Source code for trinity.common.workflows.rubric_judge_workflow

# -*- coding: utf-8 -*-
"""A workflow with LLM-as-a-judge."""
import json
from typing import List, Optional, Tuple

import openai

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task


[docs] @WORKFLOWS.register_module("rubric_judge_workflow") class RubricJudgeWorkflow(SimpleWorkflow): """A workflow using LLM-as-a-judge and rubrics to get the reward. Adapted from https://arxiv.org/pdf/2507.17746 """
[docs] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, )
[docs] def reset(self, task: Task): """Modified from SimpleWorkflow.reset""" self.format_args = task.format_args self.system_prompt = task.format_args.system_prompt self.reply_prefix = task.format_args.reply_prefix if self.system_prompt is None: self.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. """ self.raw_task = task.raw_task self.task_desc = task.task_desc self.truth = task.truth self.rubric = self.raw_task.get("rubric", [])
[docs] def run(self) -> List[Experience]: """Modified from SimpleWorkflow.run""" messages = self.format_messages() self.logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) # === Calculate rubric-based rewards === assert ( self.auxiliary_models is not None ), "Current implementation of rubric-based rewards requires that auxiliary_models is not None." judge_success_list = [] for i, response in enumerate(responses): judge_success, reward = self.get_judge_reward( response=response.response_text, judger=self.auxiliary_models[0] ) response.reward = reward response.eid.run = i + self.run_id_base judge_success_list.append(judge_success) if i == 0: self.logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {response.reward}" ) # record judge success judge_success_rate = ( sum(judge_success_list) / len(judge_success_list) if judge_success_list else 0.0 ) for response in responses: if response.metrics is None: response.metrics = {} response.metrics.update({"judge_success": float(judge_success_rate)}) return responses
[docs] def get_judge_reward(self, response: str, judger: openai.OpenAI) -> Tuple[bool, float]: """Get rewards with LLM-as-a-judge The prompts are adapted from RAR-IMPLICIT method in https://arxiv.org/pdf/2507.17746 """ # Step 1: format prompts # system prompt ruler_system_prompt = """You are an expert evaluator. Given a user prompt, a generated response, and a list of quality rubrics, please rate the overall quality of the response on a scale of 1 to 10 based on how well it satisfies the rubrics. Consider all rubrics holistically when determining your score. A response that violates multiple rubrics should receive a lower score, while a response that satisfies all rubrics should receive a higher score. Start your response with a valid JSON object that starts with "```json" and ends with "```". The JSON object should contain a single key "rating" and the value should be an integer between 1 and 10. Example response: ```json { "rating": 7 }```""" # user prompt if len(self.rubric) > 0: rubric_prompt_parts = [ f"Rubric {i} (weight: {single_rubric['weight']}): {single_rubric['description']}" for i, single_rubric in enumerate(self.rubric) ] rubric_list_string = "\n".join(rubric_prompt_parts) else: self.logger.warning("No rubric is provided!") rubric_list_string = "Rubrics are not provided." ruler_user_prompt = f"""Given the following prompt, response, and rubrics, please rate the overall quality of the response on a scale of 1 to 10 based on how well it satisfies the rubrics. <prompt> {self.task_desc} </prompt> <response> {response} </response> <rubrics> {rubric_list_string} </rubrics> Your JSON Evaluation: """.strip() # Step 2: invoke judger LLM messages = [ {"role": "system", "content": ruler_system_prompt}, {"role": "user", "content": ruler_user_prompt}, ] completion = judger.chat.completions.create( model=judger.model_path, messages=messages, stream=False, temperature=0.0 ) judger_response = completion.choices[0].message.content self.logger.debug(f"LLM judge response: {judger_response}") # Step 3: extract score from judger's response (expecting a JSON block with "rating") try: # Extract content between ```json and ``` start_tag = "```json" start_index = judger_response.find(start_tag) if start_index == -1: start_tag = "```" start_index = judger_response.find(start_tag) if start_index == -1: self.logger.warning("No JSON code block found in judger response.") return False, 0.0 end_index = judger_response.find("```", start_index + len(start_tag)) if end_index == -1: self.logger.warning("Malformed JSON code block in judger response.") return False, 0.0 json_str = judger_response[start_index + len(start_tag) : end_index].strip() parsed = json.loads(json_str) rating = parsed.get("rating") if not isinstance(rating, (int, float)) or not (1 <= rating <= 10): self.logger.warning(f"Invalid or out-of-range rating: {rating}") return False, 0.0 normalized_score = rating * 0.1 # Normalize 1-10 to 0-1 scale return True, normalized_score except json.JSONDecodeError as e: self.logger.warning(f"Failed to parse JSON from judger response: {e}") return False, 0.0 except Exception as e: self.logger.warning(f"Unexpected error when processing judger response: {e}") return False, 0.0