Source code for trinity.common.workflows.agentscope_workflow

from typing import Awaitable, Callable, Dict, List, Optional

import openai

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow


[docs] @WORKFLOWS.register_module("agentscope_workflow_adapter") class AgentScopeWorkflowAdapter(Workflow): """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow.""" is_async: bool = True
[docs] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): """Initialize the adapter with the task and model.""" try: from agentscope.model import TrinityChatModel except ImportError: raise ImportError( "This workflow requires agentscope >= 0.1.6, please install " "it via `pip install agentscope>=0.1.6`", ) super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, ) self.workflow_func: Callable[ [Dict, TrinityChatModel], Awaitable[float] ] = task.workflow_args.get("workflow_func", None) if self.workflow_func is None: raise ValueError( "The 'workflow_func' is not provided.", ) self.chat_model: TrinityChatModel = TrinityChatModel( model.get_openai_async_client(), )
[docs] def construct_experiences( self, reward: float, ) -> List[Experience]: """Construct experiences from the agent's interaction history. Args: reward (float): The reward value to assign to each experience. Returns: List: A list of Experience objects. """ exps = self.model.extract_experience_from_history() for exp in exps: exp.reward = reward return exps
[docs] async def run_async(self) -> List[Experience]: """Run the workflow asynchronously and return experiences.""" reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type] return self.construct_experiences(reward)