diff --git a/LLMs/training/train_grpo.py b/LLMs/training/train_grpo.py index 0d844e0..819d796 100644 --- a/LLMs/training/train_grpo.py +++ b/LLMs/training/train_grpo.py @@ -77,7 +77,8 @@ def strict_format_reward_func(completions, **kwargs) -> list[float]: def soft_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" - pattern = r".*?\s*.*?" + #pattern = r".*?\s*.*?" + pattern = r"[\s\S]*\s*.*?" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] return [0.5 if match else 0.0 for match in matches]