from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List from sentence_transformers import SentenceTransformer, util from transformers import pipeline import torch # Initialize FastAPI app app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define request models class ImagePair(BaseModel): image_url: str image_description: str class MatchRequest(BaseModel): image_pairs: List[ImagePair] node_texts: List[str] # Load models once at startup for performance sim_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") nli_pipeline = pipeline("zero-shot-classification", model="roberta-large-mnli") def match_image_to_node_conditional(image_description, image_url, node_texts, sim_threshold=0.8, nli_threshold=0.7): """ Matches an image (via its description) to the best candidate node text. Rules: - Reject if NLI score < nli_threshold. - If NLI score >= nli_threshold and cosine similarity >= sim_threshold, mark as high-priority. Returns a tuple: (candidate_node, image_url, match_status) """ # Ensure node_texts is not empty if not node_texts: raise ValueError("No node texts provided. Please include at least one node in your request.") # Compute embedding for the image description. desc_embedding = sim_model.encode(image_description, convert_to_tensor=True) # Compute embeddings for all node texts. node_embeddings = sim_model.encode(node_texts, convert_to_tensor=True) # If node_embeddings is empty (shouldn't happen after the check), return error. if node_embeddings.shape[0] == 0: raise ValueError("Node embeddings could not be computed. The node_texts list appears to be empty.") # Compute cosine similarities. similarities = util.pytorch_cos_sim(desc_embedding, node_embeddings) # shape [1, N] best_sim, best_idx = torch.max(similarities, dim=1) candidate_node = node_texts[best_idx.item()] cosine_score = best_sim.item() # Construct hypothesis and run NLI. hypothesis = f"This image is about {candidate_node}." nli_result = nli_pipeline(image_description, candidate_labels=[candidate_node]) nli_score = nli_result["scores"][0] print(f"Image: '{image_description}'") print(f" Candidate node: '{candidate_node}'") print(f" Cosine similarity: {cosine_score:.3f}, NLI score: {nli_score:.3f}") # Apply conditional logic. if nli_score < nli_threshold: return None, None, "Rejected: Low NLI score" else: if cosine_score >= sim_threshold: return candidate_node, image_url, "High-priority match" else: return candidate_node, image_url, "Accepted match (lower cosine similarity)" @app.post("/match") def match_images(request: MatchRequest): matched_results = [] for pair in request.image_pairs: candidate, url, status = match_image_to_node_conditional(pair.image_description, pair.image_url, request.node_texts, sim_threshold=0.8, nli_threshold=0.7) if candidate: matched_results.append({"node": candidate, "image_url": url, "status": status}) return {"matched_pairs": matched_results}