File size: 4,797 Bytes
dbf8811 4356b24 dbf8811 5a665f0 dbf8811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from fastapi import FastAPI, Request, Form, BackgroundTasks
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
import torch
import os
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
import time
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set optimization variables
os.environ["OMP_NUM_THREADS"] = "8"
os.environ["MKL_NUM_THREADS"] = "8"
torch.set_num_threads(8)
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
# Initialize FastAPI
app = FastAPI()
# Load templates and static files
templates = Jinja2Templates(directory="")
# Disable gradient computation
torch.set_grad_enabled(False)
# Cache for responses
response_cache = {}
# Model initialization
model_id = "Gauri-tr/llama-3.1-8b-sarcasm"
tokenizer = None
model = None
# Load model in a lazy fashion
def load_model():
global model, tokenizer
if model is None:
logger.info("Loading model and tokenizer...")
start_time = time.time()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Load model with optimizations
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True,
)
# Set to evaluation mode
model.eval()
# Try to optimize with torch.compile if available
try:
import torch._dynamo
model = torch.compile(model, backend="inductor", fullgraph=True)
logger.info("Using torch.compile optimization")
except Exception as e:
logger.warning(f"Could not use torch.compile: {e}")
logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
# Run a warmup inference
_ = generate_response("Hello", max_length=10)
return model, tokenizer
def generate_response(input_text: str, max_length: int = 30) -> str:
# Check cache first
cache_key = f"{input_text}_{max_length}"
if cache_key in response_cache:
logger.info("Using cached response")
return response_cache[cache_key]
# Format prompt
prompt = f"""Below is an instruction that describes a task, and an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Respond to this message as if you were in a conversation. Be funny, sarcastic and smart.
### Input:
{input_text}
### Response:
"""
# Ensure model is loaded
model, tokenizer = load_model()
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
# Generate with optimization
start_time = time.time()
with torch.inference_mode():
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=max_length,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.2,
num_beams=1, # Greedy decoding for speed
pad_token_id=tokenizer.eos_token_id,
use_cache=True, # Use KV cache
)
generation_time = time.time() - start_time
logger.info(f"Generated response in {generation_time:.2f} seconds")
# Decode
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract response part
if "### Response:" in response:
response = response.split("### Response:")[1].strip()
# Cache the result
response_cache[cache_key] = response
# Make sure to clean up memory
gc.collect()
return response
# Define routes
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
# Start model loading in background if needed
if model is None:
load_model()
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/chat/")
async def chat(message: str = Form(...), max_length: Optional[int] = Form(30)):
response = generate_response(message, max_length)
return {"response": response, "message": message}
# Health check endpoint
@app.get("/health")
async def health():
return {"status": "ok"}
# Preload model at startup
@app.on_event("startup")
async def startup_event():
# Just initialize the tokenizer at startup - model will load on first request
global tokenizer
if tokenizer is None:
logger.info("Pre-loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token |