mabbas125 commited on
Commit
f5feb33
·
verified ·
1 Parent(s): b174fd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -7
app.py CHANGED
@@ -1,17 +1,116 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import os
 
 
4
 
5
- # Explicitly set the cache directory
6
- # cache_dir = 'D:/huggingface_cache'
7
 
8
- # Check the current working directory
9
- print("Current working directory:", os.getcwd())
10
-
11
- # Load the tokenizer and model with cache_dir
12
  tokenizer = AutoTokenizer.from_pretrained("chatdb/natural-sql-7b")
13
  model = AutoModelForCausalLM.from_pretrained(
14
  "chatdb/natural-sql-7b",
15
  device_map="auto",
16
  torch_dtype=torch.float16,
17
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from fastapi.responses import JSONResponse
6
 
7
+ # Initialize FastAPI app
8
+ app = FastAPI()
9
 
10
+ # Load the tokenizer and model
 
 
 
11
  tokenizer = AutoTokenizer.from_pretrained("chatdb/natural-sql-7b")
12
  model = AutoModelForCausalLM.from_pretrained(
13
  "chatdb/natural-sql-7b",
14
  device_map="auto",
15
  torch_dtype=torch.float16,
16
  )
17
+
18
+ schema = """
19
+ CREATE TABLE users (
20
+ id SERIAL PRIMARY KEY,
21
+ manager_id INTEGER,
22
+ first_name VARCHAR(100) NOT NULL,
23
+ last_name VARCHAR(100) NOT NULL,
24
+ designation VARCHAR(100),
25
+ email VARCHAR(100) UNIQUE NOT NULL,
26
+ phone VARCHAR(15) UNIQUE NOT NULL,
27
+ password TEXT NOT NULL,
28
+ role VARCHAR(50) NOT NULL, -- employee, manager, hr
29
+ country VARCHAR(50) NOT NULL, -- pakistan, uae, uk
30
+ fcm_token VARCHAR(255),
31
+ image VARCHAR(255) DEFAULT '',
32
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
33
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
34
+ );
35
+
36
+ CREATE TABLE leaves_balances (
37
+ id SERIAL PRIMARY KEY,
38
+ sick_available FLOAT NOT NULL,
39
+ casual_available FLOAT NOT NULL,
40
+ wfh_available FLOAT NOT NULL,
41
+ sick_taken FLOAT NOT NULL,
42
+ casual_taken FLOAT NOT NULL,
43
+ wfh_taken FLOAT NOT NULL,
44
+ user_id INTEGER UNIQUE REFERENCES users(id) ON DELETE CASCADE,
45
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
46
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
47
+ );
48
+
49
+ CREATE TABLE leaves (
50
+ id SERIAL PRIMARY KEY,
51
+ user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
52
+ manager_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
53
+ username VARCHAR(100) NOT NULL,
54
+ type VARCHAR(50) NOT NULL, -- sick, casual, wfh
55
+ from_date TIMESTAMP NOT NULL,
56
+ to_date TIMESTAMP NOT NULL,
57
+ comments TEXT,
58
+ status VARCHAR(50) DEFAULT 'pending', -- pending, approved, rejected
59
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
60
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
61
+ );
62
+
63
+ CREATE TABLE user_otps (
64
+ id SERIAL PRIMARY KEY,
65
+ user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
66
+ otp INTEGER NOT NULL,
67
+ otp_expiry TIMESTAMP NOT NULL,
68
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
69
+ );
70
+ """
71
+
72
+ # Define the request body model using Pydantic
73
+ class QuestionRequest(BaseModel):
74
+ question: str
75
+
76
+
77
+ @app.post('/generate-sql')
78
+ async def generate_sql(request: QuestionRequest):
79
+ """
80
+ Endpoint to generate a SQL query based on a given question.
81
+ The schema is defined within the code (in the `schema` variable).
82
+ """
83
+ question = request.question
84
+ if not question:
85
+ raise HTTPException(status_code=400, detail="No question provided")
86
+
87
+ prompt = f"""
88
+ ### Task
89
+
90
+ Generate a SQL query to answer the following question: `{question}`
91
+
92
+ ### PostgreSQL Database Schema
93
+ The query will run on a database with the following schema:
94
+ {schema}
95
+
96
+ ### Answer
97
+ Here is the SQL query that answers the question: `{question}`
98
+ ```sql
99
+ """
100
+
101
+ # Generate SQL query
102
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
103
+ generated_ids = model.generate(
104
+ **inputs,
105
+ num_return_sequences=1,
106
+ eos_token_id=100001,
107
+ pad_token_id=100001,
108
+ max_new_tokens=400,
109
+ do_sample=False,
110
+ num_beams=1,
111
+ )
112
+
113
+ outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
114
+ sql_query = outputs[0].split("```sql")[-1].strip()
115
+
116
+ return JSONResponse(content={'sql_query': sql_query})