sugan04 commited on
Commit
a5dd6dc
·
verified ·
1 Parent(s): 329ed5a

Upload modal_inference.py

Browse files
Files changed (1) hide show
  1. modal_inference.py +71 -0
modal_inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+ import os
3
+
4
+ app = modal.App("surgisight")
5
+
6
+ image = (
7
+ modal.Image.debian_slim(python_version="3.11")
8
+ .apt_install("libgl1", "libglib2.0-0")
9
+ .pip_install(
10
+ "ultralytics",
11
+ "pillow",
12
+ "numpy",
13
+ "opencv-python-headless",
14
+ "huggingface_hub",
15
+ )
16
+ )
17
+
18
+ # Cache the model weights inside the Modal image so it doesn't re-download every call
19
+ with image.imports():
20
+ from ultralytics import YOLO
21
+ from PIL import Image as PILImage
22
+ import numpy as np
23
+ import cv2
24
+ import io
25
+
26
+
27
+ @app.cls(gpu="T4", image=image, secrets=[modal.Secret.from_name("hf-secret")])
28
+ class SurgiSightDetector:
29
+
30
+ @modal.enter()
31
+ def load_model(self):
32
+ from huggingface_hub import hf_hub_download
33
+ model_path = hf_hub_download(
34
+ repo_id="sugan04/cholec-yolo26n-seg",
35
+ filename="best.pt",
36
+ token=os.environ.get("HF_TOKEN")
37
+ )
38
+ self.model = YOLO(model_path)
39
+
40
+ @modal.method()
41
+ def run(self, image_bytes: bytes, conf_threshold: float = 0.25):
42
+ nparr = np.frombuffer(image_bytes, np.uint8)
43
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
44
+
45
+ results = self.model(frame, task="segment", conf=conf_threshold)
46
+ annotated = results[0].plot()
47
+
48
+ # Encode annotated image back to bytes
49
+ _, buffer = cv2.imencode(".png", annotated)
50
+ annotated_bytes = buffer.tobytes()
51
+
52
+ # Extract detections
53
+ boxes = results[0].boxes
54
+ detections = []
55
+ if boxes is not None and len(boxes) > 0:
56
+ for box in boxes:
57
+ detections.append({
58
+ "cls_id": int(box.cls[0]),
59
+ "conf": float(box.conf[0])
60
+ })
61
+
62
+ return {"annotated_bytes": annotated_bytes, "detections": detections}
63
+
64
+
65
+ # For local testing
66
+ @app.local_entrypoint()
67
+ def main():
68
+ from PIL import Image as PILImage
69
+ import io
70
+ detector = SurgiSightDetector()
71
+ print("Modal SurgiSight detector ready.")