from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline import os import uvicorn import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set cache directory to a writable location cache_dir = "/tmp/hf_cache" os.environ["HF_HOME"] = cache_dir os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir # Create the cache directory if it doesn't exist if not os.path.exists(cache_dir): os.makedirs(cache_dir, exist_ok=True) app = FastAPI() # Add CORS middleware to allow frontend requests app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory="static"), name="static") # Load the zero-shot classification model with explicit cache directory logger.info("Loading the model...") try: classifier = pipeline( "zero-shot-classification", model="UBC-NLP/ARBERTv2", # Switch to a better Arabic model tokenizer="UBC-NLP/ARBERTv2", cache_dir=cache_dir ) logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise @app.get("/") async def index(): logger.info("Serving index.html") return FileResponse("static/index.html") @app.post("/classify") async def classify_text(data: dict): logger.info(f"Received classify request with data: {data}") try: text = data.get("document") labels = data.get("labels") if not text or not labels: logger.warning("Missing text or labels in request") return {"error": "Please provide both text and labels"}, 400 # Convert labels to list if it's a string if isinstance(labels, str): labels = [label.strip() for label in labels.split(",") if label.strip()] logger.info(f"Classifying text: {text[:50]}... with labels: {labels}") result = classifier(text, labels, multi_label=False) logger.info(f"Classification result: {result}") return {"labels": result["labels"], "scores": result["scores"]} except Exception as e: logger.error(f"Error during classification: {str(e)}") return {"error": str(e)}, 500 if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)