Skip to content

Commit c55c18c

Browse files
committed
refactor(workforce): Refactor WorkforceCallback and all related callbacks to async interface
1 parent 2561f5c commit c55c18c

22 files changed

+585
-551
lines changed

camel/benchmarks/browsecomp.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14-
14+
import asyncio
1515
import base64
1616
import hashlib
1717
import json
@@ -585,15 +585,17 @@ def process_benchmark_row(row: Dict[str, Any]) -> Dict[str, Any]:
585585
input_message = QUERY_TEMPLATE.format(question=problem)
586586

587587
if isinstance(pipeline_template, (ChatAgent)):
588-
pipeline = pipeline_template.clone() # type: ignore[assignment]
588+
chat_pipeline = pipeline_template.clone()
589589

590-
response_text = pipeline.step(
590+
response_text = chat_pipeline.step(
591591
input_message, response_format=QueryResponse
592592
)
593593
elif isinstance(pipeline_template, Workforce):
594-
pipeline = pipeline_template.clone() # type: ignore[assignment]
594+
workforce_pipeline = asyncio.run(pipeline_template.clone())
595595
task = Task(content=input_message, id="0")
596-
task = pipeline.process_task(task) # type: ignore[attr-defined]
596+
task = asyncio.run(
597+
workforce_pipeline.process_task_async(task)
598+
) # type: ignore[attr-defined]
597599
if task_json_formatter:
598600
formatter_in_process = task_json_formatter.clone()
599601
else:
@@ -607,16 +609,16 @@ def process_benchmark_row(row: Dict[str, Any]) -> Dict[str, Any]:
607609

608610
elif isinstance(pipeline_template, RolePlaying):
609611
# RolePlaying is different.
610-
pipeline = pipeline_template.clone( # type: ignore[assignment]
612+
rp_pipeline = pipeline_template.clone(
611613
task_prompt=input_message
612614
)
613615

614616
n = 0
615-
input_msg = pipeline.init_chat() # type: ignore[attr-defined]
617+
input_msg = rp_pipeline.init_chat()
616618
chat_history = []
617619
while n < chat_turn_limit:
618620
n += 1
619-
assistant_response, user_response = pipeline.step(
621+
assistant_response, user_response = rp_pipeline.step(
620622
input_msg
621623
)
622624
if assistant_response.terminated: # type: ignore[union-attr]

camel/societies/workforce/workforce.py

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def __init__(
339339
self.snapshot_interval: float = 30.0
340340
# Shared memory UUID tracking to prevent re-sharing duplicates
341341
self._shared_memory_uuids: Set[str] = set()
342+
# Defer initial worker-created callbacks until an event loop is
343+
# available in async context.
344+
self._pending_worker_created: Deque[BaseNode] = deque(self._children)
342345
self._initialize_callbacks(callbacks)
343346

344347
# Set up coordinator agent with default system message
@@ -532,10 +535,7 @@ def _initialize_callbacks(
532535
"WorkforceLogger addition."
533536
)
534537

535-
for child in self._children:
536-
self._notify_worker_created(child)
537-
538-
def _notify_worker_created(
538+
async def _notify_worker_created(
539539
self,
540540
worker_node: BaseNode,
541541
*,
@@ -551,7 +551,19 @@ def _notify_worker_created(
551551
metadata=metadata,
552552
)
553553
for cb in self._callbacks:
554-
cb.log_worker_created(event)
554+
await cb.log_worker_created(event)
555+
556+
async def _flush_initial_worker_created_callbacks(self) -> None:
557+
r"""Flush pending worker-created callbacks that were queued during
558+
initialization before an event loop was available."""
559+
if not self._pending_worker_created:
560+
return
561+
562+
pending = list(self._pending_worker_created)
563+
self._pending_worker_created.clear()
564+
565+
for child in pending:
566+
await self._notify_worker_created(child)
555567

556568
def _get_or_create_shared_context_utility(
557569
self,
@@ -1659,7 +1671,7 @@ async def _apply_recovery_strategy(
16591671
subtask_ids=[st.id for st in subtasks],
16601672
)
16611673
for cb in self._callbacks:
1662-
cb.log_task_decomposed(task_decomposed_event)
1674+
await cb.log_task_decomposed(task_decomposed_event)
16631675
for subtask in subtasks:
16641676
task_created_event = TaskCreatedEvent(
16651677
task_id=subtask.id,
@@ -1669,7 +1681,7 @@ async def _apply_recovery_strategy(
16691681
metadata=subtask.additional_info,
16701682
)
16711683
for cb in self._callbacks:
1672-
cb.log_task_created(task_created_event)
1684+
await cb.log_task_created(task_created_event)
16731685

16741686
# Insert subtasks at the head of the queue
16751687
self._pending_tasks.extendleft(reversed(subtasks))
@@ -2188,7 +2200,7 @@ async def handle_decompose_append_task(
21882200
)
21892201
return [task]
21902202

2191-
self.reset()
2203+
await self.reset()
21922204
self._task = task
21932205
task.state = TaskState.FAILED
21942206

@@ -2199,7 +2211,7 @@ async def handle_decompose_append_task(
21992211
metadata=task.additional_info,
22002212
)
22012213
for cb in self._callbacks:
2202-
cb.log_task_created(task_created_event)
2214+
await cb.log_task_created(task_created_event)
22032215

22042216
# The agent tend to be overconfident on the whole task, so we
22052217
# decompose the task into subtasks first
@@ -2220,7 +2232,7 @@ async def handle_decompose_append_task(
22202232
subtask_ids=[st.id for st in subtasks],
22212233
)
22222234
for cb in self._callbacks:
2223-
cb.log_task_decomposed(task_decomposed_event)
2235+
await cb.log_task_decomposed(task_decomposed_event)
22242236
for subtask in subtasks:
22252237
task_created_event = TaskCreatedEvent(
22262238
task_id=subtask.id,
@@ -2230,7 +2242,7 @@ async def handle_decompose_append_task(
22302242
metadata=subtask.additional_info,
22312243
)
22322244
for cb in self._callbacks:
2233-
cb.log_task_created(task_created_event)
2245+
await cb.log_task_created(task_created_event)
22342246

22352247
if subtasks:
22362248
# _pending_tasks will contain both undecomposed
@@ -2258,6 +2270,9 @@ async def process_task_async(
22582270
Returns:
22592271
Task: The updated task.
22602272
"""
2273+
# Emit worker-created callbacks lazily once an event loop is present.
2274+
await self._flush_initial_worker_created_callbacks()
2275+
22612276
# Delegate to intervention pipeline when requested to keep
22622277
# backward-compat.
22632278
if interactive:
@@ -2606,7 +2621,7 @@ def _start_child_node_when_paused(
26062621
# Close the coroutine to prevent RuntimeWarning
26072622
start_coroutine.close()
26082623

2609-
def add_single_agent_worker(
2624+
async def add_single_agent_worker(
26102625
self,
26112626
description: str,
26122627
worker: ChatAgent,
@@ -2662,13 +2677,13 @@ def add_single_agent_worker(
26622677
# If workforce is paused, start the worker's listening task
26632678
self._start_child_node_when_paused(worker_node.start())
26642679

2665-
self._notify_worker_created(
2680+
await self._notify_worker_created(
26662681
worker_node,
26672682
worker_type='SingleAgentWorker',
26682683
)
26692684
return self
26702685

2671-
def add_role_playing_worker(
2686+
async def add_role_playing_worker(
26722687
self,
26732688
description: str,
26742689
assistant_role_name: str,
@@ -2739,7 +2754,7 @@ def add_role_playing_worker(
27392754
# If workforce is paused, start the worker's listening task
27402755
self._start_child_node_when_paused(worker_node.start())
27412756

2742-
self._notify_worker_created(
2757+
await self._notify_worker_created(
27432758
worker_node,
27442759
worker_type='RolePlayingWorker',
27452760
)
@@ -2781,7 +2796,7 @@ async def _async_reset(self) -> None:
27812796
self._pause_event.set()
27822797

27832798
@check_if_running(False)
2784-
def reset(self) -> None:
2799+
async def reset(self) -> None:
27852800
r"""Reset the workforce and all the child nodes under it. Can only
27862801
be called when the workforce is not running.
27872802
"""
@@ -2816,9 +2831,7 @@ def reset(self) -> None:
28162831
if self._loop and not self._loop.is_closed():
28172832
# If we have a loop, use it to set the event safely
28182833
try:
2819-
asyncio.run_coroutine_threadsafe(
2820-
self._async_reset(), self._loop
2821-
).result()
2834+
await self._async_reset()
28222835
except RuntimeError as e:
28232836
logger.warning(f"Failed to reset via existing loop: {e}")
28242837
# Fallback to direct event manipulation
@@ -2829,7 +2842,7 @@ def reset(self) -> None:
28292842

28302843
for cb in self._callbacks:
28312844
if isinstance(cb, WorkforceMetrics):
2832-
cb.reset_task_data()
2845+
await cb.reset_task_data()
28332846

28342847
def save_workflow_memories(
28352848
self,
@@ -3697,7 +3710,7 @@ async def _post_task(self, task: Task, assignee_id: str) -> None:
36973710
task_id=task.id, worker_id=assignee_id
36983711
)
36993712
for cb in self._callbacks:
3700-
cb.log_task_started(task_started_event)
3713+
await cb.log_task_started(task_started_event)
37013714

37023715
try:
37033716
await self._channel.post_task(task, self.node_id, assignee_id)
@@ -3844,7 +3857,7 @@ async def _create_worker_node_for_task(self, task: Task) -> Worker:
38443857

38453858
self._children.append(new_node)
38463859

3847-
self._notify_worker_created(
3860+
await self._notify_worker_created(
38483861
new_node,
38493862
worker_type='SingleAgentWorker',
38503863
role=new_node_conf.role,
@@ -3982,7 +3995,7 @@ async def _post_ready_tasks(self) -> None:
39823995
for cb in self._callbacks:
39833996
# queue_time_seconds can be derived by logger if task
39843997
# creation time is logged
3985-
cb.log_task_assigned(task_assigned_event)
3998+
await cb.log_task_assigned(task_assigned_event)
39863999

39874000
# Step 2: Iterate through all pending tasks and post those that are
39884001
# ready
@@ -4140,7 +4153,7 @@ async def _post_ready_tasks(self) -> None:
41404153
},
41414154
)
41424155
for cb in self._callbacks:
4143-
cb.log_task_failed(task_failed_event)
4156+
await cb.log_task_failed(task_failed_event)
41444157

41454158
self._completed_tasks.append(task)
41464159
self._cleanup_task_tracking(task.id)
@@ -4203,7 +4216,7 @@ async def _handle_failed_task(self, task: Task) -> bool:
42034216
},
42044217
)
42054218
for cb in self._callbacks:
4206-
cb.log_task_failed(task_failed_event)
4219+
await cb.log_task_failed(task_failed_event)
42074220

42084221
# Check for immediate halt conditions after max retries.
42094222
if task.failure_count >= MAX_TASK_RETRIES:
@@ -4390,7 +4403,7 @@ async def _handle_completed_task(self, task: Task) -> None:
43904403
metadata={'current_state': task.state.value},
43914404
)
43924405
for cb in self._callbacks:
4393-
cb.log_task_completed(task_completed_event)
4406+
await cb.log_task_completed(task_completed_event)
43944407

43954408
# Find and remove the completed task from pending tasks
43964409
tasks_list = list(self._pending_tasks)
@@ -4506,7 +4519,7 @@ async def _graceful_shutdown(self, failed_task: Task) -> None:
45064519
# Wait for the full timeout period
45074520
await asyncio.sleep(self.graceful_shutdown_timeout)
45084521

4509-
def get_workforce_log_tree(self) -> str:
4522+
async def get_workforce_log_tree(self) -> str:
45104523
r"""Returns an ASCII tree representation of the task hierarchy and
45114524
worker status.
45124525
"""
@@ -4516,19 +4529,19 @@ def get_workforce_log_tree(self) -> str:
45164529
if len(metrics_cb) == 0:
45174530
return "Metrics Callback not initialized."
45184531
else:
4519-
return metrics_cb[0].get_ascii_tree_representation()
4532+
return await metrics_cb[0].get_ascii_tree_representation()
45204533

4521-
def get_workforce_kpis(self) -> Dict[str, Any]:
4534+
async def get_workforce_kpis(self) -> Dict[str, Any]:
45224535
r"""Returns a dictionary of key performance indicators."""
45234536
metrics_cb: List[WorkforceMetrics] = [
45244537
cb for cb in self._callbacks if isinstance(cb, WorkforceMetrics)
45254538
]
45264539
if len(metrics_cb) == 0:
45274540
return {"error": "Metrics Callback not initialized."}
45284541
else:
4529-
return metrics_cb[0].get_kpis()
4542+
return await metrics_cb[0].get_kpis()
45304543

4531-
def dump_workforce_logs(self, file_path: str) -> None:
4544+
async def dump_workforce_logs(self, file_path: str) -> None:
45324545
r"""Dumps all collected logs to a JSON file.
45334546
45344547
Args:
@@ -4540,7 +4553,7 @@ def dump_workforce_logs(self, file_path: str) -> None:
45404553
if len(metrics_cb) == 0:
45414554
print("Logger not initialized. Cannot dump logs.")
45424555
return
4543-
metrics_cb[0].dump_to_json(file_path)
4556+
await metrics_cb[0].dump_to_json(file_path)
45444557
# Use logger.info or print, consistent with existing style
45454558
logger.info(f"Workforce logs dumped to {file_path}")
45464559

@@ -5016,7 +5029,7 @@ async def _listen_to_channel(self) -> None:
50165029
logger.info("All tasks completed.")
50175030
all_tasks_completed_event = AllTasksCompletedEvent()
50185031
for cb in self._callbacks:
5019-
cb.log_all_tasks_completed(all_tasks_completed_event)
5032+
await cb.log_all_tasks_completed(all_tasks_completed_event)
50205033

50215034
# shut down the whole workforce tree
50225035
self.stop()
@@ -5104,7 +5117,7 @@ async def cleanup():
51045117

51055118
self._running = False
51065119

5107-
def clone(self, with_memory: bool = False) -> 'Workforce':
5120+
async def clone(self, with_memory: bool = False) -> 'Workforce':
51085121
r"""Creates a new instance of Workforce with the same configuration.
51095122
51105123
Args:
@@ -5136,13 +5149,13 @@ def clone(self, with_memory: bool = False) -> 'Workforce':
51365149
for child in self._children:
51375150
if isinstance(child, SingleAgentWorker):
51385151
cloned_worker = child.worker.clone(with_memory)
5139-
new_instance.add_single_agent_worker(
5152+
await new_instance.add_single_agent_worker(
51405153
child.description,
51415154
cloned_worker,
51425155
pool_max_size=10,
51435156
)
51445157
elif isinstance(child, RolePlayingWorker):
5145-
new_instance.add_role_playing_worker(
5158+
await new_instance.add_role_playing_worker(
51465159
child.description,
51475160
child.assistant_role_name,
51485161
child.user_role_name,
@@ -5152,7 +5165,7 @@ def clone(self, with_memory: bool = False) -> 'Workforce':
51525165
child.chat_turn_limit,
51535166
)
51545167
elif isinstance(child, Workforce):
5155-
new_instance.add_workforce(child.clone(with_memory))
5168+
new_instance.add_workforce(await child.clone(with_memory))
51565169
else:
51575170
logger.warning(f"{type(child)} is not being cloned.")
51585171
continue
@@ -5387,7 +5400,7 @@ def get_children_info():
53875400
return children_info
53885401

53895402
# Add single agent worker
5390-
def add_single_agent_worker(
5403+
async def add_single_agent_worker(
53915404
description,
53925405
system_message=None,
53935406
role_name="Assistant",
@@ -5451,7 +5464,9 @@ def add_single_agent_worker(
54515464
"message": str(e),
54525465
}
54535466

5454-
workforce_instance.add_single_agent_worker(description, agent)
5467+
await workforce_instance.add_single_agent_worker(
5468+
description, agent
5469+
)
54555470

54565471
return {
54575472
"status": "success",
@@ -5462,7 +5477,7 @@ def add_single_agent_worker(
54625477
return {"status": "error", "message": str(e)}
54635478

54645479
# Add role playing worker
5465-
def add_role_playing_worker(
5480+
async def add_role_playing_worker(
54665481
description,
54675482
assistant_role_name,
54685483
user_role_name,
@@ -5519,7 +5534,7 @@ def add_role_playing_worker(
55195534
"message": "Cannot add workers while workforce is running", # noqa: E501
55205535
}
55215536

5522-
workforce_instance.add_role_playing_worker(
5537+
await workforce_instance.add_role_playing_worker(
55235538
description=description,
55245539
assistant_role_name=assistant_role_name,
55255540
user_role_name=user_role_name,

0 commit comments

Comments
 (0)