|
|
|
import torch |
|
from models.gem_model import GEM |
|
from utils.data_preprocessing import load_tokenizer |
|
from configs.config import MODEL_CONFIG |
|
|
|
def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7): |
|
device = torch.device(MODEL_CONFIG['DEVICE']) |
|
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) |
|
generated = model.generate(input_ids, max_length=max_length, temperature=temperature) |
|
return tokenizer.decode(generated[0], skip_special_tokens=True) |
|
|
|
def main(): |
|
device = torch.device(MODEL_CONFIG['DEVICE']) |
|
|
|
tokenizer = load_tokenizer() |
|
|
|
model = GEM( |
|
vocab_size=MODEL_CONFIG['VOCAB_SIZE'], |
|
d_model=MODEL_CONFIG['D_MODEL'], |
|
n_heads=MODEL_CONFIG['N_HEADS'], |
|
d_ff=MODEL_CONFIG['D_FF'], |
|
n_layers=MODEL_CONFIG['N_LAYERS'], |
|
max_seq_len=MODEL_CONFIG['MAX_SEQ_LEN'], |
|
dropout=MODEL_CONFIG['DROPOUT'] |
|
).to(device) |
|
|
|
checkpoint = torch.load('final_model/model.pt') |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.eval() |
|
|
|
prompt = "Once upon a time" |
|
generated_text = generate_text(model, tokenizer, prompt, max_length=100) |
|
print(f"Generated text:\n{generated_text}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|