DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving

link: https://arxiv.org/abs/2407.13690

Authors: Yuxuan Tong (Tsinghua University), Xiwen Zhang (Helixon Research), Rui Wang (Helixon Research), Ruidong Wu (Helixon Research), Junxian He (HKUST)

Introduction

Mathematical reasoning remains one of the most challenging domains for large language models (LLMs). Despite recent advances, even state-of-the-art models struggle with complex mathematical problems. A common approach to improve mathematical capabilities is instruction tuning with synthetic data generated through rejection sampling from stronger models. However, my analysis of this paper reveals a critical flaw in current methods: they inadvertently bias training data toward easier problems.

The DART-Math paper identifies and addresses this bias through Difficulty-Aware Rejection Tuning, a novel approach that strategically rebalances training data distribution based on problem difficulty. The results show significant improvements across mathematical benchmarks while requiring smaller datasets and avoiding reliance on proprietary models.

The Problem: Bias Toward Easy Problems

In standard rejection tuning, practitioners sample a fixed number of response candidates M for each mathematical query, then filter out incorrect responses. The authors demonstrate that this seemingly neutral approach creates a systematic bias:

For difficult problems, fewer of the M samples yield correct answers, sometimes even zero. The MetaMathQA dataset, for instance, completely drops 51.1% of the most difficult MATH training queries and allocates only 10.5% of examples to the hardest problems despite their prevalence in the original dataset.

Mathematically, vanilla rejection sampling works as follows:

  • For each query $x_i$, sample $M$ responses: ${(x_i, y_i{(j)})}_{j=1}{M}$
  • Keep only pairs where $\text{IsCorrect}(y_i^{(j)}) = \text{True}$

This approach fails to account for the fact that $P(\text{IsCorrect}(y_i^{(j)}) = \text{True})$ varies dramatically based on problem difficulty.

The DART Solution

The authors propose two alternative sampling strategies:

1. DARS-Uniform

This strategy ensures balanced representation by collecting the same number of correct responses for each query:

  • For each query $x_i$, continue sampling until accumulating exactly $k_u$ correct responses
  • If after $n_{max}$ sampling attempts fewer than $k_u$ correct responses are found, stop and keep what's collected
  • Mathematically: Continue sampling until $|{y_i^{(j)} : \text{IsCorrect}(y_i^{(j)}) = \text{True}}| = k_u$ or number of attempts reaches $n_{max}$

2. DARS-Prop2Diff

This strategy deliberately oversamples difficult problems, allocating resources proportional to problem difficulty:

  • Calculate the difficulty of each query $x_i$ as its fail rate: $d_i = \frac{\text{# incorrect responses}}{\text{# total responses}}$
  • For each query, collect $\min(k_p \cdot d_i, n_{max})$ correct responses, where $k_p$ is a hyperparameter
  • Ensure at least one correct response for each query
  • Mathematically: Target number of correct responses for query $x_i$ = $\max(1, \min(k_p \cdot d_i, n_{max}))$

Implementation Details

The authors implement difficulty assessment using a "fail rate" metric - the proportion of incorrect responses when sampling from a strong model (DeepSeekMath-7B-RL). They then use this model to synthesize all training data, setting:

  • $k_u = 40$ for DARS-Uniform
  • $k_p = 192$ for DARS-Prop2Diff
  • Sampling temperature = 1.6 (higher than standard to improve diversity)
  • Top-p = 0.95, maximum output tokens = 2048

This process yields two datasets of approximately 590k examples each: DART-Math-Uniform and DART-Math-Hard.

Experimental Results

The authors evaluate their approach across six mathematical benchmarks using four base models:

Performance Improvements

Across 7-8B general base models (Mistral-7B and Llama3-8B):

  • MATH benchmark: +6.8 and +6.9 percentage points improvement over vanilla rejection tuning
  • Averaged across 6 benchmarks: +4.5 percentage points improvement for the Prop2Diff strategy

For Llama3-8B specifically:

  • MATH: Improved from 21.2% to 46.6%
  • GSM8K: Improved from 51.0% to 82.5%

Ablation Studies

Several ablation studies validate the approach:

  1. One-Response Coverage: For smaller training data sizes, ensuring at least one synthetic response per query significantly improved performance on simpler problems like GSM8K (+8 percentage points).
  2. Scaling Behavior: Both DART strategies consistently outperformed vanilla rejection tuning as training data size increased from thousands to millions of examples, with more pronounced differences for general models compared to math-specialized models.
  3. Domain Analysis: Improvements were consistent across all mathematical domains, with the largest gains in Geometry (+9.2 points) and Number Theory (+5.8 points).
  4. Comparison with RL: DART (a supervised fine-tuning approach) achieved comparable performance to reinforcement learning methods, demonstrating that well-designed supervised fine-tuning can match or exceed RL approaches.

Synthesis Cost Analysis

The authors acknowledge that DART requires more sampling to create datasets of comparable size:

  • At $n_{max}=2048$, DARS-Uniform could achieve the target for 90% of queries
  • DARS-Prop2Diff needed $n_{max}=8192$ to reach similar coverage
  • The total sampling cost was higher (15+ million samples vs. 5 million for DARS-Uniform)

However, as a one-time cost amortized across all future usages, this increased synthesis expense is negligible compared to the performance benefits.

Conclusion

The DART-Math paper presents a compelling case for rethinking how we create training data for mathematical reasoning. By identifying and addressing the bias toward easy problems, the authors achieve substantial improvements across diverse benchmarks while using smaller datasets and an open-source model for synthesis.

The key insights generalize beyond mathematical reasoning:

  1. The distribution of training data significantly impacts model performance
  2. Difficult examples are crucial for effective learning
  3. Naively applying seemingly neutral sampling methods can introduce harmful biases

DART-Math demonstrates that thoughtful curation of training data distribution can be as important as increasing data quantity or model size, offering a cost-effective path toward improving mathematical reasoning capabilities in language models.

Read more