trinity.common.workflows.agentscope_workflow 源代码

from typing import Awaitable, Callable, Dict, List, Optional

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import Task, Workflow


[文档] class AgentScopeWorkflowAdapter(Workflow): """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow.""" is_async: bool = True
[文档] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[ModelWrapper]] = None, ): """Initialize the adapter with the task and model.""" try: from agentscope.model import TrinityChatModel except ImportError: raise ImportError( "This workflow requires agentscope >= 1.0.7, please install " "it via `pip install agentscope>=1.0.7`", ) super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, ) self.workflow_func: Callable[ [Dict, TrinityChatModel], Awaitable[float] ] = task.workflow_args.get("workflow_func", None) if self.workflow_func is None: raise ValueError( "The 'workflow_func' is not provided.", ) self.chat_model: TrinityChatModel = TrinityChatModel( model.get_openai_async_client(), generate_kwargs={ "temperature": self.task.rollout_args.temperature, "top_p": self.task.rollout_args.top_p, "max_tokens": self.task.rollout_args.max_tokens or 4096, "logprobs": True, "top_logprobs": self.task.rollout_args.logprobs, }, )
[文档] def construct_experiences( self, reward: float, ) -> List[Experience]: """Construct experiences from the agent's interaction history. Args: reward (float): The reward value to assign to each experience. Returns: List: A list of Experience objects. """ exps = self.model.extract_experience_from_history() for exp in exps: exp.reward = reward return exps
[文档] async def run_async(self) -> List[Experience]: """Run the workflow asynchronously and return experiences.""" reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type] return self.construct_experiences(reward)
[文档] class AgentScopeWorkflowAdapterV1(Workflow): """A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow.""" is_async: bool = True
[文档] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[ModelWrapper]] = None, ): """Initialize the adapter with the task and model.""" try: from agentscope.model import TrinityChatModel except ImportError: raise ImportError( "This workflow requires agentscope >= 1.0.11, please install " "it via `pip install agentscope>=1.0.11`", ) super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, ) self.workflow_func = task.workflow_args.get("workflow_func", None) self.judge_func = task.workflow_args.get("judge_func", None) if self.workflow_func is None: raise ValueError( "The 'workflow_func' is not provided.", ) self.chat_model: TrinityChatModel = TrinityChatModel( model.get_openai_async_client(), generate_kwargs={ "temperature": self.task.rollout_args.temperature, "top_p": self.task.rollout_args.top_p, "max_tokens": self.task.rollout_args.max_tokens or 4096, "logprobs": True, "top_logprobs": self.task.rollout_args.logprobs, }, ) self.auxiliary_chat_models = [ TrinityChatModel( openai_async_client=aux_model, # TODO: customize generate_kwargs for auxiliary models if needed ) for aux_model in (self.auxiliary_models or []) ]
[文档] def construct_experiences( self, reward: float, metrics: Dict, ) -> List[Experience]: """Construct experiences from the agent's interaction history. Args: reward (float): The reward value to assign to each experience. metrics (Dict): A dictionary of metrics to be attached to the last experience. Returns: List: A list of Experience objects. """ exps = self.model.extract_experience_from_history() for exp in exps: exp.reward = reward # only attach metrics to the last experience if len(exps) > 0: exps[-1].metrics = metrics return exps
[文档] async def run_async(self) -> List[Experience]: """Run the workflow asynchronously and return experiences.""" try: from agentscope.tuner import JudgeOutput, WorkflowOutput except ImportError: raise ImportError( "Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.11 is installed." ) metrics = {} workflow_output: WorkflowOutput = await self.workflow_func( self.task.raw_task, self.chat_model, self.auxiliary_chat_models ) # type: ignore [arg-type] metrics.update(workflow_output.metrics or {}) if self.judge_func is not None: assert ( workflow_output.response is not None ), "Workflow must provide response for judging." judge_output: JudgeOutput = await self.judge_func( self.task.raw_task, workflow_output.response, self.auxiliary_chat_models ) # type: ignore [arg-type] reward = judge_output.reward metrics.update(judge_output.metrics or {}) else: assert ( workflow_output.reward is not None ), "Either workflow or judge must provide reward." reward = workflow_output.reward return self.construct_experiences(reward, metrics)