Gauri-tr commited on
Commit
dbf8811
·
verified ·
1 Parent(s): 60bf674

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -392
app.py CHANGED
@@ -1,392 +1,161 @@
1
- # app.py - FastAPI implementation for HF中国镜像站 Spaces
2
- import os
3
- import gc
4
- import time
5
- import torch
6
- from fastapi import FastAPI, Request, Form
7
- from fastapi.responses import HTMLResponse
8
- from fastapi.staticfiles import StaticFiles
9
- from fastapi.templating import Jinja2Templates
10
- from pydantic import BaseModel
11
- from typing import Optional
12
- import logging
13
- from threading import Thread
14
- from queue import Queue
15
-
16
- # Import optimized model loading utilities
17
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
18
- from peft import PeftModel, PeftConfig
19
-
20
- # Set up logging
21
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
22
- logger = logging.getLogger(__name__)
23
-
24
- # Initialize FastAPI app
25
- app = FastAPI()
26
-
27
- # Set up templates and static files
28
- templates = Jinja2Templates(directory="templates")
29
- os.makedirs("templates", exist_ok=True)
30
- os.makedirs("static", exist_ok=True)
31
- app.mount("/static", StaticFiles(directory="static"), name="static")
32
-
33
- # Create chat template
34
- with open("templates/index.html", "w") as f:
35
- f.write("""
36
- <!DOCTYPE html>
37
- <html>
38
- <head>
39
- <title>Sarcastic Assistant Chat</title>
40
- <meta name="viewport" content="width=device-width, initial-scale=1">
41
- <style>
42
- body {
43
- font-family: Arial, sans-serif;
44
- max-width: 800px;
45
- margin: 0 auto;
46
- padding: 20px;
47
- background-color: #f5f5f5;
48
- }
49
- .chat-container {
50
- border-radius: 10px;
51
- background: white;
52
- box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
53
- padding: 20px;
54
- height: 70vh;
55
- overflow-y: auto;
56
- margin-bottom: 20px;
57
- }
58
- .message {
59
- padding: 10px 15px;
60
- border-radius: 18px;
61
- margin-bottom: 10px;
62
- max-width: 80%;
63
- word-wrap: break-word;
64
- }
65
- .user {
66
- background-color: #e1ffc7;
67
- margin-left: auto;
68
- text-align: right;
69
- }
70
- .assistant {
71
- background-color: #f0f0f0;
72
- margin-right: auto;
73
- }
74
- .input-area {
75
- display: flex;
76
- gap: 10px;
77
- }
78
- #user-input {
79
- flex: 1;
80
- padding: 10px 15px;
81
- border: 1px solid #ddd;
82
- border-radius: 25px;
83
- outline: none;
84
- }
85
- button {
86
- padding: 10px 20px;
87
- background-color: #4CAF50;
88
- color: white;
89
- border: none;
90
- border-radius: 25px;
91
- cursor: pointer;
92
- }
93
- button:hover {
94
- background-color: #45a049;
95
- }
96
- #thinking {
97
- color: #666;
98
- font-style: italic;
99
- display: none;
100
- }
101
- .status-area {
102
- margin-top: 10px;
103
- color: #666;
104
- font-size: 0.9em;
105
- }
106
- </style>
107
- </head>
108
- <body>
109
- <h1>Sarcastic Assistant</h1>
110
- <div class="chat-container" id="chat-container">
111
- <div class="message assistant">
112
- Hi there! I'm your sarcastic assistant. What's on your mind today?
113
- </div>
114
- </div>
115
- <div class="input-area">
116
- <input type="text" id="user-input" placeholder="Type your message..." autocomplete="off">
117
- <button onclick="sendMessage()">Send</button>
118
- </div>
119
- <div class="status-area">
120
- <div id="thinking">Thinking...</div>
121
- <div id="model-info">Llama 3.1-8B with AdaLoRA fine-tuning</div>
122
- </div>
123
-
124
- <script>
125
- const chatContainer = document.getElementById('chat-container');
126
- const userInput = document.getElementById('user-input');
127
- const thinkingIndicator = document.getElementById('thinking');
128
-
129
- // Enable Enter key to send messages
130
- userInput.addEventListener("keyup", function(event) {
131
- if (event.key === "Enter") {
132
- sendMessage();
133
- }
134
- });
135
-
136
- async function sendMessage() {
137
- const message = userInput.value.trim();
138
- if (!message) return;
139
-
140
- // Add user message to chat
141
- addMessage(message, 'user');
142
- userInput.value = '';
143
-
144
- // Show thinking indicator
145
- thinkingIndicator.style.display = 'block';
146
-
147
- try {
148
- // Send message to API
149
- const response = await fetch('/generate', {
150
- method: 'POST',
151
- headers: {
152
- 'Content-Type': 'application/json'
153
- },
154
- body: JSON.stringify({ message: message })
155
- });
156
-
157
- if (!response.ok) {
158
- throw new Error('Network response was not ok');
159
- }
160
-
161
- const data = await response.json();
162
-
163
- // Add AI response to chat
164
- addMessage(data.response, 'assistant');
165
- } catch (error) {
166
- console.error('Error:', error);
167
- addMessage('Sorry, I had trouble processing that. Please try again.', 'assistant');
168
- } finally {
169
- // Hide thinking indicator
170
- thinkingIndicator.style.display = 'none';
171
-
172
- // Scroll to bottom
173
- chatContainer.scrollTop = chatContainer.scrollHeight;
174
- }
175
- }
176
-
177
- function addMessage(text, sender) {
178
- const messageDiv = document.createElement('div');
179
- messageDiv.classList.add('message', sender);
180
- messageDiv.textContent = text;
181
- chatContainer.appendChild(messageDiv);
182
- chatContainer.scrollTop = chatContainer.scrollHeight;
183
- }
184
- </script>
185
- </body>
186
- </html>
187
- """)
188
-
189
- # Create response queue for background processing
190
- response_queue = Queue()
191
-
192
- # Model loading - optimized for CPU
193
- class ModelManager:
194
- def __init__(self):
195
- self.model = None
196
- self.tokenizer = None
197
- self.pipeline = None
198
- self.is_loaded = False
199
- self.loading_thread = None
200
-
201
- def load_model_in_background(self):
202
- """Load model in a background thread to avoid blocking the server startup"""
203
- if self.loading_thread is None or not self.loading_thread.is_alive():
204
- self.loading_thread = Thread(target=self._load_model)
205
- self.loading_thread.daemon = True
206
- self.loading_thread.start()
207
-
208
- def _load_model(self):
209
- """Internal method to load the model with optimizations for CPU"""
210
- try:
211
- logger.info("Loading tokenizer...")
212
- # Loading base model tokenizer
213
- self.tokenizer = AutoTokenizer.from_pretrained(
214
- "meta-llama/Llama-3.1-8B-Instruct",
215
- use_fast=True
216
- )
217
- self.tokenizer.pad_token = self.tokenizer.eos_token
218
- self.tokenizer.padding_side = "right"
219
-
220
- logger.info("Loading model with CPU optimizations...")
221
- # Load the base model with CPU optimizations
222
- model_kwargs = {
223
- # Load in 8-bit for reduced memory usage
224
- "load_in_8bit": True,
225
- "device_map": "auto",
226
- # CPU optimizations
227
- "low_cpu_mem_usage": True,
228
- }
229
-
230
- # Load the base model
231
- base_model = AutoModelForCausalLM.from_pretrained(
232
- "meta-llama/Llama-3.1-8B-Instruct",
233
- **model_kwargs
234
- )
235
-
236
- logger.info("Loading adapter weights...")
237
- # Load the PEFT adapter - assuming the adapter is in the lora_model directory
238
- try:
239
- # First try with directory in current folder
240
- adapter_path = "Gauri-tr/lora_model"
241
- if not os.path.exists(adapter_path):
242
- # Check in parent directories
243
- adapter_path = "../lora_model"
244
-
245
- self.model = PeftModel.from_pretrained(
246
- base_model,
247
- adapter_path,
248
- device_map="auto"
249
- )
250
- except Exception as e:
251
- logger.error(f"Failed to load PEFT adapter: {e}")
252
- # Fallback to using base model
253
- self.model = base_model
254
- logger.warning("Using base model without adapters")
255
-
256
- logger.info("Setting up inference pipeline...")
257
- # Create pipeline with optimized settings
258
- self.pipeline = pipeline(
259
- "text-generation",
260
- model=self.model,
261
- tokenizer=self.tokenizer,
262
- max_new_tokens=64,
263
- temperature=0.8,
264
- top_p=0.9,
265
- top_k=40,
266
- repetition_penalty=1.15,
267
- pad_token_id=self.tokenizer.eos_token_id,
268
- do_sample=True
269
- )
270
-
271
- self.is_loaded = True
272
- logger.info("Model loading complete!")
273
-
274
- except Exception as e:
275
- logger.error(f"Error loading model: {e}")
276
- self.is_loaded = False
277
-
278
- def generate_response(self, user_message):
279
- """Generate a response using the loaded model"""
280
- if not self.is_loaded:
281
- return "Model is still loading, please try again in a moment."
282
-
283
- try:
284
- # Format prompt for sarcastic responses
285
- instruction = "Respond to this message as if you were in a conversation. Determine the tone and style of the conversation and reply accordingly. Be funny, sarcastic and smart as well."
286
-
287
- prompt = f"""Below is an instruction that describes a task, and an input that provides further context. Write a response that appropriately completes the request.
288
-
289
- ### Instruction:
290
- {instruction}
291
-
292
- ### Input:
293
- {user_message}
294
-
295
- ### Response:
296
- """
297
-
298
- # Generate response
299
- start_time = time.time()
300
- outputs = self.pipeline(
301
- prompt,
302
- return_full_text=False
303
- )
304
- generation_time = time.time() - start_time
305
- logger.info(f"Generation took {generation_time:.2f} seconds")
306
-
307
- # Extract response
308
- full_response = outputs[0]['generated_text']
309
-
310
- # Extract just the response part
311
- response_parts = full_response.split("### Response:")
312
- if len(response_parts) > 1:
313
- response = response_parts[1].strip()
314
- # Clean up any trailing text
315
- response = response.split("[Your Name]")[0].strip()
316
- response = response.split("---")[0].strip()
317
- return response
318
- else:
319
- return full_response.strip()
320
-
321
- except Exception as e:
322
- logger.error(f"Error generating response: {e}")
323
- return "I'm having trouble thinking right now. Can you try again?"
324
-
325
- # Create model manager
326
- model_manager = ModelManager()
327
-
328
- # Background response generation
329
- def generate_response_in_background(user_message):
330
- response = model_manager.generate_response(user_message)
331
- response_queue.put(response)
332
-
333
- # API model
334
- class MessageRequest(BaseModel):
335
- message: str
336
-
337
- # Routes
338
- @app.get("/", response_class=HTMLResponse)
339
- async def read_root(request: Request):
340
- return templates.TemplateResponse("index.html", {"request": request})
341
-
342
- @app.post("/generate")
343
- async def generate(message_request: MessageRequest):
344
- user_message = message_request.message
345
-
346
- # If model isn't loaded yet, start loading it
347
- if not model_manager.is_loaded:
348
- model_manager.load_model_in_background()
349
- return {"response": "I'm just starting up. Please try again in a moment!"}
350
-
351
- # Handle message generation
352
- thread = Thread(target=generate_response_in_background, args=(user_message,))
353
- thread.daemon = True
354
- thread.start()
355
-
356
- # Wait for response with timeout
357
- try:
358
- thread.join(timeout=30) # 30 second timeout
359
- if thread.is_alive():
360
- # If still running after timeout, return a message
361
- return {"response": "I'm thinking hard about this one! Try sending a simpler message or try again later."}
362
-
363
- # Get response from queue if available
364
- if not response_queue.empty():
365
- response = response_queue.get()
366
- return {"response": response}
367
- else:
368
- return {"response": "Sorry, I couldn't generate a response. Please try again."}
369
- except Exception as e:
370
- logger.error(f"Error in response generation: {e}")
371
- return {"response": "Something went wrong. Please try again."}
372
-
373
- # Startup event
374
- @app.on_event("startup")
375
- async def startup_event():
376
- # Start loading model in background at startup
377
- model_manager.load_model_in_background()
378
- logger.info("Starting model loading in background")
379
-
380
- # Shutdown event
381
- @app.on_event("shutdown")
382
- async def shutdown_event():
383
- # Clean up resources
384
- logger.info("Shutting down and cleaning up resources")
385
- gc.collect()
386
- if torch.cuda.is_available():
387
- torch.cuda.empty_cache()
388
-
389
- if __name__ == "__main__":
390
- import uvicorn
391
- # Run the FastAPI app
392
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, Request, Form, BackgroundTasks
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.templating import Jinja2Templates
4
+ from fastapi.staticfiles import StaticFiles
5
+ import torch
6
+ import os
7
+ import gc
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from typing import Optional
10
+ import time
11
+ import logging
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Set optimization variables
18
+ os.environ["OMP_NUM_THREADS"] = "8"
19
+ os.environ["MKL_NUM_THREADS"] = "8"
20
+ torch.set_num_threads(8)
21
+
22
+ # Initialize FastAPI
23
+ app = FastAPI()
24
+
25
+ # Load templates and static files
26
+ templates = Jinja2Templates(directory="")
27
+ app.mount("/static", StaticFiles(directory="static"), name="static")
28
+
29
+ # Disable gradient computation
30
+ torch.set_grad_enabled(False)
31
+
32
+ # Cache for responses
33
+ response_cache = {}
34
+
35
+ # Model initialization
36
+ model_id = "Gauri-tr/llama-3.1-8b-sarcasm"
37
+ tokenizer = None
38
+ model = None
39
+
40
+ # Load model in a lazy fashion
41
+ def load_model():
42
+ global model, tokenizer
43
+ if model is None:
44
+ logger.info("Loading model and tokenizer...")
45
+ start_time = time.time()
46
+
47
+ # Load tokenizer
48
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
49
+ tokenizer.pad_token = tokenizer.eos_token
50
+
51
+ # Load model with optimizations
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ model_id,
54
+ torch_dtype=torch.float32,
55
+ device_map="cpu",
56
+ low_cpu_mem_usage=True,
57
+ )
58
+
59
+ # Set to evaluation mode
60
+ model.eval()
61
+
62
+ # Try to optimize with torch.compile if available
63
+ try:
64
+ import torch._dynamo
65
+ model = torch.compile(model, backend="inductor", fullgraph=True)
66
+ logger.info("Using torch.compile optimization")
67
+ except Exception as e:
68
+ logger.warning(f"Could not use torch.compile: {e}")
69
+
70
+ logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
71
+
72
+ # Run a warmup inference
73
+ _ = generate_response("Hello", max_length=10)
74
+
75
+ return model, tokenizer
76
+
77
+ def generate_response(input_text: str, max_length: int = 30) -> str:
78
+ # Check cache first
79
+ cache_key = f"{input_text}_{max_length}"
80
+ if cache_key in response_cache:
81
+ logger.info("Using cached response")
82
+ return response_cache[cache_key]
83
+
84
+ # Format prompt
85
+ prompt = f"""Below is an instruction that describes a task, and an input that provides further context. Write a response that appropriately completes the request.
86
+
87
+ ### Instruction:
88
+ Respond to this message as if you were in a conversation. Be funny, sarcastic and smart.
89
+
90
+ ### Input:
91
+ {input_text}
92
+
93
+ ### Response:
94
+ """
95
+
96
+ # Ensure model is loaded
97
+ model, tokenizer = load_model()
98
+
99
+ # Tokenize
100
+ inputs = tokenizer(prompt, return_tensors="pt")
101
+
102
+ # Generate with optimization
103
+ start_time = time.time()
104
+ with torch.inference_mode():
105
+ outputs = model.generate(
106
+ inputs["input_ids"],
107
+ max_new_tokens=max_length,
108
+ do_sample=True,
109
+ temperature=0.8,
110
+ top_p=0.9,
111
+ repetition_penalty=1.2,
112
+ num_beams=1, # Greedy decoding for speed
113
+ pad_token_id=tokenizer.eos_token_id,
114
+ use_cache=True, # Use KV cache
115
+ )
116
+
117
+ generation_time = time.time() - start_time
118
+ logger.info(f"Generated response in {generation_time:.2f} seconds")
119
+
120
+ # Decode
121
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
122
+
123
+ # Extract response part
124
+ if "### Response:" in response:
125
+ response = response.split("### Response:")[1].strip()
126
+
127
+ # Cache the result
128
+ response_cache[cache_key] = response
129
+
130
+ # Make sure to clean up memory
131
+ gc.collect()
132
+
133
+ return response
134
+
135
+ # Define routes
136
+ @app.get("/", response_class=HTMLResponse)
137
+ async def index(request: Request):
138
+ # Start model loading in background if needed
139
+ if model is None:
140
+ load_model()
141
+ return templates.TemplateResponse("index.html", {"request": request})
142
+
143
+ @app.post("/chat/")
144
+ async def chat(message: str = Form(...), max_length: Optional[int] = Form(30)):
145
+ response = generate_response(message, max_length)
146
+ return {"response": response, "message": message}
147
+
148
+ # Health check endpoint
149
+ @app.get("/health")
150
+ async def health():
151
+ return {"status": "ok"}
152
+
153
+ # Preload model at startup
154
+ @app.on_event("startup")
155
+ async def startup_event():
156
+ # Just initialize the tokenizer at startup - model will load on first request
157
+ global tokenizer
158
+ if tokenizer is None:
159
+ logger.info("Pre-loading tokenizer...")
160
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
161
+ tokenizer.pad_token = tokenizer.eos_token