from fastapi import FastAPI, Request, Form, BackgroundTasks from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles import torch import os import gc from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Optional import time import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set optimization variables os.environ["OMP_NUM_THREADS"] = "8" os.environ["MKL_NUM_THREADS"] = "8" torch.set_num_threads(8) os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Initialize FastAPI app = FastAPI() # Load templates and static files templates = Jinja2Templates(directory="") # Disable gradient computation torch.set_grad_enabled(False) # Cache for responses response_cache = {} # Model initialization model_id = "Gauri-tr/llama-3.1-8b-sarcasm" tokenizer = None model = None # Load model in a lazy fashion def load_model(): global model, tokenizer if model is None: logger.info("Loading model and tokenizer...") start_time = time.time() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token # Load model with optimizations model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, device_map="cpu", low_cpu_mem_usage=True, ) # Set to evaluation mode model.eval() # Try to optimize with torch.compile if available try: import torch._dynamo model = torch.compile(model, backend="inductor", fullgraph=True) logger.info("Using torch.compile optimization") except Exception as e: logger.warning(f"Could not use torch.compile: {e}") logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds") # Run a warmup inference _ = generate_response("Hello", max_length=10) return model, tokenizer def generate_response(input_text: str, max_length: int = 30) -> str: # Check cache first cache_key = f"{input_text}_{max_length}" if cache_key in response_cache: logger.info("Using cached response") return response_cache[cache_key] # Format prompt prompt = f"""Below is an instruction that describes a task, and an input that provides further context. Write a response that appropriately completes the request. ### Instruction: Respond to this message as if you were in a conversation. Be funny, sarcastic and smart. ### Input: {input_text} ### Response: """ # Ensure model is loaded model, tokenizer = load_model() # Tokenize inputs = tokenizer(prompt, return_tensors="pt") # Generate with optimization start_time = time.time() with torch.inference_mode(): outputs = model.generate( inputs["input_ids"], max_new_tokens=max_length, do_sample=True, temperature=0.8, top_p=0.9, repetition_penalty=1.2, num_beams=1, # Greedy decoding for speed pad_token_id=tokenizer.eos_token_id, use_cache=True, # Use KV cache ) generation_time = time.time() - start_time logger.info(f"Generated response in {generation_time:.2f} seconds") # Decode response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract response part if "### Response:" in response: response = response.split("### Response:")[1].strip() # Cache the result response_cache[cache_key] = response # Make sure to clean up memory gc.collect() return response # Define routes @app.get("/", response_class=HTMLResponse) async def index(request: Request): # Start model loading in background if needed if model is None: load_model() return templates.TemplateResponse("index.html", {"request": request}) @app.post("/chat/") async def chat(message: str = Form(...), max_length: Optional[int] = Form(30)): response = generate_response(message, max_length) return {"response": response, "message": message} # Health check endpoint @app.get("/health") async def health(): return {"status": "ok"} # Preload model at startup @app.on_event("startup") async def startup_event(): # Just initialize the tokenizer at startup - model will load on first request global tokenizer if tokenizer is None: logger.info("Pre-loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token