|
""" |
|
VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱 |
|
|
|
文本生成仍然考虑因果掩码。 |
|
""" |
|
import torch.nn.functional as F |
|
from VIT import model as VIT |
|
from Text_Encoder import text_encoder as transformer |
|
import torch.nn as nn |
|
import torch |
|
from Text_Encoder import MLP |
|
|
|
class Prompt_block(nn.Module): |
|
def __init__(self,config): |
|
super(Prompt_block,self).__init__() |
|
self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device) |
|
def forward(self,text_embeddings): |
|
b,_,_=text_embeddings.size() |
|
n,dim=self.prompt_embedding.weight.size() |
|
""" |
|
new_embeddings=[] |
|
for batch,index_ in enumerate(index): |
|
text_embedding=text_embeddings[0] |
|
text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0) |
|
new_embeddings.append(text_embedding) |
|
stacked_embedding= torch.stack(new_embeddings, dim=0) |
|
return stacked_embedding |
|
""" |
|
text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1) |
|
return text_embeddings |
|
|
|
|
|
|
|
|
|
|
|
class CLIP(nn.Module): |
|
def __init__(self,config): |
|
super().__init__() |
|
self.visual=VIT |
|
self.device=config.device |
|
self.dtype=config.dtype |
|
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device) |
|
self.max_position_embeddings=config.max_position_embeddings |
|
self.prompt_num=config.prompt_num |
|
self.transformer=transformer |
|
|
|
self.prompt_block=Prompt_block(config) |
|
self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device)) |
|
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device) |
|
self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device)) |
|
self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False) |
|
def encode_image(self,img,use_emotion=True): |
|
cls_embedding=self.visual(img,use_emotion) |
|
|
|
return cls_embedding |
|
def encode_text(self,text,use_emotion=True): |
|
|
|
b,n=text.size() |
|
index=text.argmax(dim=-1) |
|
text_embedding=self.token_embedding(text) |
|
|
|
if n==self.max_position_embeddings-self.prompt_num: |
|
text_embedding=self.prompt_block(text_embedding) |
|
index=index+torch.tensor(20,device=index.device,dtype=index.dtype) |
|
position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype) |
|
text_embedding=position_embedding+text_embedding |
|
text_embedding=self.transformer(text_embedding,use_emotion=use_emotion) |
|
text_embedding=self.ln_final(text_embedding) |
|
|
|
|
|
text_embedding=text_embedding[torch.arange(text.shape[0]),index] |
|
[email protected]_projection.to(self.dtype) |
|
|
|
return text_embedding |
|
|
|
def forward(self,image,text,use_emotion=True): |
|
image_features=self.encode_image(image,use_emotion) |
|
text_features=self.encode_text(text,use_emotion) |
|
|
|
image_features=image_features/image_features.norm(dim=-1,keepdim=True) |
|
text_features=text_features/text_features.norm(dim=-1,keepdim=True) |
|
|
|
logit_scale=self.logit_scale.exp() |
|
logits_per_image=logit_scale*image_features@text_features.t() |
|
logits_per_text=logits_per_image.t() |
|
|
|
return logits_per_image,logits_per_text |
|
|
|
class Config: |
|
def __init__(self): |
|
self.vocab_size=49408 |
|
self.image_dim=768 |
|
self.num_patches=49 |
|
self.patch_size=32 |
|
self.hidden_size=512 |
|
self.prompt_num=20 |
|
self.max_position_embeddings=77 |
|
self.num_hidden_layers=12 |
|
self.num_attention_heads=8 |
|
self.head_size=64 |
|
self.layer_norm_eps=1e-5 |
|
self.activation_function="Quickgelu" |
|
self.dtype=torch.float16 |
|
self.device=torch.device("cuda:0") |
|
self.logit_scale_init=4.6052 |
|
self.num_virtual_tokens=20 |
|
self.token_dim=self.hidden_size |
|
self.encoder_hidden_size=self.hidden_size |
|
|
|
config=Config() |
|
model=CLIP(config) |
|
|
|
model.load_state_dict(torch.load(r'/root/autodl-tmp/true_Emoset/EmotionCLIP_v2.bin',weights_only=True,map_location='cpu'),strict=True) |
|
""" |
|
for name, param in model.named_parameters(): |
|
if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix' |
|
print(name,"'s requires_grad turn off.") |
|
param.requires_grad = False # 冻结该参数 |
|
else: |
|
print(name,"'s requires_grad turn on.") |
|
param.requires_grad = True # 允许该参数进行训练 |
|
""" |
|
|
|
|
|
|
|
import pickle |
|
from PIL import Image |
|
import clip |
|
with open('./preprocess.pkl','rb') as f: |
|
preprocess = pickle.load(f) |
|
with open('./tokenize.pkl','rb') as f: |
|
tokenizer=pickle.load(f) |
|
device=config.device |
|
image = preprocess(Image.open("spider.jpg")).unsqueeze(0).to(device) |
|
text = tokenizer(["This picture conveys a sense of fear", "This picture conveys a sense of contentment", "This picture conveys a sense of anger","This picture conveys a sense of sadness","This picture conveys a sense of neutral","This picture conveys a sense of disgust","This picture conveys a sense of excitement","This picture conveys a sense of awe","This picture conveys a sense of amusement"],context_length=57).to(device) |
|
|
|
with torch.no_grad(): |
|
logits_per_image, logits_per_text = model(image.to(config.dtype), text) |
|
probs = logits_per_image.softmax(dim=-1).cpu().numpy() |
|
print("情感识别:",probs) |
|
|
|
import torch |
|
torch.save(model.state_dict(),'./upload/EmotionCLIP-V2.pth') |
|
|
|
""" |
|
text=tokenizer(['This is a spider.','This is a dog','This is a cat'],context_length=57).to(device) |
|
with torch.no_grad(): |
|
logits_per_image, logits_per_text = model(image.to(config.dtype), text,use_emotion=False) |
|
probs = logits_per_image.softmax(dim=-1).cpu().numpy() |
|
print("泛化识别:",probs) |
|
""" |
|
|