Amitz244's picture
Update README.md
78afd29 verified
|
raw
history blame
2.66 kB
metadata
language:
  - en
base_model:
  - openai/clip-vit-large-patch14
tags:
  - emotion_prediction
  - VEA
  - computer_vision
  - perceptual_tasks
  - CLIP
  - EmoSet

Don’t Judge Before You CLIP: Visual Emotion Analysis Model

This model is part of our paper:
"Don’t Judge Before You CLIP: A Unified Approach for Perceptual Tasks"
It was trained on the EmoSet dataset to predict emotion class.

Model Overview

Visual perceptual tasks, such as visual emotion analysis, aim to estimate how humans perceive and interpret images. Unlike objective tasks (e.g., object recognition), these tasks rely on subjective human judgment, making labeled data scarce.

Our approach leverages CLIP as a prior for perceptual tasks, inspired by cognitive research showing that CLIP correlates well with human judgment. This suggests that CLIP implicitly captures human biases, emotions, and preferences. We fine-tune CLIP minimally using LoRA and incorporate an MLP head to adapt it to each specific task.

Training Details

  • Dataset: EmoSet
  • Architecture: CLIP Vision Encoder (ViT-L/14) with LoRA adaptation
  • Loss Function: Cross Entropy Loss
  • Optimizer: AdamW
  • Learning Rate: 0.0001
  • Batch Size: 32

Performance

The model was trained and evaluated on the EmoSet dataset, following the standard dataset splits. Our method achieves state-of-the-art performance compared to existing approaches, as described in our paper.

Usage

To use the model for inference:

from torchvision import transforms
import torch
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = torch.load("EmoSet_clip_Lora_16.0R_8.0alphaLora_32_batch_0.0001_headmlp.pth").to(device).eval()

# Emotion label mapping
idx2label = {
    0: "amusement",
    1: "awe",
    2: "contentment",
    3: "excitement",
    4: "anger",
    5: "disgust",
    6: "fear",
    7: "sadness"
}

# Preprocessing function
def emo_preprocess():
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(size=(224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
    ])
    return transform

# Load an image
image = Image.open("image_path.jpg").convert("RGB")
image = emo_preprocess()(image).unsqueeze(0).to(device)

# Run inference
with torch.no_grad():
    outputs = model(image)
    _, predicted = outputs.max(1)  # Get the class index

# Get emotion label
predicted_emotion = idx2label[predicted.item()]
print(f"Predicted Emotion: {predicted_emotion}")