March 12, 2025
Present Large Language Models (LLM) self-training methods always under-sample on challenging queries, leading to inadequate learning on difficult problems which limits LLMs’ ability. Therefore, this work proposes a difficulty-aware self-training (DAST) framework that focuses on improving both the quantity and quality of self-generated responses on challenging queries during self-training. DAST is specified in three components: 1) sampling-based difficulty level estimation, 2) difficulty-aware data augmentation, and 3) the self-training algorithm using SFT and DPO respectively. Experiments on mathematical tasks demonstrate the effectiveness and generalization of DAST, highlighting the critical role of difficulty-aware strategies in advancing LLM self-training.
What doesn’t kill you makes you stronger. — Friedrich Wilhelm Nietzsche
The lack of extensive, high-quality human-curated training data for Large Language Models (LLMs) constrains the potential upper bounds of their capacities, particularly on complex reasoning tasks [1]. Recently, self-training techniques of LLMs have garnered increasing attention, which iteratively fine-tunes LLMs on their self-generated outputs, attaining sustained improvements and diminishing the reliance on human interventions [2]–[5].
To ensure the quality of LLMs’ self-generated training data, previous works employ rejection sampling [6] to filter out low-quality or incorrect responses with external reward models [2] or ground-truth labels [3]. This may lead to LLM over-sampling originally adept simple queries while under-sampling challenging queries [7], [8]. LLMs’ insufficient learning in challenging instances is primarily in two aspects during self-training. First, when fixing the sampling number, only a few even or no correct responses are acquired on challenging queries, which iteratively exacerbates the distribution imbalance of the training data and severely overfitting on simple questions (Left hand of Figure 1 (a)). Second, the lengths of sampled self-generated responses on difficult questions are not enough (Right hand of Figure 1 (a)). Given that challenging problems require more thinking steps [9], [10], the quality of these responses tends to be lower. As a result, LLMs can not adequately learn from challenging tasks, thereby restricting their capacity improvements.
Considering the above two issues, this work proposes a difficulty-aware self-training (DAST) framework which focuses on increasing both the quantity and quality of self-generated responses on challenging queries during self-training: 1) DAST employs a sampling-based, model-specific method to estimate the difficulty level of each query. 2) Two data augmentation approaches are employed to balance the distribution and improve the response quality of training data given the difficulty levels. Specifically, we perform up-sampling on challenging questions to control the data proportion of different difficulty levels. We also employ a difficulty-matched few-shot prompting method to control the lengths of responses, encouraging LLMs to increase thinking steps on challenging questions. These two methods are combined incrementally. 3) We finally iteratively perform the above difficulty estimation and data augmentation steps in several rounds for LLM self-training using supervised fine-tuning (SFT) and direct preference optimization (DPO) [11] respectively.
Experiments are conducted on both the in-domain and out-of-domain tasks on various mathematical datasets. Results demonstrate that DAST significantly enhances LLMs’ math ability and generalizability over several baselines.
Our contributions are as follows: 1) This work first comprehensively incorporates difficulty level into LLM self-training, demonstrating the significance of considering difficulty for future works. Our codes are available on https://github.com/AmourWaltz/DAST
; 2) We propose two data augmentation methods in DAST to improve both quantity and quality on challenging queries using the estimated difficulty level; 3) We
conduct experiments and validate that DAST can enhance LLM’s math ability and generalizability using SFT and DPO respectively.
Figure 1: Changes of data proportion and response length distribution of samples in different difficulty levels during a three-round self-training process. The vanilla rejection sampling to construct training data (a) is widely employed in [2], [3], [5], [6]. (b) and (c) are the proposed DAST aim to control data proportion and response lengths for challenging queries. Note that in iteration 0, the training data \(\mathcal{D}_{\mathrm u}\) is the original dataset \(\mathcal{D}_{\mathrm o}\) with ground-truth labels, while during iteration 1, 2, and 3, the training data is combined of self-generated data \(\mathcal{D}_{\mathrm a}\) and the original dataset \(\mathcal{D}_{\mathrm o}\). All the difficulty levels are measured on the initial policy \(\mathcal{M}_0\) on the GSM8K test set and are fixed during self-training.
We employ a sampling-based, model-specific method to estimate the difficulty level of each question to the model. Given the initial policy \(\mathcal{M}_0\) and the training set \(\mathcal{D}_{\mathrm{o}}={\left \{ \boldsymbol{x}_i, \boldsymbol{\hat{r}}_i, \boldsymbol{\hat{y}}_i \right \}}_{i=1}^{N}\), where \(\boldsymbol{x}_i, \boldsymbol{r}_i, \boldsymbol{y}_i\) represent the question, rationale, and the ground-truth answer respectively. Each rationale \(\boldsymbol{\hat{r}}_i = \left [\hat{r}_{i,1}, \dots,\hat{r}_{i,l} \right ]\) contains \(l\) reasoning steps where \(l\) varies in \(\boldsymbol{\hat{r}}_i\). For each \((\boldsymbol{x}_i, \boldsymbol{\hat{r}}_i, \boldsymbol{\hat{y}}_i)\) and a prompt set \(\mathcal{P}\) containing \(K\) different few-shot prompts, we employ each few-shot exemplar \(\boldsymbol{p}_{k}\in \mathcal{P}\) with the question \(\boldsymbol{x}_i\) for the policy \(\mathcal{M}_{0}\) to generate the \(k\)-th response \(\left ({\boldsymbol{y}_i}^{(k)},{\boldsymbol{r}_i}^{(k)}\right )=\mathcal{M}_{0}\left (\boldsymbol{p}_{k}, \boldsymbol{x}_i\right )\) using temperature sampling (\(T=0.2, \mathrm{top}~p=0.9\)). We obtain the response set \(\boldsymbol{Y}_{i}=\{{\boldsymbol{y}_i}^{(k)}\}_{k=1}^{K}\) and the label set \(\boldsymbol{Z}_i = {\left \{ {z_{i}}^{(k)} \right \}}_{k=1}^{K}\) by comparing each extracted answers in \(\boldsymbol{Y}_{i}\) with the ground-truth \(\boldsymbol{\hat{y}}_i\) to determine the correctness (\({z_{i}}^{(k)}\in\left \{ 0,1 \right \}\), 1 for True and 0 for False). The difficulty level \(d_i\) is estimated as follows 1. Details and splits of four difficulty levels are in Table 2.
\[\begin{align} \label{eq:diff} d_i=P\left (\boldsymbol{Y}_i|\boldsymbol{x}_i\right)=\frac{\sum_{k=1}^{K}\mathbb{I}\left ({\boldsymbol{y}_i}^{(k)}={\boldsymbol{\hat{y}}}_i\right)}{K} \end{align}\tag{1}\]
We augment \(\mathcal{D}_{\mathrm o}\) with the strategy \(\mathcal{A}(\cdot)\) for each query \(\boldsymbol{x}_i\) according to \(d_i\) by controlling the data proportion and response lengths on \(\mathcal{M}\) to obtain an augmented dataset \(\mathcal{D}_{\mathrm a}\) for self-training as follows.
As in the left hand of Figure 1 (a), the construction of self-training data using rejection sampling may bias simple questions. Therefore, we set different sampling numbers \(K\) for different difficulty levels \(d_i\) of \(\boldsymbol{x}_i\). More specifically, the sampling number \(K\) will multiply by a coefficient \(\beta\) determined by \(d_i\) as presented in Table 2. For \(d_i\in \{M, H, U\}\) which indicates that \(\boldsymbol{x}_i\) is a challenging question, \(\beta\) is larger to increase the number of correct responses sampled from the policy \(\mathcal{M}\). The sampled responses will be added into \(\mathcal{D}_{\mathrm a}\). As illustrated in Figure 1 (b), we can dynamically control the proportion of samples in all difficulty levels and balance the distribution of the training data in each self-training iteration.
As in the right hand of Figure 1 (a), the lengths of responses generated using the vanilla few-shot sampling method are in averaged length for all difficulty levels during self-training (iterations 1, 2, and 3) and relatively shorter than lengths of the ground-truth responses in \(\mathcal{D}_{\mathrm o}\) (iteration 0). To generate lengthy and difficulty-matched responses, we propose a difficulty-matched few-shot (DMFS) prompting method: for each difficulty level \(d\in\{E, M, H, U\}\), we select samples from the training set that exceed the average response length of this difficulty level to construct four prompt sets \(\mathcal{P}_{E}, \mathcal{P}_{M}. \mathcal{P}_{H}, \mathcal{P}_{U}\). DMFS examples are employed based on \(d_i\) to sample responses for \(\boldsymbol{x}_i\) on \(\mathcal{M}\). Sampled responses will be added into \(\mathcal{D}_{\mathrm a}\). Therefore, length distribution of \(\mathcal{D}_{\mathrm a}\) is close to the ground truth in iteration 0 as in Figure 1 (c), which improves the response quality with more thinking steps [9], [12].
Figure 2: DAST Algorithm
As presented in Algorithm 2, in the \(t\)-th iteration, the training set \(\mathcal{D}_{\mathrm{u}}\) is updated by merging the augmented dataset \(\mathcal{D}_{\mathrm{a}}^{(t)}\) and initial training set \(\mathcal{D}_{\mathrm{t}}\), ensuring \(\mathcal{D}_{\mathrm{u}}\) doesn’t diverge too much from \(\mathcal{D}_{\mathrm{t}}\). The policy \(\mathcal{M}_j\) is fine-tuned based on \(\mathcal{M}_{j-1}/\mathcal{M}_{0}\) on \(\mathcal{D}_u\) using SFT\(/\)DPO [11] by optimizing \(\mathcal{L}_{\mathrm{sft}}/\mathcal{L}_{\mathrm{dpo}}\) in Equation 2 \(/\)3 respectively. \(\mathcal{M}_j\) is trained to be converged while the accuracy doesn’t increase on the validation set \(\mathcal{D}_{\mathrm v}\). Specifically, we denote DAST using SFT/DPO by DAST-S/DAST-D. For DAST-S, we investigate only employing data proportion control or length control, and denote by DAST-P and DAST-L respectively.
During the training stage, we jointly combine training sets from GSM8K [1] and MATH [13] as \(\mathcal{D}_{\mathrm t}\). We evaluate in-domain (ID) performance on the corresponding test sets. We also assess the out-of-domain (OOD) performance three challenging test sets: TAL-SCQ [14] College [15], and TheoremQA [16]. We standardize the data format as in Appendix 10 and employ the evaluation script of MWPBench 2 [15] to judge the correctness of the extracted answer compared with the ground-truth label. Dataset details are in Appendix 8.
We utilize in-context learning (ICL) [17] to generate responses. We also employ several SFT-based and DPO-based baselines. SFT-based baselines include: 1) single-round standard SFT and difficulty-aware rejection tuning (DART) [8] (specified in DART-Uniform and DART-Prob2Diff); and 2) multi-round ReST-EM [3]. DPO-based [11] baselines include single- and multi-round DPO (DPO and mDPO). Detailed implementations of the above baselines can be referred to Appendix 9.
Figure 3: Performance results of DAST over various baselines on both in-domain (ID) and out-of-domain (OOD) mathematical test sets using Llama-3.1. Note that the names of employed baselines are in lowercase.
Experiments are conducted on Llama-3.1-8B (Llama-3.1) [18] in this work. As in Figure 3, several findings can be found below.
1. With different sizes of self-training data in each iteration, DAST-S and DAST-D consistently yield superior performance over corresponding SFT and DPO baselines with comparable or less data, exhibiting the effectiveness and efficiency of DAST for both SFT and DPO during self-training. Data size statistics are presented in Table 3.
2. DAST-P exhibits better performance compared to DAST-L, suggesting that increasing the data size can gain more improvements than increasing the response lengths for challenging queries. This can be attributed to that the initial policy is suboptimal and the sampled lengthy responses are also low-quality. Therefore, raising the data quantity can lead to more obvious gains.
3. DAST-S and DAST-P can better generalize to OOD tasks than others. DAST enables LLMs to adequately learn more diverse challenging questions, thereby achieving more pronounced improvements in relatively challenging OOD tasks.
In this part, we investigate the research question "As self-training progresses iteratively, will increasing the proportion of challenging samples lead to further improvements?". We control the proportions of challenging queries with fixed data size in each iteration by adjusting \(\beta\) during self-training as illustrated in Figure 4. Results suggest that LLMs perform better when trained on the dataset with a balanced distribution (DAST-P-\(\alpha 1\)) of different difficulty levels than more hard samples (DAST-P-\(\alpha 2\)) during self-training. Excessive challenging samples may lead to a large distribution shift, affecting LLMs’ original abilities on simple queries.
Figure 4: Results of data proportion control.
In this part, we investigate the research question "Will the performance be further improved by employing difficult examples across all queries to generate lengthy responses during self-training?". We generate training data using few-shot examples from solely a single difficulty level in the first round of DAST to compare with our proposed difficulty-matched few-shot (DMFS) prompting method for sampling. Results in Table 1 suggest that training data generated by DMFS outperforms those obtained from any single level. Tailoring response length to difficulty levels of queries is more effective, as sampling lengthy responses to simple queries may result in overthinking and undermine performances [19].
Exam. Level | \(E\) | \(M\) | \(H\) | \(U\) | DMFS |
---|---|---|---|---|---|
ID | 35.58 | 37.44 | 38.90 | 38.66 | 41.94 |
OOD | 11.45 | 12.15 | 12.48 | 12.06 | 13.07 |
This work proposes a DAST framework to enhance both the quantity and quality of challenging queries during the self-training process, including three key parts: difficulty level estimation, data augmentation, and a self-training algorithm. Experiments conducted on math tasks using SFT and DPO showcase the effectiveness and generalization of DAST.
The limitations of this work are as follows:
This work enhances the response quality by solely increasing the thinking steps or lengths of responses. Although improving response quality by adding length is simple yet effective for challenging queries, more explorations should be conducted to comprehensively evaluate the response quality in other dimensions.
Another limitation is that the experiments are solely conducted on mathematical reasoning tasks. This constraint primarily arises from that many tasks like long-form generations are also challenging to evaluate the generation quality. Future research endeavors should prioritize a wider range of datasets of long-form generation tasks to thoroughly assess the applicability and effectiveness of DAST.
The definitions of the notations in this work are summarized in Table [table:notation].
\(p\) | Difficulty Level | Denotation \(d_j\) | \(\beta\) |
---|---|---|---|
Easy | \(E\) | 1 | |
\([0.4, 0.8)\) | Middle | \(M\) | 3 |
(0.0, 0.4) | Hard | \(H\) | 5 |
0.0 | Unsolved | \(U\) | 5 |
SFT is optimized by minimizing the negative log-likelihood loss as follows.
\[\begin{align} \label{eq:sft} \mathcal{L}_{\mathrm{sft}}=\mathbb{E} \left [-\log \mathcal{M}_{j-1}(\boldsymbol{y}_i^{+}, \boldsymbol{r}_i^{+}|\boldsymbol{x}) \right ] \end{align}\tag{2}\]
DPO is optimized to minimize the preference loss as follows.
\[\begin{align} \label{eq:dpo} \mathcal{L}_{\mathrm{dpo}}={\mathbb{E}}\left [-\log \sigma \left ( \theta(\boldsymbol{y}_i^{+}, \boldsymbol{r}_i^{+}|\boldsymbol{x}) - \theta(\boldsymbol{y}_i^{-}, \boldsymbol{r}_i^{-}|\boldsymbol{x}) \right ) \right ] \end{align}\tag{3}\] where \({(\boldsymbol{x}_i, \boldsymbol{y}_i^{+}, \boldsymbol{r}_i^{+}, \boldsymbol{y}_i^{-}, \boldsymbol{r}_i^{-})\sim \mathcal{D}_{\mathrm u}}\) and \(\theta(\cdot|\boldsymbol{x})=\log \frac{\mathcal{M}_{j-1}(\cdot|\boldsymbol{x})}{\mathcal{M}_{0}(\cdot|\boldsymbol{x})}\).
LLM Self-Training [2], [3] involves a machine learning paradigm where a LLM iteratively improves its performance by generating and leveraging its own synthetic data for further training without human intervention also referring to self-taught [5], [20], self-evolving [21], or self-improve [4]. Such self-training paradigms always involve a generation step by prompting LLMs to self-generate training data and an improve step by training the LLM on the self-generated data [2]. In the Generation step, to ensure the data quality, the generated data are always filtered and selected using rejection sampling [22] before being employed for training. These signals can be reward scores returned by a reward model [2], the binary score to judge the correctness given gold answer for mathematical or coding tasks [3], [5], [22], [23], or two scores using two reward model for process and object respectively on reasoning tasks [24]. LLM itself can be also regarded as judge or the reward model [25], [26].
In the Improve step, the selected data are utilized to train the LLM using supervised fine-tuning (SFT) [2], [3], [5], [27] or reinforcement learning [2], [20], [23]. Some studies iteratively train the policy LLM based on the previously obtained LLM [2] while some train the base LLM instead of the LLM obtained from the previous iteration [3], [5], [23], [28], [29].
Since the growth rate of high-quality data is significantly outpaced by the expansion of training datasets, synthetic data has emerged as a promising solution [30] to address the data capacity limitation and further improve LLM performance according to scaling laws [31]. Self-training paradigm employs LLM itself to generate the synthetic training data on mathematical problems [3], [5], [23]. [8] proposes to synthesize more responses for challenging questions. [32] bootstraps the diversity of math problems by re-writing the training set and further fine-tunes LLM on the enhanced training set. [33] designs several re-writing principles to enhance both questions and responses to obtain an enhanced training set. [34] proposes to synthesize more complex and diverse mathematical instructions to improve LLMs’ mathematical reasoning ability. [7] employs the Socratic-Guided Sampling (GSI) method to synthesize data to address the long-tail distribution issue during self-training. Some studies also investigate synthesizing new questions [35], [36]
Furthermore, Test Time Scaling Law [9] has attracted much attention recently, which proposed to consider allocating more computation resources in inference to generate high-quality responses. These LLMs’ self-generated data can be further used for LLM training to self-improve LLMs [2]. Many works validate that incorporating data multiply sampled on LLMs in inference can benefit LLMs and lead to further improvements such as dialogue system [37]–[39], multilingual LLMs [40], [41], and knowledge-intensive QA [42]–[44], which is a new trend for LLM training. Although few additional computation costs are required, such works can still efficiently be utilized practically with significant improvements.
GSM8K [1] 3 is a high-quality multi-step mathematical reasoning dataset of diverse grade school math word problems constructed by human problem writers, including 7,472 training samples and 1,319 test samples. All the questions take 2 to 8 steps to solve, involving a series of basic arithmetic operations to parse the final answer.
MATH [13] 4 is a challenging mathematical dataset with competition mathematics problems, consisting of 7,500 training samples and 5,000 test samples. Each problem in MATH also has a full step-by-step solution which can be used to teach models to generate answer derivations and explanations across several subjects including algebra, geometry, number theory, counting and probability, calculus, etc.
TAL-SCQ5K-EN [14] 5 are high-quality mathematical competition datasets in English created by TAL Education Group with totally 5,000 samples. The TAL-SCQ dataset split 3,000 and 2,000 questions for training and testing respectively. The questions are in the form of multiple-choice and cover mathematical topics at different levels of primary, junior high, and high school. We format all the samples in standard QA format.
[15] 6 The College dataset contains 1281 training and 2818 test college-level mathematical problems extracted from 9 textbooks across 7 domains such as linear algebra and differential equations. This dataset is to test generalization on complex mathematical reasoning in diverse domains.
[16] 7 The TheoremQA dataset contains 800 problems focused on utilizing mathematical theorems to solve challenging problems in fields such as math, physics, finance, and engineering, testing generalization on theoretical reasoning in general STEM. The dataset is collected by human experts with very high quality. We filter out the questions requiring pictures and remain 747 samples to test.
Sampling Stage: Set the sampling temperature to 0.5. For each query, sample 10 responses. Retain responses based on whether the final answer matches the ground truth. Training Stage: Combine the sampled data from the current policy model with the original dataset \(\mathcal{D}_{\mathrm o}\) to form a new training dataset, which is then used for supervised fine-tuning (SFT).
Sampling Stage: Set the sampling temperature to 0.5. During dataset construction, perform oversampling for difficult samples to ensure every sample has 4 correct responses. Training Stage: Combine the sampled data with the original dataset \(\mathcal{D}_{\mathrm o}\) to form a new training dataset, which is then used for supervised fine-tuning (SFT).
Sampling Stage: Set the sampling temperature to 0.5. During dataset construction, perform oversampling for difficult samples, applying a coefficient based on the difficulty level. More challenging samples are assigned more responses. Training Stage: Combine the sampled data with the original dataset to form a new training dataset, which is then used for supervised fine-tuning (SFT).
Sampling Stage: Set the sampling temperature to 0.5. The dataset construction is similar to SFT while we will also add negative samples into training data to conduct the DPO algorithm.
The sampling stage is similar to ReST-EM and we will also add negative samples into training data to conduct the DPO algorithm. For the multi-round DPO, we sample the self-generated training data on the model obtained from the previous training iteration but we train the model from the initial policy as in Equation 3 .
You are an excellent mathematician. Answer the following mathematical questions based on your knowledge.
### Question ###: {Question
}
### Response ###:
<think>{Reasoning steps
}</think>.
The answer is \box{Answer
}.
Experiments are conducted on Llama-3.1-8B (Llama-3.1) 8 [18].
During dataset construction, we sample the responses using 8-shot examples by setting the sampling temperature to \(T=0.5\). For response length control of DAST, challenging samples are paired with longer few-shot examples. When sampling, we will dynamically adjust the sampling number \(K\) to control the training data in each iteration comparable as in Table 3.
During training, ADAM parameter update is used in a mini-batch mode. The initial learning rate of 1e-4 is utilized with the 0.05 warm-up ratio and 0.01 weight decay of the ADAM optimizer. When training the models, we fix the training steps and ensure that all the models can be trained to convergences. Although the training data size of different methods are different, fixed training steps in total can maintain fairness for all the methods.
When decoding, the temperature is also set to 0.2 to be consistent with the sampling setting. All the models are quantified using float16 (fp16) to load and save parameters. The vLLM library [45] 9 is utilized to accelerate the generation. All the experiments are conducted on 4 \(\times\) NVIDIA A100-40GB GPUs.
Method | Iteration | Data Size | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
ICL | - | - | ||||||||
SFT | - | 15k | ||||||||
DART-Uniform | - | 60k | ||||||||
DART-Prob2Diff | - | 60k | ||||||||
ReST-EM | 1 | 50k | ||||||||
ReST-EM | 2 | 55k | ||||||||
ReST-EM | 3 | 58k | ||||||||
ReST-EM | 4 | 58k | ||||||||
DAST-P | 1 | 55k | ||||||||
DAST-P | 2 | 56k | ||||||||
DAST-P | 3 | 58k | ||||||||
DAST-P | 4 | 58k | ||||||||
DAST-L | 1 | 56k | ||||||||
DAST-L | 2 | 56k | ||||||||
DAST-L | 3 | 56k | ||||||||
DAST-L | 4 | 56k | ||||||||
DAST-S | 1 | 58k | ||||||||
DAST-S | 2 | 59k | ||||||||
DAST-S | 3 | 60k | ||||||||
DAST-S | 4 | 60k | ||||||||
DPO | - | 15k | ||||||||
mDPO | 1 | 50k | ||||||||
mDPO | 2 | 55k | ||||||||
mDPO | 3 | 58k | ||||||||
mDPO | 4 | 58k | ||||||||
DPO-D | 1 | 58k | ||||||||
DPO-D | 2 | 59k | ||||||||
DPO-D | 3 | 60k | ||||||||
DPO-D | 4 | 60k |
In this study, the challenging queries refer to the queries estimated in difficulty levels of Middle, Hard, and Unsolved↩︎
https://github.com/microsoft/unilm/tree/master/mathscale/MWPBench↩︎