Papers
arxiv:2306.13649

On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes

Published on Jun 23, 2023
· Submitted by akhaliq on Jun 26, 2023
#1 Paper of the day
Authors:
,
,
,
,
,

Abstract

Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.

Community

Hi! Thanks for this interesting paper, do you have a pytorch implementation of the GKD loss somewhere ? Do you plan on providing one publicly in the future ?
Thanks

Hi, so like DPO , can we use GKD with unsloth?

i was trying to run GKD with unsloth, facing this issue

root/home/deeksha/envs/unsloth_env/lib/python3.10/site-packages/transformers/training_args.py:1594: FutureWarning: evaluation_strategy is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy instead
warnings.warn(
Traceback (most recent call last):
File "/root/home/deeksha/codes/student/GKD_unsloth.py", line 140, in
trainer = GKDTrainer(
File "/root/home/deeksha/envs/unsloth_env/lib/python3.10/site-packages/unsloth/trainer.py", line 203, in new_init
original_init(self, *args, **kwargs)
File "/root/home/deeksha/codes/student/unsloth_compiled_cache/UnslothGKDTrainer.py", line 805, in init
super().init(
File "/root/home/deeksha/codes/student/unsloth_compiled_cache/UnslothGKDTrainer.py", line 419, in init
super().init(
File "/root/home/deeksha/envs/unsloth_env/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/root/home/deeksha/envs/unsloth_env/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 170, in init
args = SFTConfig(**dict_args)
File "/root/home/deeksha/codes/student/unsloth_compiled_cache/UnslothSFTTrainer.py", line 251, in init
super().init(
TypeError: SFTConfig.init() got an unexpected keyword argument 'temperature'

@deeksha2695 let me have a look... can you make a comment / issue on TRL?

Yes , I have. raised an issue in unsloth. https://github.com/unslothai/unsloth/issues/1941
I will do on trl as well
This is the code which i am running

from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
import torch
import re
from unsloth.chat_templates import standardize_sharegpt
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
from unsloth.chat_templates import train_on_responses_only
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer
import pandas as pd
from datasets import Dataset
from sklearn.model_selection import train_test_split
from datasets import load_dataset, Dataset
from trl import GKDConfig, GKDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, pipeline
import numpy as np
from evaluate import load # HF中国镜像站's evaluate library

from peft import LoraConfig, get_peft_model, TaskType
from transformers import BitsAndBytesConfig

max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

student_model_path = "/home/deeksha/codes/Qwen2.5-0.5B-Instruct"
teacher_model_path = "/root/home/deeksha/codes/student/models/qwen_2.5_model_1.5B_fine"

train_data = pd.read_csv('/root/home/deeksha/codes/annotated_data/data/formatted_training_data.csv')
train_dataset = Dataset.from_pandas(train_data)

Format dataset for training

def format_example(example):
return {
"messages": [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["completion"]},
]
}

train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
train_dataset = Dataset.from_dict({"messages": train_dataset["messages"]})

Optional: Check the structure of the dataset

print(train_dataset)

test_data = pd.read_csv('/root/home/deeksha/codes/annotated_data/data/formatted_testing_data.csv')
test_data = test_data.head(10)
test_dataset = Dataset.from_pandas(test_data)
test_dataset = test_dataset.map(format_example, remove_columns=test_dataset.column_names)

eval_dataset = test_dataset
print(test_dataset)

Load the tokenizer

tokenizer = AutoTokenizer.from_pretrained(student_model_path)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_path)

student_model, tokenizer = FastLanguageModel.from_pretrained(
model_name = student_model_path, #"unsloth/Llama-3.2-3B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

training_args = GKDConfig(
output_dir="models/gkd-model-unsloth",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
evaluation_strategy="steps",
eval_steps=10,
save_steps=2,
save_total_limit=3,
logging_steps=100,
learning_rate=5e-5, # Can be increased (e.g., 1e-4 for LoRA)
num_train_epochs=1,
weight_decay=0.0, # No weight decay for LoRA
report_to="none",
seq_kd=True,
)

Initialize Trainer

trainer = GKDTrainer(
model=student_model, # LoRA-applied model
teacher_model=teacher_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)

Train the model

trainer.train()

Sign up or log in to comment

Models citing this paper 10

Browse 10 models citing this paper

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2306.13649 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2306.13649 in a Space README.md to link it from this page.

Collections including this paper 3