10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168 | @DataConverterRegistry.register("prmbench")
class PRMBenchConverter(DataConverter):
"""
Unified converter for Process Reward Model (PRM) data
Handles mathematical reasoning data with step-wise processes
"""
# define as class attribute instead of instance attribute
DIMENSION_CLASSIFICATION_MAPPING: ClassVar[Dict[str, str]] = {
"confidence": "confidence",
"*": None, # wildcard, means no filtering
}
def convert_to_data_sample(
self, data_dict: Dict[str, Any], source_info: Dict[str, Any]
) -> DataSample:
"""Convert PRM data to DataSample format
Expected input format:
{
"original_question": "...",
"modified_question": "...",
"original_process": ["step1", "step2", ...],
"modified_process": ["step1", "step2", ...],
"modified_steps": [5, 6],
"error_steps": [5, 6],
"reason": "...",
"idx": "...",
"question": "...",
"classification": "confidence"
}
"""
# Generate unique id from idx or question
unique_id = data_dict.get(
"idx", hashlib.md5(str(data_dict.get("question", "")).encode()).hexdigest()
)
try:
# Create input from question
data_input = self._create_prm_input(data_dict)
# Create outputs from processes
data_output = self._create_prm_output(data_dict)
# Build metadata based on source type
metadata = {
"classification": data_dict.get("classification"),
"modified_steps": data_dict.get("modified_steps", []),
"error_steps": data_dict.get("error_steps", []),
"reason": data_dict.get("reason"),
"idx": data_dict.get("idx"),
"original_process_length": len(data_dict.get("original_process", [])),
"modified_process_length": len(data_dict.get("modified_process", [])),
"load_strategy": "PRMBenchConverter",
}
# Add source-specific metadata
if source_info.get("load_type") == "local":
metadata.update(
{
"source_file_path": source_info.get("source_file_path"),
"load_type": "local",
}
)
elif source_info.get("load_type") == "huggingface":
metadata.update(
{
"dataset_name": source_info.get("dataset_name"),
"dataset_config": source_info.get("dataset_config"),
"split": source_info.get("split", "train"),
"load_type": "huggingface",
}
)
# Create DataSample object
data_sample = DataSample(
unique_id=str(unique_id),
input=data_input,
output=data_output,
source="prmbench",
task_category=data_dict.get("classification", "reasoning"),
metadata=metadata,
)
return data_sample
except Exception as e:
logger.error(f"Error creating DataSample from PRM data: {str(e)}")
return None
def _create_prm_input(self, data_dict: Dict[str, Any]) -> list[ChatMessage]:
"""Create DataInput from PRM question"""
question = data_dict.get("question") or data_dict.get("original_question", "")
return [ChatMessage(role="user", content=question)]
def _create_prm_output(self, data_dict: Dict[str, Any]) -> list[DataOutput]:
"""Create DataOutput list from PRM processes"""
outputs = []
# Original process output
if "original_process" in data_dict:
original_steps = []
for i, step_content in enumerate(data_dict["original_process"]):
step = Step(
role="assistant",
content=step_content,
label={"correctness": "correct", "step_idx": i + 1},
)
original_steps.append(step)
outputs.append(
DataOutput(
answer=Step(
role="assistant",
content="\n".join(data_dict["original_process"]),
label={"process_type": "original_correct"},
),
steps=original_steps,
)
)
# Modified process output (with errors)
if "modified_process" in data_dict:
modified_steps = []
error_steps = set(data_dict.get("error_steps", []))
for i, step_content in enumerate(data_dict["modified_process"]):
step_idx = i + 1
is_correct = step_idx not in error_steps
step = Step(
role="assistant",
content=step_content,
label={
"correctness": "correct" if is_correct else "error",
"step_idx": step_idx,
},
)
modified_steps.append(step)
# Calculate correctness score based on error ratio
total_steps = len(data_dict["modified_process"])
error_count = len(error_steps)
outputs.append(
DataOutput(
answer=Step(
role="assistant",
content="\n".join(data_dict["modified_process"]),
label={
"process_type": f"Modified process with {error_count}/{total_steps} error steps"
},
),
steps=modified_steps,
)
)
return outputs
|