Bases: BaseModule
A module implementing iterative response refinement using LLM and reward feedback.
Attributes:
Name |
Type |
Description |
reward |
BaseReward
|
Reward for evaluating response quality
|
llm |
BaseLLM
|
Language model client for generating responses
|
max_iterations |
int
|
Maximum number of refinement iterations
|
Source code in rm_gallery/core/reward/refinement.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145 | class LLMRefinement(BaseModule):
"""
A module implementing iterative response refinement using LLM and reward feedback.
Attributes:
reward: Reward for evaluating response quality
llm: Language model client for generating responses
max_iterations: Maximum number of refinement iterations
"""
reward: BaseReward = Field(default=..., description="reward")
llm: BaseLLM = Field(default=..., description="llm client")
max_iterations: int = Field(default=3, description="max iterations")
def _generate_response(
self,
sample: DataSample,
feedback: str | None = None,
**kwargs,
) -> DataSample:
"""
Generate refined response based on conversation history and feedback.
Args:
sample: DataSample object containing input and previous responses
feedback: Quality assessment feedback for previous responses
**kwargs: Additional parameters for LLM generation
Returns:
Generated response as a DataSample object
"""
# Construct prompt based on feedback availability
if feedback is None:
prompt = """# Task
Please generate a respoonse as the conversation required.
# Conversation history
{history}
""".format(
history=format_messages(sample.input)
)
else:
prompt = """# Task
Please generate a better response based on the feedback provided on candidate responses.
# Conversation history
{history}
# Responses
{responses}
# Feedback
{feedback}
""".format(
history=format_messages(sample.input),
responses="\n".join(
[
f"<response_{i}>{output.answer.content}</response_{i+1}>"
for i, output in enumerate(sample.output)
]
),
feedback=feedback,
)
respoonse = self.llm.simple_chat(prompt)
sample.output.append(
DataOutput(
answer=Step(role=MessageRole.ASSISTANT, content=filter_think(respoonse))
)
)
return sample
def _generate_feedback(self, sample: DataSample, **kwargs) -> str:
"""
Generate quality feedback for a response sample.
Args:
sample: Data sample containing input-response pair for evaluation
**kwargs: Additional parameters for reward evaluation
Returns:
Feedback string describing response quality assessment
"""
# Evaluate response quality using reward module
sample = self.reward.evaluate(sample)
# safety check
if (
len(sample.output) > 0
and hasattr(sample.output[0].answer, "reward")
and len(sample.output[0].answer.reward.details) > 0
):
feedback = sample.output[0].answer.reward.details[0].reason
else:
feedback = "No valid evaluation feedback available."
return feedback
def run(self, sample: DataSample, **kwargs) -> DataSample:
"""
Execute iterative response refinement process.
Args:
sample: Data sample containing input for refinement
**kwargs: Additional parameters for generation and evaluation
Returns:
Final refined response as a DataSample object
"""
sample = deepcopy(sample)
if len(sample.output) == 0:
# Initial response generation
response = self.llm.chat(sample.input)
sample.output.append(
DataOutput(
answer=Step(
role=MessageRole.ASSISTANT,
content=filter_think(response.message.content),
)
)
)
# Iterative refinement loop
for i in range(self.max_iterations):
# Generate feedback and create refined response
feedback = self._generate_feedback(sample, **kwargs)
sample = self._generate_response(sample, feedback, **kwargs)
return sample
|
run(sample, **kwargs)
Execute iterative response refinement process.
Parameters:
Name |
Type |
Description |
Default |
sample
|
DataSample
|
Data sample containing input for refinement
|
required
|
**kwargs
|
|
Additional parameters for generation and evaluation
|
{}
|
Returns:
Type |
Description |
DataSample
|
Final refined response as a DataSample object
|
Source code in rm_gallery/core/reward/refinement.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145 | def run(self, sample: DataSample, **kwargs) -> DataSample:
"""
Execute iterative response refinement process.
Args:
sample: Data sample containing input for refinement
**kwargs: Additional parameters for generation and evaluation
Returns:
Final refined response as a DataSample object
"""
sample = deepcopy(sample)
if len(sample.output) == 0:
# Initial response generation
response = self.llm.chat(sample.input)
sample.output.append(
DataOutput(
answer=Step(
role=MessageRole.ASSISTANT,
content=filter_think(response.message.content),
)
)
)
# Iterative refinement loop
for i in range(self.max_iterations):
# Generate feedback and create refined response
feedback = self._generate_feedback(sample, **kwargs)
sample = self._generate_response(sample, feedback, **kwargs)
return sample
|