Amitz244 commited on
Commit
8f0cdb2
·
verified ·
1 Parent(s): be5e013

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -6
README.md CHANGED
@@ -43,15 +43,20 @@ import importlib.util
43
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
- # Load model class
47
  class_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="modeling.py")
48
- spec = importlib.util.spec_from_file_location("clip_lora_model", class_path)
49
- clip_lora_model = importlib.util.module_from_spec(spec)
50
- spec.loader.exec_module(clip_lora_model)
 
 
 
 
51
 
52
  # Load pretrained model
53
  model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth")
54
- model = torch.load(model_path).to(device).eval()
 
55
 
56
  # Emotion label mapping
57
  idx2label = {
@@ -82,7 +87,7 @@ image = emo_preprocess()(image).unsqueeze(0).to(device)
82
  # Run inference
83
  with torch.no_grad():
84
  outputs = model(image)
85
- _, predicted = outputs.max(1) # Get the class index
86
 
87
  # Get emotion label
88
  predicted_emotion = idx2label[predicted.item()]
 
43
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
+ # Load the model class definition dynamically
47
  class_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="modeling.py")
48
+ spec = importlib.util.spec_from_file_location("modeling", class_path)
49
+ modeling = importlib.util.module_from_spec(spec)
50
+ spec.loader.exec_module(modeling)
51
+
52
+ # initialize a model
53
+ ModelClass = modeling.clip_lora_model
54
+ model = ModelClass().to(device)
55
 
56
  # Load pretrained model
57
  model_path = hf_hub_download(repo_id="PerceptCLIP/PerceptCLIP_Emotions", filename="perceptCLIP_Emotions.pth")
58
+ model.load_state_dict(torch.load(model_path, map_location=device))
59
+ model.eval()
60
 
61
  # Emotion label mapping
62
  idx2label = {
 
87
  # Run inference
88
  with torch.no_grad():
89
  outputs = model(image)
90
+ _, predicted = outputs.max(1)
91
 
92
  # Get emotion label
93
  predicted_emotion = idx2label[predicted.item()]