""" VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱 文本生成仍然考虑因果掩码。 """ import torch.nn.functional as F from VIT import model as VIT from Text_Encoder import text_encoder as transformer import torch.nn as nn import torch from Text_Encoder import MLP class Prompt_block(nn.Module): def __init__(self,config): super(Prompt_block,self).__init__() self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device) def forward(self,text_embeddings): b,_,_=text_embeddings.size() n,dim=self.prompt_embedding.weight.size() """ new_embeddings=[] for batch,index_ in enumerate(index): text_embedding=text_embeddings[0] text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0) new_embeddings.append(text_embedding) stacked_embedding= torch.stack(new_embeddings, dim=0) return stacked_embedding """ text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1) return text_embeddings class CLIP(nn.Module): def __init__(self,config): super().__init__() self.visual=VIT self.device=config.device self.dtype=config.dtype self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device) self.max_position_embeddings=config.max_position_embeddings self.prompt_num=config.prompt_num self.transformer=transformer #增加一个prompt block self.prompt_block=Prompt_block(config) self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device)) self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device) self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device)) self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False) def encode_image(self,img,use_emotion=True): cls_embedding=self.visual(img,use_emotion) #cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512] return cls_embedding def encode_text(self,text,use_emotion=True): #预留20token的位置 b,n=text.size() index=text.argmax(dim=-1) text_embedding=self.token_embedding(text) #text_embedding=self.prompt_block(index,text_embedding) if n==self.max_position_embeddings-self.prompt_num: text_embedding=self.prompt_block(text_embedding) index=index+torch.tensor(20,device=index.device,dtype=index.dtype) position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype) text_embedding=position_embedding+text_embedding text_embedding=self.transformer(text_embedding,use_emotion=use_emotion) text_embedding=self.ln_final(text_embedding) #传入的标记有 #print(index[0],index_new[0],text_embedding.shape) text_embedding=text_embedding[torch.arange(text.shape[0]),index] text_embedding=text_embedding@self.text_projection.to(self.dtype) return text_embedding def forward(self,image,text,use_emotion=True): image_features=self.encode_image(image,use_emotion) text_features=self.encode_text(text,use_emotion) # normalized features image_features=image_features/image_features.norm(dim=-1,keepdim=True) text_features=text_features/text_features.norm(dim=-1,keepdim=True) # cosine similarity as logits logit_scale=self.logit_scale.exp() logits_per_image=logit_scale*image_features@text_features.t() logits_per_text=logits_per_image.t() # shape = [global_batch_size, global_batch_size] return logits_per_image,logits_per_text class Config: def __init__(self): self.vocab_size=49408 self.image_dim=768 self.num_patches=49 self.patch_size=32 self.hidden_size=512 self.prompt_num=20 self.max_position_embeddings=77 self.num_hidden_layers=12 self.num_attention_heads=8 self.head_size=64 self.layer_norm_eps=1e-5 self.activation_function="Quickgelu" self.dtype=torch.float16 self.device=torch.device("cuda:0") self.logit_scale_init=4.6052 self.num_virtual_tokens=20 self.token_dim=self.hidden_size self.encoder_hidden_size=self.hidden_size config=Config() model=CLIP(config) #加载预训练权重 model.load_state_dict(torch.load(r'/root/autodl-tmp/true_Emoset/EmotionCLIP_v2.bin',weights_only=True,map_location='cpu'),strict=True) """ for name, param in model.named_parameters(): if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix' print(name,"'s requires_grad turn off.") param.requires_grad = False # 冻结该参数 else: print(name,"'s requires_grad turn on.") param.requires_grad = True # 允许该参数进行训练 """ #编译模型 #model=torch.compile(model) import pickle from PIL import Image import clip with open('./preprocess.pkl','rb') as f: preprocess = pickle.load(f) with open('./tokenize.pkl','rb') as f: tokenizer=pickle.load(f) device=config.device image = preprocess(Image.open("spider.jpg")).unsqueeze(0).to(device) text = tokenizer(["This picture conveys a sense of fear", "This picture conveys a sense of contentment", "This picture conveys a sense of anger","This picture conveys a sense of sadness","This picture conveys a sense of neutral","This picture conveys a sense of disgust","This picture conveys a sense of excitement","This picture conveys a sense of awe","This picture conveys a sense of amusement"],context_length=57).to(device) #context_length=57 with torch.no_grad(): logits_per_image, logits_per_text = model(image.to(config.dtype), text) probs = logits_per_image.softmax(dim=-1).cpu().numpy() print("情感识别:",probs) #保存合并前缀的权重 import torch torch.save(model.state_dict(),'./upload/EmotionCLIP-V2.pth') #泛化性能 """ text=tokenizer(['This is a spider.','This is a dog','This is a cat'],context_length=57).to(device) with torch.no_grad(): logits_per_image, logits_per_text = model(image.to(config.dtype), text,use_emotion=False) probs = logits_per_image.softmax(dim=-1).cpu().numpy() print("泛化识别:",probs) """