Source code for trinity.common.workflows.agentscope_workflow

import inspect
from typing import Awaitable, Callable, Dict, List, Optional

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import Task, Workflow
from trinity.utils.annotations import Deprecated


[docs] @Deprecated class AgentScopeWorkflowAdapter(Workflow): """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow. Only for agentscope versions between 1.0.7 and 1.0.11. For agentscope >= 1.0.12, please use AgentScopeWorkflowAdapterV1. """ is_async: bool = True
[docs] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[ModelWrapper]] = None, ): """Initialize the adapter with the task and model.""" try: from agentscope.model import TrinityChatModel except ImportError: raise ImportError( "This workflow requires agentscope >= 1.0.7, please install " "it via `pip install agentscope>=1.0.7`", ) super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, ) self.workflow_func: Callable[ [Dict, TrinityChatModel], Awaitable[float] ] = task.workflow_args.get("workflow_func", None) if self.workflow_func is None: raise ValueError( "The 'workflow_func' is not provided.", ) self.chat_model: TrinityChatModel = TrinityChatModel( model.get_openai_async_client(), generate_kwargs={ "temperature": self.task.rollout_args.temperature, "top_p": self.task.rollout_args.top_p, "max_tokens": self.task.rollout_args.max_tokens or 4096, "logprobs": True, "top_logprobs": self.task.rollout_args.logprobs, }, )
[docs] def construct_experiences( self, reward: float, ) -> List[Experience]: """Construct experiences from the agent's interaction history. Args: reward (float): The reward value to assign to each experience. Returns: List: A list of Experience objects. """ exps = self.model.extract_experience_from_history() for exp in exps: exp.reward = reward return exps
[docs] async def run_async(self) -> List[Experience]: """Run the workflow asynchronously and return experiences.""" reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type] return self.construct_experiences(reward)
[docs] class AgentScopeWorkflowAdapterV1(Workflow): """A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow. Only for agentscope versions >= 1.0.12. """ is_async: bool = True
[docs] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[ModelWrapper]] = None, ): """Initialize the adapter with the task and model.""" try: from agentscope.model import OpenAIChatModel except ImportError: raise ImportError( "This workflow requires agentscope >= 1.0.12, please install " "it via `pip install agentscope>=1.0.12`", ) super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, ) self.workflow_func = task.workflow_args.get("workflow_func", None) self.judge_func = task.workflow_args.get("judge_func", None) self._openai_client = self.model.get_openai_async_client() if self.workflow_func is None: raise ValueError( "The 'workflow_func' is not provided.", ) self.chat_model: OpenAIChatModel = OpenAIChatModel( api_key="EMPTY", model_name=self._openai_client.model_path, stream=False, generate_kwargs={ "temperature": self.task.rollout_args.temperature, "top_p": self.task.rollout_args.top_p, "max_tokens": self.task.rollout_args.max_tokens or 4096, "logprobs": True, "top_logprobs": self.task.rollout_args.logprobs, }, ) self.chat_model.client = self._openai_client self.auxiliary_chat_models: Dict[str, OpenAIChatModel] = {} if self.auxiliary_model_wrappers is not None: for aux_model_wrapper in self.auxiliary_model_wrappers: aux_model_client = aux_model_wrapper.get_openai_async_client() aux_chat_model = OpenAIChatModel( api_key="EMPTY", model_name=aux_model_client.model_path, generate_kwargs=aux_model_wrapper.generate_kwargs, stream=False, ) aux_chat_model.client = aux_model_client assert ( aux_model_wrapper.model_name is not None ), "Auxiliary model must have a name. This should not happen." self.auxiliary_chat_models[aux_model_wrapper.model_name] = aux_chat_model
[docs] def construct_experiences( self, reward: float, metrics: Dict, ) -> List[Experience]: """Construct experiences from the agent's interaction history. Args: reward (float): The reward value to assign to each experience. metrics (Dict): A dictionary of metrics to be attached to the last experience. Returns: List: A list of Experience objects. """ exps = self.model.extract_experience_from_history() for exp in exps: exp.reward = reward # only attach metrics to the last experience if len(exps) > 0: exps[-1].metrics = metrics return exps
[docs] async def run_async(self) -> List[Experience]: """Run the workflow asynchronously and return experiences.""" try: from agentscope.tuner import JudgeOutput, WorkflowOutput except ImportError: raise ImportError( "Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.12 is installed." ) metrics = {} workflow_sig = inspect.signature(self.workflow_func) if "auxiliary_models" in workflow_sig.parameters: workflow_output = await self.workflow_func( task=self.task.raw_task, model=self.chat_model, auxiliary_models=self.auxiliary_chat_models, ) else: workflow_output = await self.workflow_func( task=self.task.raw_task, model=self.chat_model, ) if not isinstance(workflow_output, WorkflowOutput): raise ValueError( "The 'workflow_func' must return a WorkflowOutput object.", ) metrics.update(workflow_output.metrics or {}) if self.judge_func is not None: judge_sig = inspect.signature(self.judge_func) if "auxiliary_models" in judge_sig.parameters: judge_output = await self.judge_func( task=self.task.raw_task, response=workflow_output.response, auxiliary_models=self.auxiliary_chat_models, ) else: judge_output = await self.judge_func( task=self.task.raw_task, response=workflow_output.response, ) if not isinstance(judge_output, JudgeOutput): raise ValueError( "The 'judge_func' must return a JudgeOutput object.", ) reward = judge_output.reward metrics.update(judge_output.metrics or {}) else: assert ( workflow_output.reward is not None ), "Either workflow or judge must provide reward." reward = workflow_output.reward return self.construct_experiences(reward, metrics)