Implementing GRPO
Today we go a bit deeper into GRPO by implementing it. I will first talk a little bit about the concept, discuss our approach and then start implementing it. This will be fun. Thanks to Will for showing us the way!
What is GRPO?
GRPO is a training technique designed to optimize language models using reward functions that capture specific preferences, which we have discussed in our last post GRPO, simplified. Unlike other reinforcement learning approaches like PPO or RLHF that might require complex critic models and compute, GRPO directly optimizes language model while computing relative advantages within a groups of generated responses.
Key Aspects of GRPO
What Makes GRPO Different?
GRPO is a recent reinforcement learning technique that offers several advantages over traditional approaches:
- Direct optimization: Unlike methods that require a separate reward model, GRPO directly optimizes the language model using explicit reward functions.
- Multiple reward signals: You can define multiple reward functions that target different aspects of generation (correctness, format, style).
- Exploration efficiency: GRPO explores the output space by generating multiple completions for each prompt during training.
Reward Functions
The code implements several reward functions that work together to guide the model:
correctness_reward_func
: Awards 2.0 points when the model's extracted answer matches the ground truth answer. This is the primary learning signal for factual correctness.int_reward_func
: Gives 0.5 points when the answer is a digit, which is appropriate for math problems. This helps guide the model toward numerical responses.soft_format_reward_func
andstrict_format_reward_func
: Reward proper XML formatting (0.5 points). This teaches the model to structure its responses with proper tags.xmlcount_reward_func
: Provides partial credit (0.125 points per tag) for each correctly used XML tag, creating a smoother learning gradient.
This implementation should provide good insights into how GRPO works and how it can be used to optimize language models for specific formats and tasks. The math problem solving task with XML formatting serves as a clear example of the technique's capabilities. Fun session, I'd say.
Implementation Consideration
Using a smaller model (Qwen2.5-1.5B-Instruct) that fits well under memory constraints. Reducing batch sizes and generation counts to manage memory usage along with a smaller dataset subset (20 examples) for quick experimentation. For testing our implementation we have included some testing code to immediately evaluate results after training. You can increase max_samples
for more thorough training or try different reward functions.
Code Example: Reward Function
Here is an example of one of the reward functions implemented:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""
Reward function that checks if the extracted answer matches the ground truth.
Returns 2.0 for correct answers, 0.0 otherwise.
"""
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
# Optional debugging print for the first example
if kwargs.get('debug', False) and len(responses) > 0:
q = prompts[0][-1]['content']
print("--------------------")
print(f"Question:\n{q}")
print(f"\nGround Truth Answer:\n{answer[0]}")
print(f"\nModel Response:\n{responses[0]}")
print(f"\nExtracted Answer:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
This function evaluates the correctness of the model's extracted answer against the known ground truth.
Conclusion
This implementation should provide good insights into how GRPO works and how it can be used to optimize language models for specific formats and tasks. The math problem solving task with XML formatting serves as a clear example of the technique's capabilities. Fun session, I'd say.