Source code for trinity.common.workflows.envs.sciworld.sciworld_workflow

# -*- coding: utf-8 -*-
import json
from typing import List, Optional

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task

SCIWORLD_SYSTEM_PROMPT = """
You are an agent, you job is to do some scientific experiment in a virtual test-based environments.

## Notes:
At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the <think> </think> tag and wrap your action with the <action> </action> tag.
You should ALWAYS take one action each step.
DONOT try to interact with the user at anytime. Finish the task by yourself.

## Action Format:
Below are the available commands you can use:
    open OBJ: open a container
    close OBJ: close a container
    activate OBJ: activate a device
    deactivate OBJ: deactivate a device
    connect OBJ to OBJ: connect electrical components
    disconnect OBJ: disconnect electrical components
    use OBJ [on OBJ]: use a device/item
    look around: describe the current room
    examine OBJ: describe an object in detail
    look at OBJ: describe a container's contents
    read OBJ: read a note or book
    move OBJ to OBJ: move an object to a container
    pick up OBJ: move an object to the inventory
    pour OBJ into OBJ: pour a liquid into a container
    mix OBJ: chemically mix a container
    teleport to LOC: teleport to a specific room
    focus on OBJ: signal intent on a task object
    wait: task no action for 10 steps
    wait1: task no action for a step

For example your output should be like this:
<think> Now I will check the bedroom ... </think><action>teleport to bedroom</action>
"""


def format_observation(observation: str):
    return "Observation: \n" + observation


def parse_action(response):
    try:
        # parse the action within the <action> </action> tag
        action = response.split("<action>")[1].split("</action>")[0].strip()
        return action
    except Exception as e:
        print("Error parsing action:", e)
        return ""


[docs] @WORKFLOWS.register_module("sciworld_workflow") class SciWorldWorkflow(MultiTurnWorkflow): """A workflow for sciworld task."""
[docs] def __init__( self, model: ModelWrapper, task: Task, auxiliary_models: Optional[List] = None, ): super().__init__( model=model, task=task, ) self.task_desc = task.task_desc or "0" self.repeat_times = task.rollout_args.n self.max_env_steps = 30 # should be less than 100
[docs] def get_model_response(self, messages): responses = self.model.chat(messages, n=1) return responses
[docs] def get_model_response_text(self, messages): return self.get_model_response(messages)[0].response_text
[docs] def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: # TODO: Make this parallel print("Generating env inference samples...") golden_rounds = len(env.get_gold_action_sequence()) experience_list = [] for i in range(rollout_num): observation, info = env.reset() observation = ( "Task Description: " + str(env.get_task_description()) + "\n" + observation ) final_reward = 0.0 current_reward = 0.0 memory = [] memory.append({"role": "system", "content": SCIWORLD_SYSTEM_PROMPT}) for r in range(self.max_env_steps): format_obs = format_observation(observation) memory = memory + [{"role": "user", "content": format_obs}] response_text = self.get_model_response_text(memory) memory.append({"role": "assistant", "content": response_text}) action = parse_action(response_text) observation, reward, done, info = env.step(action) current_reward += reward final_reward = max(current_reward, final_reward) if done: break final_reward = final_reward / 100.0 experience = self.process_messages_to_experience( memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds}, ) experience_list.append(experience) # Close the env to save cpu memory env.close() return experience_list
[docs] def run(self) -> List[Experience]: # assume the task_description is the json object containing task index and the var_num # see Trinity-RFT/script/data_prepare/get_scriworld_data.py task_desc = self.task_desc task_config = json.loads(task_desc) rollout_n = self.repeat_times # TODO: Make parallel envs try: from scienceworld import ScienceWorldEnv def create_environment(task_config): var_num = task_config["var_num"] task_name = task_config["task_name"] jar_path = task_config["jar_path"] simplificationStr = "easy" env = ScienceWorldEnv("", jar_path, envStepLimit=100) env.load(task_name, var_num, simplificationStr, generateGoldPath=True) return env except Exception as e: print("Please make sure you have installed the sciworld package.") error_message = f"Error importing SciWorldTWEnv {str(e)}. Please make sure you have installed the sciworld package successfully, following the instructions in https://github.com/allenai/ScienceWorld" raise ImportError(error_message) env = create_environment(task_config) return self.generate_env_inference_samples(env, rollout_n)