Skip to content

prmbench

PRMBenchConverter

Bases: DataConverter

Unified converter for Process Reward Model (PRM) data Handles mathematical reasoning data with step-wise processes

Source code in rm_gallery/gallery/data/load/prmbench.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@DataConverterRegistry.register("prmbench")
class PRMBenchConverter(DataConverter):
    """
    Unified converter for Process Reward Model (PRM) data
    Handles mathematical reasoning data with step-wise processes
    """

    # define as class attribute instead of instance attribute
    DIMENSION_CLASSIFICATION_MAPPING: ClassVar[Dict[str, str]] = {
        "confidence": "confidence",
        "*": None,  # wildcard, means no filtering
    }

    def convert_to_data_sample(
        self, data_dict: Dict[str, Any], source_info: Dict[str, Any]
    ) -> DataSample:
        """Convert PRM data to DataSample format

        Expected input format:
        {
            "original_question": "...",
            "modified_question": "...",
            "original_process": ["step1", "step2", ...],
            "modified_process": ["step1", "step2", ...],
            "modified_steps": [5, 6],
            "error_steps": [5, 6],
            "reason": "...",
            "idx": "...",
            "question": "...",
            "classification": "confidence"
        }
        """

        # Generate unique id from idx or question
        unique_id = data_dict.get(
            "idx", hashlib.md5(str(data_dict.get("question", "")).encode()).hexdigest()
        )

        try:
            # Create input from question
            data_input = self._create_prm_input(data_dict)

            # Create outputs from processes
            data_output = self._create_prm_output(data_dict)

            # Build metadata based on source type
            metadata = {
                "classification": data_dict.get("classification"),
                "modified_steps": data_dict.get("modified_steps", []),
                "error_steps": data_dict.get("error_steps", []),
                "reason": data_dict.get("reason"),
                "idx": data_dict.get("idx"),
                "original_process_length": len(data_dict.get("original_process", [])),
                "modified_process_length": len(data_dict.get("modified_process", [])),
                "load_strategy": "PRMBenchConverter",
            }

            # Add source-specific metadata
            if source_info.get("load_type") == "local":
                metadata.update(
                    {
                        "source_file_path": source_info.get("source_file_path"),
                        "load_type": "local",
                    }
                )
            elif source_info.get("load_type") == "huggingface":
                metadata.update(
                    {
                        "dataset_name": source_info.get("dataset_name"),
                        "dataset_config": source_info.get("dataset_config"),
                        "split": source_info.get("split", "train"),
                        "load_type": "huggingface",
                    }
                )

            # Create DataSample object
            data_sample = DataSample(
                unique_id=str(unique_id),
                input=data_input,
                output=data_output,
                source="prmbench",
                task_category=data_dict.get("classification", "reasoning"),
                metadata=metadata,
            )

            return data_sample

        except Exception as e:
            logger.error(f"Error creating DataSample from PRM data: {str(e)}")
            return None

    def _create_prm_input(self, data_dict: Dict[str, Any]) -> list[ChatMessage]:
        """Create DataInput from PRM question"""
        question = data_dict.get("question") or data_dict.get("original_question", "")
        return [ChatMessage(role="user", content=question)]

    def _create_prm_output(self, data_dict: Dict[str, Any]) -> list[DataOutput]:
        """Create DataOutput list from PRM processes"""
        outputs = []

        # Original process output
        if "original_process" in data_dict:
            original_steps = []
            for i, step_content in enumerate(data_dict["original_process"]):
                step = Step(
                    role="assistant",
                    content=step_content,
                    label={"correctness": "correct", "step_idx": i + 1},
                )
                original_steps.append(step)

            outputs.append(
                DataOutput(
                    answer=Step(
                        role="assistant",
                        content="\n".join(data_dict["original_process"]),
                        label={"process_type": "original_correct"},
                    ),
                    steps=original_steps,
                )
            )

        # Modified process output (with errors)
        if "modified_process" in data_dict:
            modified_steps = []
            error_steps = set(data_dict.get("error_steps", []))

            for i, step_content in enumerate(data_dict["modified_process"]):
                step_idx = i + 1
                is_correct = step_idx not in error_steps

                step = Step(
                    role="assistant",
                    content=step_content,
                    label={
                        "correctness": "correct" if is_correct else "error",
                        "step_idx": step_idx,
                    },
                )
                modified_steps.append(step)

            # Calculate correctness score based on error ratio
            total_steps = len(data_dict["modified_process"])
            error_count = len(error_steps)

            outputs.append(
                DataOutput(
                    answer=Step(
                        role="assistant",
                        content="\n".join(data_dict["modified_process"]),
                        label={
                            "process_type": f"Modified process with {error_count}/{total_steps} error steps"
                        },
                    ),
                    steps=modified_steps,
                )
            )

        return outputs

convert_to_data_sample(data_dict, source_info)

Convert PRM data to DataSample format

Expected input format: { "original_question": "...", "modified_question": "...", "original_process": ["step1", "step2", ...], "modified_process": ["step1", "step2", ...], "modified_steps": [5, 6], "error_steps": [5, 6], "reason": "...", "idx": "...", "question": "...", "classification": "confidence" }

Source code in rm_gallery/gallery/data/load/prmbench.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def convert_to_data_sample(
    self, data_dict: Dict[str, Any], source_info: Dict[str, Any]
) -> DataSample:
    """Convert PRM data to DataSample format

    Expected input format:
    {
        "original_question": "...",
        "modified_question": "...",
        "original_process": ["step1", "step2", ...],
        "modified_process": ["step1", "step2", ...],
        "modified_steps": [5, 6],
        "error_steps": [5, 6],
        "reason": "...",
        "idx": "...",
        "question": "...",
        "classification": "confidence"
    }
    """

    # Generate unique id from idx or question
    unique_id = data_dict.get(
        "idx", hashlib.md5(str(data_dict.get("question", "")).encode()).hexdigest()
    )

    try:
        # Create input from question
        data_input = self._create_prm_input(data_dict)

        # Create outputs from processes
        data_output = self._create_prm_output(data_dict)

        # Build metadata based on source type
        metadata = {
            "classification": data_dict.get("classification"),
            "modified_steps": data_dict.get("modified_steps", []),
            "error_steps": data_dict.get("error_steps", []),
            "reason": data_dict.get("reason"),
            "idx": data_dict.get("idx"),
            "original_process_length": len(data_dict.get("original_process", [])),
            "modified_process_length": len(data_dict.get("modified_process", [])),
            "load_strategy": "PRMBenchConverter",
        }

        # Add source-specific metadata
        if source_info.get("load_type") == "local":
            metadata.update(
                {
                    "source_file_path": source_info.get("source_file_path"),
                    "load_type": "local",
                }
            )
        elif source_info.get("load_type") == "huggingface":
            metadata.update(
                {
                    "dataset_name": source_info.get("dataset_name"),
                    "dataset_config": source_info.get("dataset_config"),
                    "split": source_info.get("split", "train"),
                    "load_type": "huggingface",
                }
            )

        # Create DataSample object
        data_sample = DataSample(
            unique_id=str(unique_id),
            input=data_input,
            output=data_output,
            source="prmbench",
            task_category=data_dict.get("classification", "reasoning"),
            metadata=metadata,
        )

        return data_sample

    except Exception as e:
        logger.error(f"Error creating DataSample from PRM data: {str(e)}")
        return None