Model Details

We employ Llama3-Instruct (8B) as one of the base models to evaluate our proposed Reward-Driven Selective Penalization for Preference Alignment Optimization (RSPO) method. The model is trained for one epoch on the Llama3-UltraFeedback dataset using (RSPO) method.

How to use

Transformers AutoModelForCausalLM

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "li11111/Llama3-Instruct-8B-RSPO"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = model.generate(
    input_ids,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

Experiment Parameters

Parameter Llama-3-Instruct
GPU 8×Ascend910B
beta 0.01
batch 128
learning_rate 7e-7
max_prompt_length 512
max_length 1024
num_train_epochs 1
torch_dtype bfloat16
warmup_ratio 0.1
β_w 0.01
β_l 0.1
λ 0.1

Training Data

We use the princeton-nlp/llama3-ultrafeedback dataset created by princeton-nlp team to train the Llama3 Instruct models. The UltraFeedback dataset is used to provide prompts, and the chosen and rejected response pairs (yw, yl) are regenerated using the SFT models. For each prompt x, five responses are generated with the SFT model using a sampling temperature of 0.8. The responses are then scored using llm-blender/PairRM , with the highest-scoring response selected as yw and the lowest-scoring one as yl.

Benchmarks

Method AlpacaEval 2.0
LC WR Avg. Len
RSPO 45.0 42.5 1870
Downloads last month
22
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for li11111/Llama3-Instruct-8B-RSPO

Quantizations
1 model