Amitz244 commited on
Commit
71df5d4
·
verified ·
1 Parent(s): dbe6a3d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +36 -14
README.md CHANGED
@@ -44,23 +44,45 @@ To use the model for inference:
44
  from torchvision import transforms
45
  import torch
46
  from PIL import Image
 
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
48
  # Load model
49
  model = torch.load("EmoSet_clip_Lora_16.0R_8.0alphaLora_32_batch_0.0001_headmlp.pth").to(device).eval()
50
- # Load an image
51
- image = Image.open("image_path.jpg").convert("RGB")
52
- # Preprocess and predict
53
- def Emo_preprocess():
 
 
 
 
 
 
 
 
 
 
 
54
  transform = transforms.Compose([
55
- transforms.Resize(224),
56
- transforms.CenterCrop(size=(224,224)),
57
- transforms.ToTensor(),
58
- # Note: The model normalizes the image inside the forward pass
59
- # using mean = (0.48145466, 0.4578275, 0.40821073) and
60
- # std = (0.26862954, 0.26130258, 0.27577711)
61
- ])
62
  return transform
63
- image = Emo_preprocess()(image).unsqueeze(0).to(device)
 
 
 
 
 
64
  with torch.no_grad():
65
- emo_label = model(image).item()
66
- print(f"Predicted Emotion: {emo_label}")
 
 
 
 
 
44
  from torchvision import transforms
45
  import torch
46
  from PIL import Image
47
+
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
  # Load model
51
  model = torch.load("EmoSet_clip_Lora_16.0R_8.0alphaLora_32_batch_0.0001_headmlp.pth").to(device).eval()
52
+
53
+ # Emotion label mapping
54
+ idx2label = {
55
+ 0: "amusement",
56
+ 1: "awe",
57
+ 2: "contentment",
58
+ 3: "excitement",
59
+ 4: "anger",
60
+ 5: "disgust",
61
+ 6: "fear",
62
+ 7: "sadness"
63
+ }
64
+
65
+ # Preprocessing function
66
+ def emo_preprocess():
67
  transform = transforms.Compose([
68
+ transforms.Resize(224),
69
+ transforms.CenterCrop(size=(224, 224)),
70
+ transforms.ToTensor(),
71
+ # Note: The model normalizes the image inside the forward pass
72
+ # using mean = (0.48145466, 0.4578275, 0.40821073) and
73
+ # std = (0.26862954, 0.26130258, 0.27577711)
74
+ ])
75
  return transform
76
+
77
+ # Load an image
78
+ image = Image.open("image_path.jpg").convert("RGB")
79
+ image = emo_preprocess()(image).unsqueeze(0).to(device)
80
+
81
+ # Run inference
82
  with torch.no_grad():
83
+ outputs = model(image)
84
+ _, predicted = outputs.max(1) # Get the class index
85
+
86
+ # Get emotion label
87
+ predicted_emotion = idx2label[predicted.item()]
88
+ print(f"Predicted Emotion: {predicted_emotion}")