On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes
Abstract
Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
Community
Hi! Thanks for this interesting paper, do you have a pytorch implementation of the GKD loss somewhere ? Do you plan on providing one publicly in the future ?
Thanks
Hi, so like DPO , can we use GKD with unsloth?
i was trying to run GKD with unsloth, facing this issue
root/home/deeksha/envs/unsloth_env/lib/python3.10/site-packages/transformers/training_args.py:1594: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
Traceback (most recent call last):
File "/root/home/deeksha/codes/student/GKD_unsloth.py", line 140, in
trainer = GKDTrainer(
File "/root/home/deeksha/envs/unsloth_env/lib/python3.10/site-packages/unsloth/trainer.py", line 203, in new_init
original_init(self, *args, **kwargs)
File "/root/home/deeksha/codes/student/unsloth_compiled_cache/UnslothGKDTrainer.py", line 805, in init
super().init(
File "/root/home/deeksha/codes/student/unsloth_compiled_cache/UnslothGKDTrainer.py", line 419, in init
super().init(
File "/root/home/deeksha/envs/unsloth_env/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/root/home/deeksha/envs/unsloth_env/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 170, in init
args = SFTConfig(**dict_args)
File "/root/home/deeksha/codes/student/unsloth_compiled_cache/UnslothSFTTrainer.py", line 251, in init
super().init(
TypeError: SFTConfig.init() got an unexpected keyword argument 'temperature'
Yes , I have. raised an issue in unsloth. https://github.com/unslothai/unsloth/issues/1941
I will do on trl as well
This is the code which i am running
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
import torch
import re
from unsloth.chat_templates import standardize_sharegpt
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
from unsloth.chat_templates import train_on_responses_only
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer
import pandas as pd
from datasets import Dataset
from sklearn.model_selection import train_test_split
from datasets import load_dataset, Dataset
from trl import GKDConfig, GKDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, pipeline
import numpy as np
from evaluate import load # HF中国镜像站's evaluate library
from peft import LoraConfig, get_peft_model, TaskType
from transformers import BitsAndBytesConfig
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
student_model_path = "/home/deeksha/codes/Qwen2.5-0.5B-Instruct"
teacher_model_path = "/root/home/deeksha/codes/student/models/qwen_2.5_model_1.5B_fine"
train_data = pd.read_csv('/root/home/deeksha/codes/annotated_data/data/formatted_training_data.csv')
train_dataset = Dataset.from_pandas(train_data)
Format dataset for training
def format_example(example):
return {
"messages": [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["completion"]},
]
}
train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
train_dataset = Dataset.from_dict({"messages": train_dataset["messages"]})
Optional: Check the structure of the dataset
print(train_dataset)
test_data = pd.read_csv('/root/home/deeksha/codes/annotated_data/data/formatted_testing_data.csv')
test_data = test_data.head(10)
test_dataset = Dataset.from_pandas(test_data)
test_dataset = test_dataset.map(format_example, remove_columns=test_dataset.column_names)
eval_dataset = test_dataset
print(test_dataset)
Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(student_model_path)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_path)
student_model, tokenizer = FastLanguageModel.from_pretrained(
model_name = student_model_path, #"unsloth/Llama-3.2-3B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
training_args = GKDConfig(
output_dir="models/gkd-model-unsloth",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
evaluation_strategy="steps",
eval_steps=10,
save_steps=2,
save_total_limit=3,
logging_steps=100,
learning_rate=5e-5, # Can be increased (e.g., 1e-4 for LoRA)
num_train_epochs=1,
weight_decay=0.0, # No weight decay for LoRA
report_to="none",
seq_kd=True,
)
Initialize Trainer
trainer = GKDTrainer(
model=student_model, # LoRA-applied model
teacher_model=teacher_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
Train the model
trainer.train()
Models citing this paper 10
Browse 10 models citing this paperDatasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper