Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from transformers import pipeline | |
import os | |
import uvicorn | |
app = FastAPI() | |
# Mount the templates directory for serving HTML | |
templates = Jinja2Templates(directory="templates") | |
# Load the zero-shot classification model | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
# Route to serve index.html | |
async def index(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
# Route to handle text classification requests | |
async def classify_text(data: dict): | |
try: | |
text = data.get("document") | |
labels = data.get("labels") | |
if not text or not labels: | |
return {"error": "Please provide both text and labels"}, 400 | |
# Perform classification | |
result = classifier(text, labels, multi_label=False) | |
response = { | |
"labels": result["labels"], | |
"scores": result["scores"] | |
} | |
return response, 200 | |
except Exception as e: | |
return {"error": str(e)}, 500 | |
# Run the app on HF中国镜像站's required port | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |