Model Details
We employ Llama3-Instruct (8B) as one of the base models to evaluate our proposed Reward-Driven Selective Penalization for Preference Alignment Optimization (RSPO) method. The model is trained for one epoch on the Llama3-UltraFeedback dataset using (RSPO) method.
How to use
Transformers AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "li11111/Llama3-Instruct-8B-RSPO"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))
Experiment Parameters
Parameter | Llama-3-Instruct |
---|---|
GPU |
8×Ascend910B |
beta |
0.01 |
batch |
128 |
learning_rate |
7e-7 |
max_prompt_length |
512 |
max_length |
1024 |
num_train_epochs |
1 |
torch_dtype |
bfloat16 |
warmup_ratio |
0.1 |
β_w |
0.01 |
β_l |
0.1 |
λ |
0.1 |
Training Data
We use the princeton-nlp/llama3-ultrafeedback dataset created by princeton-nlp team to train the Llama3 Instruct models. The UltraFeedback dataset is used to provide prompts, and the chosen and rejected response pairs (yw, yl) are regenerated using the SFT models. For each prompt x, five responses are generated with the SFT model using a sampling temperature of 0.8. The responses are then scored using llm-blender/PairRM , with the highest-scoring response selected as yw and the lowest-scoring one as yl.
Benchmarks
Method | AlpacaEval 2.0 | ||
---|---|---|---|
LC | WR | Avg. Len | |
RSPO | 45.0 | 42.5 | 1870 |
- Downloads last month
- 22