from groundingdino.util.inference import load_model, load_image, predict, annotate
import cv2
import gradio as gr
import re

def dino_inference(image, text):
    print("Image: ", image)
    print("Text: ", text)
    new_path = re.sub(r'^.+/([^/]+)$', r'.asset/train/\1', image)
    new_path = re.sub(r'\.jpeg$', '.jpg', new_path)

    model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", 
    "weights/groundingdino_swint_ogc.pth")
    IMAGE_PATH = new_path
    TEXT_PROMPT = text # "apple . fruit basket . table ."
    BOX_TRESHOLD = 0.35
    TEXT_TRESHOLD = 0.25

    image_source, image = load_image(IMAGE_PATH)

    boxes, logits, phrases = predict(
        model=model,
        image=image,
        caption=TEXT_PROMPT,
        box_threshold=BOX_TRESHOLD,
        text_threshold=TEXT_TRESHOLD
    )

    annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
    print("Annotated frame shape: ", annotated_frame.shape)
    cv2.imwrite("annotated_image.jpg", annotated_frame)
    return annotated_frame

def app():
    with gr.Blocks():
        with gr.Row():
            with gr.Column():
                image = gr.Image(type="filepath", label="Image", visible=True)
                text = gr.Textbox(label="Text", placeholder="Enter text here")
                yolov10_infer = gr.Button(value="Detect Objects")
            
            with gr.Column():
                output_image = gr.Image(width=1024, height=768, type="numpy", label="Annotated Image", visible=True)
            
            def run_inference(image, text):
                return dino_inference(image, text)

            yolov10_infer.click(
                fn=run_inference,
                inputs=[image, text],
                outputs=[output_image],
            )

gradio_app = gr.Blocks()
with gradio_app:
    with gr.Row():
        with gr.Column():
            app()
if __name__ == '__main__':
    gradio_app.launch(share=True)