EmotionCLIP-V2 / .ipynb_checkpoints /EmotionCLIP-checkpoint.py
jiangchengchengNLP's picture
Upload my files
b11ecdd verified
"""
VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱
文本生成仍然考虑因果掩码。
"""
import torch.nn.functional as F
from VIT import model as VIT
from Text_Encoder import text_encoder as transformer
import torch.nn as nn
import torch
from Text_Encoder import MLP
class Prompt_block(nn.Module):
def __init__(self,config):
super(Prompt_block,self).__init__()
self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device)
def forward(self,text_embeddings):
b,_,_=text_embeddings.size()
n,dim=self.prompt_embedding.weight.size()
"""
new_embeddings=[]
for batch,index_ in enumerate(index):
text_embedding=text_embeddings[0]
text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0)
new_embeddings.append(text_embedding)
stacked_embedding= torch.stack(new_embeddings, dim=0)
return stacked_embedding
"""
text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1)
return text_embeddings
class CLIP(nn.Module):
def __init__(self,config):
super().__init__()
self.visual=VIT
self.device=config.device
self.dtype=config.dtype
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
self.max_position_embeddings=config.max_position_embeddings
self.prompt_num=config.prompt_num
self.transformer=transformer
#增加一个prompt block
self.prompt_block=Prompt_block(config)
self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device))
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False)
def encode_image(self,img,use_emotion=True):
cls_embedding=self.visual(img,use_emotion)
#cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512]
return cls_embedding
def encode_text(self,text,use_emotion=True):
#预留20token的位置
b,n=text.size()
index=text.argmax(dim=-1)
text_embedding=self.token_embedding(text)
#text_embedding=self.prompt_block(index,text_embedding)
if n==self.max_position_embeddings-self.prompt_num:
text_embedding=self.prompt_block(text_embedding)
index=index+torch.tensor(20,device=index.device,dtype=index.dtype)
position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype)
text_embedding=position_embedding+text_embedding
text_embedding=self.transformer(text_embedding,use_emotion=use_emotion)
text_embedding=self.ln_final(text_embedding)
#传入的标记有
#print(index[0],index_new[0],text_embedding.shape)
text_embedding=text_embedding[torch.arange(text.shape[0]),index]
[email protected]_projection.to(self.dtype)
return text_embedding
def forward(self,image,text,use_emotion=True):
image_features=self.encode_image(image,use_emotion)
text_features=self.encode_text(text,use_emotion)
# normalized features
image_features=image_features/image_features.norm(dim=-1,keepdim=True)
text_features=text_features/text_features.norm(dim=-1,keepdim=True)
# cosine similarity as logits
logit_scale=self.logit_scale.exp()
logits_per_image=logit_scale*image_features@text_features.t()
logits_per_text=logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image,logits_per_text
class Config:
def __init__(self):
self.vocab_size=49408
self.image_dim=768
self.num_patches=49
self.patch_size=32
self.hidden_size=512
self.prompt_num=20
self.max_position_embeddings=77
self.num_hidden_layers=12
self.num_attention_heads=8
self.head_size=64
self.layer_norm_eps=1e-5
self.activation_function="Quickgelu"
self.dtype=torch.float16
self.device=torch.device("cuda:0")
self.logit_scale_init=4.6052
self.num_virtual_tokens=20
self.token_dim=self.hidden_size
self.encoder_hidden_size=self.hidden_size
config=Config()
model=CLIP(config)
#加载预训练权重
model.load_state_dict(torch.load(r'/root/autodl-tmp/true_Emoset/EmotionCLIP_v2.bin',weights_only=True,map_location='cpu'),strict=True)
"""
for name, param in model.named_parameters():
if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix'
print(name,"'s requires_grad turn off.")
param.requires_grad = False # 冻结该参数
else:
print(name,"'s requires_grad turn on.")
param.requires_grad = True # 允许该参数进行训练
"""
#编译模型
#model=torch.compile(model)
import pickle
from PIL import Image
import clip
with open('./preprocess.pkl','rb') as f:
preprocess = pickle.load(f)
with open('./tokenize.pkl','rb') as f:
tokenizer=pickle.load(f)
device=config.device
image = preprocess(Image.open("spider.jpg")).unsqueeze(0).to(device)
text = tokenizer(["This picture conveys a sense of fear", "This picture conveys a sense of contentment", "This picture conveys a sense of anger","This picture conveys a sense of sadness","This picture conveys a sense of neutral","This picture conveys a sense of disgust","This picture conveys a sense of excitement","This picture conveys a sense of awe","This picture conveys a sense of amusement"],context_length=57).to(device)
#context_length=57
with torch.no_grad():
logits_per_image, logits_per_text = model(image.to(config.dtype), text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("情感识别:",probs)
#保存合并前缀的权重
import torch
torch.save(model.state_dict(),'./upload/EmotionCLIP-V2.pth')
#泛化性能
"""
text=tokenizer(['This is a spider.','This is a dog','This is a cat'],context_length=57).to(device)
with torch.no_grad():
logits_per_image, logits_per_text = model(image.to(config.dtype), text,use_emotion=False)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("泛化识别:",probs)
"""