Amir Mahla commited on
Commit
975f40e
·
1 Parent(s): 3cf734e

FIX race condition

Browse files
cua2-core/src/cua2_core/models/models.py CHANGED
@@ -1,6 +1,6 @@
 
1
  import json
2
  import os
3
- import threading
4
  from datetime import datetime
5
  from typing import Annotated, Literal, Optional
6
  from uuid import uuid4
@@ -269,51 +269,58 @@ class ActiveTask(BaseModel):
269
  timestamp: datetime = datetime.now()
270
  steps: list[AgentStep] = []
271
  traceMetadata: AgentTraceMetadata = AgentTraceMetadata()
272
- _file_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
 
 
 
 
 
 
273
 
274
  @property
275
  def trace_path(self):
276
  """Trace path"""
277
  return f"data/trace-{self.message_id}-{self.model_id.replace('/', '-')}"
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  @model_validator(mode="after")
280
  def store_model(self):
281
- """Validate model ID"""
282
- with self._file_lock:
283
- self.traceMetadata.traceId = self.message_id
284
- os.makedirs(self.trace_path, exist_ok=True)
285
- with open(f"{self.trace_path}/tasks.json", "w") as f:
286
- json.dump(
287
- self.model_dump(
288
- mode="json",
289
- exclude={"_file_locks"},
290
- context={"actions_as_json": True, "image_as_path": True},
291
- ),
292
- f,
293
- indent=2,
294
- )
295
  return self
296
 
297
- def update_step(self, step: AgentStep):
 
 
 
 
 
298
  """Update step"""
299
- with self._file_lock:
300
  if int(step.stepId) <= len(self.steps):
301
  self.steps[int(step.stepId) - 1] = step
302
  else:
303
  self.steps.append(step)
304
  self.traceMetadata.numberOfSteps = len(self.steps)
305
- with open(f"{self.trace_path}/tasks.json", "w") as f:
306
- json.dump(
307
- self.model_dump(
308
- mode="json",
309
- exclude={"_file_locks"},
310
- context={"actions_as_json": True, "image_as_path": True},
311
- ),
312
- f,
313
- indent=2,
314
- )
315
-
316
- def update_trace_metadata(
317
  self,
318
  step_input_tokens_used: int | None = None,
319
  step_output_tokens_used: int | None = None,
@@ -327,7 +334,7 @@ class ActiveTask(BaseModel):
327
  user_evaluation: Literal["success", "failed", "not_evaluated"] | None = None,
328
  ):
329
  """Update trace metadata"""
330
- with self._file_lock:
331
  if step_input_tokens_used is not None:
332
  self.traceMetadata.inputTokensUsed += step_input_tokens_used
333
  if step_output_tokens_used is not None:
 
1
+ import asyncio
2
  import json
3
  import os
 
4
  from datetime import datetime
5
  from typing import Annotated, Literal, Optional
6
  from uuid import uuid4
 
269
  timestamp: datetime = datetime.now()
270
  steps: list[AgentStep] = []
271
  traceMetadata: AgentTraceMetadata = AgentTraceMetadata()
272
+ _file_lock: asyncio.Lock | None = PrivateAttr(default=None)
273
+
274
+ def _get_lock(self) -> asyncio.Lock:
275
+ """Get or create the async lock (lazy initialization)"""
276
+ if self._file_lock is None:
277
+ self._file_lock = asyncio.Lock()
278
+ return self._file_lock
279
 
280
  @property
281
  def trace_path(self):
282
  """Trace path"""
283
  return f"data/trace-{self.message_id}-{self.model_id.replace('/', '-')}"
284
 
285
+ def _write_to_file_sync(self):
286
+ """Synchronous file write helper (used in async context via to_thread)"""
287
+ self.traceMetadata.traceId = self.message_id
288
+ os.makedirs(self.trace_path, exist_ok=True)
289
+ with open(f"{self.trace_path}/tasks.json", "w") as f:
290
+ json.dump(
291
+ self.model_dump(
292
+ mode="json",
293
+ exclude={"_file_lock", "_lock_initialized"},
294
+ context={"actions_as_json": True, "image_as_path": True},
295
+ ),
296
+ f,
297
+ indent=2,
298
+ )
299
+
300
  @model_validator(mode="after")
301
  def store_model(self):
302
+ """Validate model ID - creates directory, but file write is deferred to async method"""
303
+ self.traceMetadata.traceId = self.message_id
304
+ os.makedirs(self.trace_path, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
305
  return self
306
 
307
+ async def save_to_file(self):
308
+ """Async method to save task data to file"""
309
+ async with self._get_lock():
310
+ await asyncio.to_thread(self._write_to_file_sync)
311
+
312
+ async def update_step(self, step: AgentStep):
313
  """Update step"""
314
+ async with self._get_lock():
315
  if int(step.stepId) <= len(self.steps):
316
  self.steps[int(step.stepId) - 1] = step
317
  else:
318
  self.steps.append(step)
319
  self.traceMetadata.numberOfSteps = len(self.steps)
320
+ # Use to_thread for file I/O to avoid blocking
321
+ await asyncio.to_thread(self._write_to_file_sync)
322
+
323
+ async def update_trace_metadata(
 
 
 
 
 
 
 
 
324
  self,
325
  step_input_tokens_used: int | None = None,
326
  step_output_tokens_used: int | None = None,
 
334
  user_evaluation: Literal["success", "failed", "not_evaluated"] | None = None,
335
  ):
336
  """Update trace metadata"""
337
+ async with self._get_lock():
338
  if step_input_tokens_used is not None:
339
  self.traceMetadata.inputTokensUsed += step_input_tokens_used
340
  if step_output_tokens_used is not None:
cua2-core/src/cua2_core/routes/routes.py CHANGED
@@ -74,7 +74,7 @@ async def update_trace_step(
74
  ):
75
  """Update a specific step in a trace (e.g., update step evaluation)"""
76
  try:
77
- agent_service.update_trace_step(
78
  trace_id=trace_id,
79
  step_id=step_id,
80
  step_evaluation=request.step_evaluation,
@@ -99,7 +99,7 @@ async def update_trace_evaluation(
99
  ):
100
  """Update the user evaluation for a trace (overall task feedback)"""
101
  try:
102
- agent_service.update_trace_evaluation(
103
  trace_id=trace_id,
104
  user_evaluation=request.user_evaluation,
105
  )
 
74
  ):
75
  """Update a specific step in a trace (e.g., update step evaluation)"""
76
  try:
77
+ await agent_service.update_trace_step(
78
  trace_id=trace_id,
79
  step_id=step_id,
80
  step_evaluation=request.step_evaluation,
 
99
  ):
100
  """Update the user evaluation for a trace (overall task feedback)"""
101
  try:
102
+ await agent_service.update_trace_evaluation(
103
  trace_id=trace_id,
104
  user_evaluation=request.user_evaluation,
105
  )
cua2-core/src/cua2_core/services/agent_service.py CHANGED
@@ -104,9 +104,13 @@ class AgentService:
104
  """
105
  Update the archival service with current active task IDs.
106
  Should be called whenever tasks are added or removed.
 
 
107
  """
108
  if self.archival_service.is_alive():
109
- self.archival_service.update_active_tasks(set(self.active_tasks.keys()))
 
 
110
 
111
  async def create_id_and_sandbox(self, websocket: WebSocket) -> str:
112
  """Create a new ID and sandbox"""
@@ -174,8 +178,8 @@ class AgentService:
174
  self.active_tasks[trace_id] = active_task
175
  self.last_screenshot[trace_id] = None
176
 
177
- # Update archival service with new active task
178
- self._update_archival_active_tasks()
179
 
180
  asyncio.create_task(self._agent_processing(trace_id))
181
 
@@ -351,13 +355,13 @@ class AgentService:
351
 
352
  novnc_active = False
353
 
354
- self.active_tasks[message_id].update_trace_metadata(
355
  final_state=final_state,
356
  completed=True,
357
  )
358
 
359
  if message_id in self.active_tasks:
360
- self.active_tasks[message_id].store_model()
361
 
362
  # Clean up
363
  async with self._lock:
@@ -370,8 +374,8 @@ class AgentService:
370
  if message_id in self.last_screenshot:
371
  del self.last_screenshot[message_id]
372
 
373
- # Update archival service after task removal
374
- self._update_archival_active_tasks()
375
 
376
  # Always release sandbox back to the pool, even if it's still in "creating" state
377
  # This handles cases where acquire_sandbox was called but sandbox never became ready
@@ -469,14 +473,23 @@ class AgentService:
469
  step_evaluation="neutral",
470
  )
471
 
472
- self.active_tasks[message_id].update_trace_metadata(
473
- step_input_tokens_used=memory_step.token_usage.input_tokens,
474
- step_output_tokens_used=memory_step.token_usage.output_tokens,
475
- step_duration=memory_step.timing.duration,
476
- step_numberOfSteps=1,
 
 
 
 
477
  )
478
-
479
- self.active_tasks[message_id].update_step(step)
 
 
 
 
 
480
 
481
  websocket = self.task_websockets.get(message_id)
482
  if websocket and websocket.client_state == WebSocketState.CONNECTED:
@@ -529,7 +542,7 @@ class AgentService:
529
  # Re-raise to ensure error is logged
530
  raise
531
 
532
- def update_trace_step(
533
  self,
534
  trace_id: str,
535
  step_id: str,
@@ -559,7 +572,8 @@ class AgentService:
559
  step_index = int(step_id) - 1
560
  if 0 <= step_index < len(active_task.steps):
561
  active_task.steps[step_index].step_evaluation = step_evaluation
562
- active_task.update_step(active_task.steps[step_index])
 
563
  else:
564
  raise ValueError(f"Step {step_id} not found in trace")
565
  except (ValueError, TypeError) as e:
@@ -602,7 +616,7 @@ class AgentService:
602
  except (ValueError, KeyError, TypeError) as e:
603
  raise ValueError(f"Error processing step update: {e}")
604
 
605
- def update_trace_evaluation(
606
  self,
607
  trace_id: str,
608
  user_evaluation: Literal["success", "failed", "not_evaluated"],
@@ -622,7 +636,7 @@ class AgentService:
622
 
623
  if active_task:
624
  # Task is still active
625
- active_task.update_trace_metadata(user_evaluation=user_evaluation)
626
  else:
627
  # Task is not active, try to load from file
628
  data_dir = "data"
@@ -657,7 +671,7 @@ class AgentService:
657
  async def stop_task(self, trace_id: str):
658
  """Stop a task"""
659
  if trace_id in self.active_tasks:
660
- self.active_tasks[trace_id].update_trace_metadata(
661
  completed=True,
662
  )
663
 
@@ -687,7 +701,7 @@ class AgentService:
687
  try:
688
  # Mark task as completed to stop the agent (if task exists)
689
  if message_id in self.active_tasks:
690
- self.active_tasks[message_id].update_trace_metadata(
691
  completed=True,
692
  )
693
  logger.info(
 
104
  """
105
  Update the archival service with current active task IDs.
106
  Should be called whenever tasks are added or removed.
107
+ Note: This should be called while holding self._lock to ensure consistent snapshot.
108
+ The archival service update itself is fast and non-blocking.
109
  """
110
  if self.archival_service.is_alive():
111
+ # Create a snapshot of active task IDs (should be called with lock held)
112
+ active_task_ids = set(self.active_tasks.keys())
113
+ self.archival_service.update_active_tasks(active_task_ids)
114
 
115
  async def create_id_and_sandbox(self, websocket: WebSocket) -> str:
116
  """Create a new ID and sandbox"""
 
178
  self.active_tasks[trace_id] = active_task
179
  self.last_screenshot[trace_id] = None
180
 
181
+ # Update archival service with new active task (while holding lock)
182
+ self._update_archival_active_tasks()
183
 
184
  asyncio.create_task(self._agent_processing(trace_id))
185
 
 
355
 
356
  novnc_active = False
357
 
358
+ await self.active_tasks[message_id].update_trace_metadata(
359
  final_state=final_state,
360
  completed=True,
361
  )
362
 
363
  if message_id in self.active_tasks:
364
+ await self.active_tasks[message_id].save_to_file()
365
 
366
  # Clean up
367
  async with self._lock:
 
374
  if message_id in self.last_screenshot:
375
  del self.last_screenshot[message_id]
376
 
377
+ # Update archival service after task removal (while holding lock)
378
+ self._update_archival_active_tasks()
379
 
380
  # Always release sandbox back to the pool, even if it's still in "creating" state
381
  # This handles cases where acquire_sandbox was called but sandbox never became ready
 
473
  step_evaluation="neutral",
474
  )
475
 
476
+ # Schedule async operations in the event loop (callback runs in worker thread)
477
+ future1 = asyncio.run_coroutine_threadsafe(
478
+ self.active_tasks[message_id].update_trace_metadata(
479
+ step_input_tokens_used=memory_step.token_usage.input_tokens,
480
+ step_output_tokens_used=memory_step.token_usage.output_tokens,
481
+ step_duration=memory_step.timing.duration,
482
+ step_numberOfSteps=1,
483
+ ),
484
+ loop,
485
  )
486
+ future2 = asyncio.run_coroutine_threadsafe(
487
+ self.active_tasks[message_id].update_step(step),
488
+ loop,
489
+ )
490
+ # Wait for both to complete
491
+ future1.result()
492
+ future2.result()
493
 
494
  websocket = self.task_websockets.get(message_id)
495
  if websocket and websocket.client_state == WebSocketState.CONNECTED:
 
542
  # Re-raise to ensure error is logged
543
  raise
544
 
545
+ async def update_trace_step(
546
  self,
547
  trace_id: str,
548
  step_id: str,
 
572
  step_index = int(step_id) - 1
573
  if 0 <= step_index < len(active_task.steps):
574
  active_task.steps[step_index].step_evaluation = step_evaluation
575
+ await active_task.update_step(active_task.steps[step_index])
576
+ return active_task.steps[step_index]
577
  else:
578
  raise ValueError(f"Step {step_id} not found in trace")
579
  except (ValueError, TypeError) as e:
 
616
  except (ValueError, KeyError, TypeError) as e:
617
  raise ValueError(f"Error processing step update: {e}")
618
 
619
+ async def update_trace_evaluation(
620
  self,
621
  trace_id: str,
622
  user_evaluation: Literal["success", "failed", "not_evaluated"],
 
636
 
637
  if active_task:
638
  # Task is still active
639
+ await active_task.update_trace_metadata(user_evaluation=user_evaluation)
640
  else:
641
  # Task is not active, try to load from file
642
  data_dir = "data"
 
671
  async def stop_task(self, trace_id: str):
672
  """Stop a task"""
673
  if trace_id in self.active_tasks:
674
+ await self.active_tasks[trace_id].update_trace_metadata(
675
  completed=True,
676
  )
677
 
 
701
  try:
702
  # Mark task as completed to stop the agent (if task exists)
703
  if message_id in self.active_tasks:
704
+ await self.active_tasks[message_id].update_trace_metadata(
705
  completed=True,
706
  )
707
  logger.info(
cua2-core/src/cua2_core/services/sandbox_service.py CHANGED
@@ -158,10 +158,14 @@ class SandboxService:
158
  asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
159
  return
160
 
161
- # Check capacity before adding
162
- if len(self.sandboxes) >= self.max_sandboxes:
 
 
 
163
  print(
164
- f"Pool at capacity, killing newly created sandbox for {session_hash}"
 
165
  )
166
  asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
167
  return
 
158
  asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
159
  return
160
 
161
+ # Check total capacity before adding (sandboxes + other pending creations)
162
+ # Note: We already removed this session_hash from pending, so we check
163
+ # if adding it to sandboxes would exceed capacity
164
+ total_count = len(self.sandboxes) + len(self.pending)
165
+ if total_count >= self.max_sandboxes:
166
  print(
167
+ f"Pool at capacity ({total_count}/{self.max_sandboxes}), "
168
+ f"killing newly created sandbox for {session_hash}"
169
  )
170
  asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
171
  return
cua2-core/tests/test_routes.py CHANGED
@@ -1,4 +1,4 @@
1
- from unittest.mock import Mock
2
 
3
  import pytest
4
  from cua2_core.models.models import AvailableModelsResponse, UpdateStepResponse
@@ -15,7 +15,9 @@ def mock_agent_service():
15
  """Fixture to create a mocked AgentService"""
16
  service = Mock(spec=AgentService)
17
  service.active_tasks = {}
18
- service.update_trace_step = Mock()
 
 
19
  return service
20
 
21
 
@@ -112,8 +114,8 @@ class TestUpdateTraceStep:
112
  step_id = "1"
113
  request_data = {"step_evaluation": "like"}
114
 
115
- # Mock the service method to succeed
116
- mock_agent_service.update_trace_step.return_value = None
117
 
118
  response = client.patch(
119
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
@@ -136,8 +138,6 @@ class TestUpdateTraceStep:
136
  step_id = "2"
137
  request_data = {"step_evaluation": "dislike"}
138
 
139
- mock_agent_service.update_trace_step.return_value = None
140
-
141
  response = client.patch(
142
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
143
  )
@@ -154,8 +154,6 @@ class TestUpdateTraceStep:
154
  step_id = "3"
155
  request_data = {"step_evaluation": "neutral"}
156
 
157
- mock_agent_service.update_trace_step.return_value = None
158
-
159
  response = client.patch(
160
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
161
  )
@@ -186,8 +184,8 @@ class TestUpdateTraceStep:
186
  request_data = {"step_evaluation": "like"}
187
 
188
  # Mock the service to raise ValueError
189
- mock_agent_service.update_trace_step.side_effect = ValueError(
190
- "Invalid step_id format"
191
  )
192
 
193
  response = client.patch(
@@ -204,8 +202,8 @@ class TestUpdateTraceStep:
204
  request_data = {"step_evaluation": "like"}
205
 
206
  # Mock the service to raise FileNotFoundError
207
- mock_agent_service.update_trace_step.side_effect = FileNotFoundError(
208
- "Trace not found"
209
  )
210
 
211
  response = client.patch(
@@ -222,8 +220,8 @@ class TestUpdateTraceStep:
222
  request_data = {"step_evaluation": "like"}
223
 
224
  # Mock the service to raise ValueError for step not found
225
- mock_agent_service.update_trace_step.side_effect = ValueError(
226
- "Step 999 not found in trace"
227
  )
228
 
229
  response = client.patch(
@@ -251,8 +249,6 @@ class TestUpdateTraceStep:
251
  step_id = "1"
252
  request_data = {"step_evaluation": "like"}
253
 
254
- mock_agent_service.update_trace_step.return_value = None
255
-
256
  response = client.patch(
257
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
258
  )
@@ -269,8 +265,6 @@ class TestUpdateTraceStep:
269
  step_id = "1"
270
  request_data = {"step_evaluation": "like"}
271
 
272
- mock_agent_service.update_trace_step.return_value = None
273
-
274
  response = client.patch(
275
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
276
  )
@@ -294,8 +288,7 @@ class TestRoutesIntegration:
294
 
295
  def test_update_step_endpoint_available(self, client, mock_agent_service):
296
  """Test that update step endpoint is available"""
297
- mock_agent_service.update_trace_step.return_value = None
298
-
299
  response = client.patch(
300
  "/traces/test/steps/1", json={"step_evaluation": "like"}
301
  )
 
1
+ from unittest.mock import AsyncMock, Mock
2
 
3
  import pytest
4
  from cua2_core.models.models import AvailableModelsResponse, UpdateStepResponse
 
15
  """Fixture to create a mocked AgentService"""
16
  service = Mock(spec=AgentService)
17
  service.active_tasks = {}
18
+ # update_trace_step is now async, so use AsyncMock
19
+ service.update_trace_step = AsyncMock(return_value=None)
20
+ service.update_trace_evaluation = AsyncMock(return_value=None)
21
  return service
22
 
23
 
 
114
  step_id = "1"
115
  request_data = {"step_evaluation": "like"}
116
 
117
+ # Mock the service method to succeed (already set up as AsyncMock in fixture)
118
+ pass
119
 
120
  response = client.patch(
121
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
 
138
  step_id = "2"
139
  request_data = {"step_evaluation": "dislike"}
140
 
 
 
141
  response = client.patch(
142
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
143
  )
 
154
  step_id = "3"
155
  request_data = {"step_evaluation": "neutral"}
156
 
 
 
157
  response = client.patch(
158
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
159
  )
 
184
  request_data = {"step_evaluation": "like"}
185
 
186
  # Mock the service to raise ValueError
187
+ mock_agent_service.update_trace_step = AsyncMock(
188
+ side_effect=ValueError("Invalid step_id format")
189
  )
190
 
191
  response = client.patch(
 
202
  request_data = {"step_evaluation": "like"}
203
 
204
  # Mock the service to raise FileNotFoundError
205
+ mock_agent_service.update_trace_step = AsyncMock(
206
+ side_effect=FileNotFoundError("Trace not found")
207
  )
208
 
209
  response = client.patch(
 
220
  request_data = {"step_evaluation": "like"}
221
 
222
  # Mock the service to raise ValueError for step not found
223
+ mock_agent_service.update_trace_step = AsyncMock(
224
+ side_effect=ValueError("Step 999 not found in trace")
225
  )
226
 
227
  response = client.patch(
 
249
  step_id = "1"
250
  request_data = {"step_evaluation": "like"}
251
 
 
 
252
  response = client.patch(
253
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
254
  )
 
265
  step_id = "1"
266
  request_data = {"step_evaluation": "like"}
267
 
 
 
268
  response = client.patch(
269
  f"/traces/{trace_id}/steps/{step_id}", json=request_data
270
  )
 
288
 
289
  def test_update_step_endpoint_available(self, client, mock_agent_service):
290
  """Test that update step endpoint is available"""
291
+ # Mock is already set up as AsyncMock in fixture
 
292
  response = client.patch(
293
  "/traces/test/steps/1", json={"step_evaluation": "like"}
294
  )