File size: 4,797 Bytes
dbf8811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4356b24
dbf8811
 
 
 
 
5a665f0
dbf8811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from fastapi import FastAPI, Request, Form, BackgroundTasks
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
import torch
import os
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
import time
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set optimization variables
os.environ["OMP_NUM_THREADS"] = "8"
os.environ["MKL_NUM_THREADS"] = "8"
torch.set_num_threads(8)

os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
# Initialize FastAPI
app = FastAPI()

# Load templates and static files
templates = Jinja2Templates(directory="")


# Disable gradient computation
torch.set_grad_enabled(False)

# Cache for responses
response_cache = {}

# Model initialization
model_id = "Gauri-tr/llama-3.1-8b-sarcasm"
tokenizer = None
model = None

# Load model in a lazy fashion
def load_model():
    global model, tokenizer
    if model is None:
        logger.info("Loading model and tokenizer...")
        start_time = time.time()
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.pad_token = tokenizer.eos_token
        
        # Load model with optimizations
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,
            device_map="cpu", 
            low_cpu_mem_usage=True,
        )
        
        # Set to evaluation mode
        model.eval()
        
        # Try to optimize with torch.compile if available
        try:
            import torch._dynamo
            model = torch.compile(model, backend="inductor", fullgraph=True)
            logger.info("Using torch.compile optimization")
        except Exception as e:
            logger.warning(f"Could not use torch.compile: {e}")
        
        logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
        
        # Run a warmup inference
        _ = generate_response("Hello", max_length=10)
        
    return model, tokenizer

def generate_response(input_text: str, max_length: int = 30) -> str:
    # Check cache first
    cache_key = f"{input_text}_{max_length}"
    if cache_key in response_cache:
        logger.info("Using cached response")
        return response_cache[cache_key]
    
    # Format prompt
    prompt = f"""Below is an instruction that describes a task, and an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Respond to this message as if you were in a conversation. Be funny, sarcastic and smart.

### Input:
{input_text}

### Response:
"""
    
    # Ensure model is loaded
    model, tokenizer = load_model()
    
    # Tokenize 
    inputs = tokenizer(prompt, return_tensors="pt")
    
    # Generate with optimization
    start_time = time.time()
    with torch.inference_mode():
        outputs = model.generate(
            inputs["input_ids"],
            max_new_tokens=max_length,
            do_sample=True,
            temperature=0.8,
            top_p=0.9,
            repetition_penalty=1.2,
            num_beams=1,  # Greedy decoding for speed
            pad_token_id=tokenizer.eos_token_id,
            use_cache=True,  # Use KV cache
        )
    
    generation_time = time.time() - start_time
    logger.info(f"Generated response in {generation_time:.2f} seconds")
    
    # Decode
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract response part
    if "### Response:" in response:
        response = response.split("### Response:")[1].strip()
    
    # Cache the result
    response_cache[cache_key] = response
    
    # Make sure to clean up memory
    gc.collect()
    
    return response

# Define routes
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    # Start model loading in background if needed
    if model is None:
        load_model()
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/chat/")
async def chat(message: str = Form(...), max_length: Optional[int] = Form(30)):
    response = generate_response(message, max_length)
    return {"response": response, "message": message}

# Health check endpoint
@app.get("/health")
async def health():
    return {"status": "ok"}

# Preload model at startup
@app.on_event("startup")
async def startup_event():
    # Just initialize the tokenizer at startup - model will load on first request
    global tokenizer
    if tokenizer is None:
        logger.info("Pre-loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.pad_token = tokenizer.eos_token