|
from fastapi import FastAPI, Request, Form, BackgroundTasks |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.staticfiles import StaticFiles |
|
import torch |
|
import os |
|
import gc |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from typing import Optional |
|
import time |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "8" |
|
os.environ["MKL_NUM_THREADS"] = "8" |
|
torch.set_num_threads(8) |
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" |
|
|
|
app = FastAPI() |
|
|
|
|
|
templates = Jinja2Templates(directory="") |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
response_cache = {} |
|
|
|
|
|
model_id = "Gauri-tr/llama-3.1-8b-sarcasm" |
|
tokenizer = None |
|
model = None |
|
|
|
|
|
def load_model(): |
|
global model, tokenizer |
|
if model is None: |
|
logger.info("Loading model and tokenizer...") |
|
start_time = time.time() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
device_map="cpu", |
|
low_cpu_mem_usage=True, |
|
) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
try: |
|
import torch._dynamo |
|
model = torch.compile(model, backend="inductor", fullgraph=True) |
|
logger.info("Using torch.compile optimization") |
|
except Exception as e: |
|
logger.warning(f"Could not use torch.compile: {e}") |
|
|
|
logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds") |
|
|
|
|
|
_ = generate_response("Hello", max_length=10) |
|
|
|
return model, tokenizer |
|
|
|
def generate_response(input_text: str, max_length: int = 30) -> str: |
|
|
|
cache_key = f"{input_text}_{max_length}" |
|
if cache_key in response_cache: |
|
logger.info("Using cached response") |
|
return response_cache[cache_key] |
|
|
|
|
|
prompt = f"""Below is an instruction that describes a task, and an input that provides further context. Write a response that appropriately completes the request. |
|
|
|
### Instruction: |
|
Respond to this message as if you were in a conversation. Be funny, sarcastic and smart. |
|
|
|
### Input: |
|
{input_text} |
|
|
|
### Response: |
|
""" |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
start_time = time.time() |
|
with torch.inference_mode(): |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_new_tokens=max_length, |
|
do_sample=True, |
|
temperature=0.8, |
|
top_p=0.9, |
|
repetition_penalty=1.2, |
|
num_beams=1, |
|
pad_token_id=tokenizer.eos_token_id, |
|
use_cache=True, |
|
) |
|
|
|
generation_time = time.time() - start_time |
|
logger.info(f"Generated response in {generation_time:.2f} seconds") |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "### Response:" in response: |
|
response = response.split("### Response:")[1].strip() |
|
|
|
|
|
response_cache[cache_key] = response |
|
|
|
|
|
gc.collect() |
|
|
|
return response |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def index(request: Request): |
|
|
|
if model is None: |
|
load_model() |
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
@app.post("/chat/") |
|
async def chat(message: str = Form(...), max_length: Optional[int] = Form(30)): |
|
response = generate_response(message, max_length) |
|
return {"response": response, "message": message} |
|
|
|
|
|
@app.get("/health") |
|
async def health(): |
|
return {"status": "ok"} |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
|
|
global tokenizer |
|
if tokenizer is None: |
|
logger.info("Pre-loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
tokenizer.pad_token = tokenizer.eos_token |