Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -244,9 +244,8 @@ def track_video(n_frames,video_state):
|
|
| 244 |
video_state["origin_images"] = images
|
| 245 |
images = np.array(images)
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
inference_state = video_predictor.init_state(images=images/255, device="cuda")
|
| 250 |
video_state["inference_state"] = inference_state
|
| 251 |
|
| 252 |
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|
|
@@ -254,7 +253,7 @@ def track_video(n_frames,video_state):
|
|
| 254 |
else:
|
| 255 |
mask = torch.from_numpy(video_state["masks"][0])
|
| 256 |
|
| 257 |
-
|
| 258 |
inference_state=inference_state,
|
| 259 |
frame_idx=0,
|
| 260 |
obj_id=obj_id,
|
|
@@ -265,7 +264,7 @@ def track_video(n_frames,video_state):
|
|
| 265 |
mask_frames = []
|
| 266 |
color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
|
| 267 |
color = color[None, None, :]
|
| 268 |
-
for out_frame_idx, out_obj_ids, out_mask_logits in
|
| 269 |
frame = images[out_frame_idx].astype(np.float32) / 255.0
|
| 270 |
mask = np.zeros((H, W, 3), dtype=np.float32)
|
| 271 |
for i, logit in enumerate(out_mask_logits):
|
|
|
|
| 244 |
video_state["origin_images"] = images
|
| 245 |
images = np.array(images)
|
| 246 |
|
| 247 |
+
video_predictor_local=video_predictor.to("cuda")
|
| 248 |
+
inference_state = video_predictor_local.init_state(images=images/255, device="cuda")
|
|
|
|
| 249 |
video_state["inference_state"] = inference_state
|
| 250 |
|
| 251 |
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|
|
|
|
| 253 |
else:
|
| 254 |
mask = torch.from_numpy(video_state["masks"][0])
|
| 255 |
|
| 256 |
+
video_predictor_local.add_new_mask(
|
| 257 |
inference_state=inference_state,
|
| 258 |
frame_idx=0,
|
| 259 |
obj_id=obj_id,
|
|
|
|
| 264 |
mask_frames = []
|
| 265 |
color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
|
| 266 |
color = color[None, None, :]
|
| 267 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(inference_state):
|
| 268 |
frame = images[out_frame_idx].astype(np.float32) / 255.0
|
| 269 |
mask = np.zeros((H, W, 3), dtype=np.float32)
|
| 270 |
for i, logit in enumerate(out_mask_logits):
|