Keskeiset käsitteet
Neither forward Kullback-Leibler (FKL) divergence nor reverse Kullback-Leibler (RKL) divergence exhibits the expected mean-seeking or mode-seeking behaviors in knowledge distillation for large language models. Instead, both FKL and RKL converge to the same optimization objective after a sufficient number of epochs. However, due to practical constraints, large language models are rarely trained for such an extensive number of epochs. The authors propose an Adaptive Kullback-Leiber (AKL) divergence method that adaptively allocates weights to combine FKL and RKL, focusing on aligning the head and tail parts of the distributions.
Tiivistelmä
The paper starts by discussing the use of Kullback-Leibler (KL) divergence in knowledge distillation (KD) for compressing large language models (LLMs). Contrary to previous assertions that reverse KL (RKL) divergence is mode-seeking and thus preferable over the mean-seeking forward KL (FKL) divergence, the authors demonstrate both empirically and theoretically that these properties do not hold for KD in LLMs.
The key insights are:
- FKL and RKL share the same optimization objective, which is to align the logits of the student model with those of the teacher model. Both converge to the same solution after a sufficient number of epochs (more than 50 in the experiments).
- However, in practice, LLMs are rarely trained for such an extensive number of epochs (e.g., 10 epochs in prior work). The authors find that FKL focuses on the head part of the distributions, while RKL focuses on the tail part at the beginning epochs.
- Based on these observations, the authors propose a novel Adaptive Kullback-Leiber (AKL) divergence method, which adaptively allocates weights to combine FKL and RKL to better align the distributions.
The experimental results demonstrate that AKL outperforms the baseline methods on various benchmarks. Additionally, the authors use GPT-4 to evaluate the diversity and quality of the generated responses, showing that AKL can improve both aspects compared to the baselines.
Tilastot
The training dataset contains 14k samples for training, 500 for validation, and 500 for testing.
The teacher models used are GPT-2 with 1.5B parameters and LLaMA with 6.7B parameters.
The student models are GPT-2 with 120M parameters and TinyLLaMA with 1.1B parameters.