Text Generation
English
instruction-following
reasoning
gem-1o / generate.py
comethrusws's picture
Commit #1: GEM_1o_Aug trained
d18eb09 verified
import torch
from models.gem_model import GEM
from utils.data_preprocessing import load_tokenizer
from configs.config import MODEL_CONFIG
def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7):
device = torch.device(MODEL_CONFIG['DEVICE'])
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated = model.generate(input_ids, max_length=max_length, temperature=temperature)
return tokenizer.decode(generated[0], skip_special_tokens=True)
def main():
device = torch.device(MODEL_CONFIG['DEVICE'])
tokenizer = load_tokenizer()
model = GEM(
vocab_size=MODEL_CONFIG['VOCAB_SIZE'],
d_model=MODEL_CONFIG['D_MODEL'],
n_heads=MODEL_CONFIG['N_HEADS'],
d_ff=MODEL_CONFIG['D_FF'],
n_layers=MODEL_CONFIG['N_LAYERS'],
max_seq_len=MODEL_CONFIG['MAX_SEQ_LEN'],
dropout=MODEL_CONFIG['DROPOUT']
).to(device)
checkpoint = torch.load('final_model/model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
prompt = "Once upon a time"
generated_text = generate_text(model, tokenizer, prompt, max_length=100)
print(f"Generated text:\n{generated_text}")
if __name__ == "__main__":
main()