Upload my files
Browse files- .ipynb_checkpoints/EmotionCLIP-checkpoint.py +151 -0
- .ipynb_checkpoints/Text_Encoder-checkpoint.py +192 -0
- .ipynb_checkpoints/VIT-checkpoint.py +243 -0
- Dog sad.jpg +0 -0
- EmotionCLIP-V2.pth +3 -0
- EmotionCLIP.py +183 -0
- Text_Encoder.py +182 -0
- VIT.py +233 -0
- __pycache__/Text_Encoder.cpython-312.pyc +0 -0
- __pycache__/VIT.cpython-312.pyc +0 -0
- preprocess.pkl +3 -0
- tokenize.pkl +3 -0
.ipynb_checkpoints/EmotionCLIP-checkpoint.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱
|
3 |
+
|
4 |
+
文本生成仍然考虑因果掩码。
|
5 |
+
"""
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from VIT import model as VIT
|
8 |
+
from Text_Encoder import text_encoder as transformer
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch
|
11 |
+
from Text_Encoder import MLP
|
12 |
+
|
13 |
+
class Prompt_block(nn.Module):
|
14 |
+
def __init__(self,config):
|
15 |
+
super(Prompt_block,self).__init__()
|
16 |
+
self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device)
|
17 |
+
def forward(self,text_embeddings):
|
18 |
+
b,_,_=text_embeddings.size()
|
19 |
+
n,dim=self.prompt_embedding.weight.size()
|
20 |
+
"""
|
21 |
+
new_embeddings=[]
|
22 |
+
for batch,index_ in enumerate(index):
|
23 |
+
text_embedding=text_embeddings[0]
|
24 |
+
text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0)
|
25 |
+
new_embeddings.append(text_embedding)
|
26 |
+
stacked_embedding= torch.stack(new_embeddings, dim=0)
|
27 |
+
return stacked_embedding
|
28 |
+
"""
|
29 |
+
text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1)
|
30 |
+
return text_embeddings
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
class CLIP(nn.Module):
|
37 |
+
def __init__(self,config):
|
38 |
+
super().__init__()
|
39 |
+
self.visual=VIT
|
40 |
+
self.device=config.device
|
41 |
+
self.dtype=config.dtype
|
42 |
+
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
|
43 |
+
self.max_position_embeddings=config.max_position_embeddings
|
44 |
+
self.prompt_num=config.prompt_num
|
45 |
+
self.transformer=transformer
|
46 |
+
#增加一个prompt block
|
47 |
+
self.prompt_block=Prompt_block(config)
|
48 |
+
self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device))
|
49 |
+
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
|
50 |
+
self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
|
51 |
+
self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False)
|
52 |
+
def encode_image(self,img,use_emotion=True):
|
53 |
+
cls_embedding=self.visual(img,use_emotion)
|
54 |
+
#cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512]
|
55 |
+
return cls_embedding
|
56 |
+
def encode_text(self,text,use_emotion=True):
|
57 |
+
#预留20token的位置
|
58 |
+
b,n=text.size()
|
59 |
+
index=text.argmax(dim=-1)
|
60 |
+
text_embedding=self.token_embedding(text)
|
61 |
+
#text_embedding=self.prompt_block(index,text_embedding)
|
62 |
+
if n==self.max_position_embeddings-self.prompt_num:
|
63 |
+
text_embedding=self.prompt_block(text_embedding)
|
64 |
+
index=index+torch.tensor(20,device=index.device,dtype=index.dtype)
|
65 |
+
position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype)
|
66 |
+
text_embedding=position_embedding+text_embedding
|
67 |
+
text_embedding=self.transformer(text_embedding,use_emotion=use_emotion)
|
68 |
+
text_embedding=self.ln_final(text_embedding)
|
69 |
+
#传入的标记有
|
70 |
+
#print(index[0],index_new[0],text_embedding.shape)
|
71 |
+
text_embedding=text_embedding[torch.arange(text.shape[0]),index]
|
72 |
+
[email protected]_projection.to(self.dtype)
|
73 |
+
|
74 |
+
return text_embedding
|
75 |
+
|
76 |
+
def forward(self,image,text,use_emotion=True):
|
77 |
+
image_features=self.encode_image(image,use_emotion)
|
78 |
+
text_features=self.encode_text(text,use_emotion)
|
79 |
+
# normalized features
|
80 |
+
image_features=image_features/image_features.norm(dim=-1,keepdim=True)
|
81 |
+
text_features=text_features/text_features.norm(dim=-1,keepdim=True)
|
82 |
+
# cosine similarity as logits
|
83 |
+
logit_scale=self.logit_scale.exp()
|
84 |
+
logits_per_image=logit_scale*image_features@text_features.t()
|
85 |
+
logits_per_text=logits_per_image.t()
|
86 |
+
# shape = [global_batch_size, global_batch_size]
|
87 |
+
return logits_per_image,logits_per_text
|
88 |
+
|
89 |
+
class Config:
|
90 |
+
def __init__(self):
|
91 |
+
self.vocab_size=49408
|
92 |
+
self.image_dim=768
|
93 |
+
self.num_patches=49
|
94 |
+
self.patch_size=32
|
95 |
+
self.hidden_size=512
|
96 |
+
self.prompt_num=20
|
97 |
+
self.max_position_embeddings=77
|
98 |
+
self.num_hidden_layers=12
|
99 |
+
self.num_attention_heads=8
|
100 |
+
self.head_size=64
|
101 |
+
self.layer_norm_eps=1e-5
|
102 |
+
self.activation_function="Quickgelu"
|
103 |
+
self.dtype=torch.float16
|
104 |
+
self.device=torch.device("cuda:0")
|
105 |
+
self.logit_scale_init=4.6052
|
106 |
+
self.num_virtual_tokens=20
|
107 |
+
self.token_dim=self.hidden_size
|
108 |
+
self.encoder_hidden_size=self.hidden_size
|
109 |
+
|
110 |
+
config=Config()
|
111 |
+
model=CLIP(config)
|
112 |
+
#加载预训练权重
|
113 |
+
model.load_state_dict(torch.load(r'/root/autodl-tmp/true_Emoset/EmotionCLIP_v2.bin',weights_only=True,map_location='cpu'),strict=True)
|
114 |
+
"""
|
115 |
+
for name, param in model.named_parameters():
|
116 |
+
if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix'
|
117 |
+
print(name,"'s requires_grad turn off.")
|
118 |
+
param.requires_grad = False # 冻结该参数
|
119 |
+
else:
|
120 |
+
print(name,"'s requires_grad turn on.")
|
121 |
+
param.requires_grad = True # 允许该参数进行训练
|
122 |
+
"""
|
123 |
+
|
124 |
+
#编译模型
|
125 |
+
#model=torch.compile(model)
|
126 |
+
import pickle
|
127 |
+
from PIL import Image
|
128 |
+
import clip
|
129 |
+
with open('./preprocess.pkl','rb') as f:
|
130 |
+
preprocess = pickle.load(f)
|
131 |
+
with open('./tokenize.pkl','rb') as f:
|
132 |
+
tokenizer=pickle.load(f)
|
133 |
+
device=config.device
|
134 |
+
image = preprocess(Image.open("spider.jpg")).unsqueeze(0).to(device)
|
135 |
+
text = tokenizer(["This picture conveys a sense of fear", "This picture conveys a sense of contentment", "This picture conveys a sense of anger","This picture conveys a sense of sadness","This picture conveys a sense of neutral","This picture conveys a sense of disgust","This picture conveys a sense of excitement","This picture conveys a sense of awe","This picture conveys a sense of amusement"],context_length=57).to(device)
|
136 |
+
#context_length=57
|
137 |
+
with torch.no_grad():
|
138 |
+
logits_per_image, logits_per_text = model(image.to(config.dtype), text)
|
139 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
140 |
+
print("情感识别:",probs)
|
141 |
+
#保存合并前缀的权重
|
142 |
+
import torch
|
143 |
+
torch.save(model.state_dict(),'./upload/EmotionCLIP-V2.pth')
|
144 |
+
#泛化性能
|
145 |
+
"""
|
146 |
+
text=tokenizer(['This is a spider.','This is a dog','This is a cat'],context_length=57).to(device)
|
147 |
+
with torch.no_grad():
|
148 |
+
logits_per_image, logits_per_text = model(image.to(config.dtype), text,use_emotion=False)
|
149 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
150 |
+
print("泛化识别:",probs)
|
151 |
+
"""
|
.ipynb_checkpoints/Text_Encoder-checkpoint.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class PrefixEncoder(torch.nn.Module):
|
9 |
+
def __init__(self,config):
|
10 |
+
super(PrefixEncoder,self).__init__()
|
11 |
+
self.config=config
|
12 |
+
self.device=config.device
|
13 |
+
self.dtype=config.dtype
|
14 |
+
self.num_virtual_tokens=config.num_virtual_tokens
|
15 |
+
self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype)
|
16 |
+
self.token_dim=config.token_dim
|
17 |
+
self.encoder_hidden_size=config.encoder_hidden_size
|
18 |
+
self.num_layers=config.num_layers
|
19 |
+
self.transformer=torch.nn.Sequential(
|
20 |
+
torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
|
21 |
+
torch.nn.Tanh(),
|
22 |
+
torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
|
23 |
+
)
|
24 |
+
def forward(self,input_ids,batch_size):
|
25 |
+
input_ids=input_ids.unsqueeze(0)
|
26 |
+
prefix_embedding=self.embedding(input_ids)
|
27 |
+
prefix_embedding=self.transformer(prefix_embedding)
|
28 |
+
self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
|
29 |
+
prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
|
30 |
+
prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
|
31 |
+
prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
|
32 |
+
del self.embedding
|
33 |
+
del self.transformer
|
34 |
+
k,v=prefix_embedding.chunk(2,dim=0)
|
35 |
+
return (k.squeeze(0),v.squeeze(0))
|
36 |
+
|
37 |
+
|
38 |
+
class Transformer(nn.Module):
|
39 |
+
def __init__(self,config):
|
40 |
+
super(Transformer,self).__init__()
|
41 |
+
self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
|
42 |
+
self.prefix=PrefixEncoder(config)
|
43 |
+
prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
|
44 |
+
self.register_buffer("prefix_tokens",prefix_tokens)
|
45 |
+
def forward(self,hidden_state,use_emotion):
|
46 |
+
if use_emotion:
|
47 |
+
#print("激活text transformer prefix.")
|
48 |
+
b,n,h=hidden_state.shape
|
49 |
+
prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
|
50 |
+
for index,resblock in enumerate(self.resblocks):
|
51 |
+
hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
|
52 |
+
return hidden_state
|
53 |
+
else:
|
54 |
+
for index,resblock in enumerate(self.resblocks):
|
55 |
+
hidden_state=resblock(hidden_state)
|
56 |
+
return hidden_state
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
class ResidualAttentionBlock(nn.Module):
|
64 |
+
def __init__(self,config):
|
65 |
+
super(ResidualAttentionBlock,self).__init__()
|
66 |
+
self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
67 |
+
self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
68 |
+
#self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
|
69 |
+
self.attn=MultiHeadAttention(config)
|
70 |
+
self.mlp=MLP(config)
|
71 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
72 |
+
residual=hidden_state
|
73 |
+
hidden_state=self.ln_1(hidden_state)
|
74 |
+
hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
|
75 |
+
hidden_state=residual+hidden_state
|
76 |
+
residual=hidden_state
|
77 |
+
hidden_state=self.ln_2(hidden_state)
|
78 |
+
hidden_state=self.mlp(hidden_state)
|
79 |
+
hidden_state=residual+hidden_state
|
80 |
+
return hidden_state
|
81 |
+
|
82 |
+
class MultiHeadAttention(nn.Module):
|
83 |
+
def __init__(self,config):
|
84 |
+
super(MultiHeadAttention,self).__init__()
|
85 |
+
self.hidden_size=config.hidden_size
|
86 |
+
self.num_heads=config.num_heads
|
87 |
+
self.head_size=self.hidden_size//self.num_heads
|
88 |
+
#nn.Parameter包含weight和bias可训练参数
|
89 |
+
self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
90 |
+
self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
91 |
+
#self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
92 |
+
#self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
93 |
+
#self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
94 |
+
self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
|
95 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
96 |
+
b,n,c=hidden_state.shape
|
97 |
+
#q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
98 |
+
#k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
|
99 |
+
#v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
100 |
+
q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
|
101 |
+
if prefix_k is not None and prefix_v is not None:
|
102 |
+
#将前缀插入到序列之前
|
103 |
+
k=torch.cat((prefix_k,k),dim=1)
|
104 |
+
v=torch.cat((prefix_v,v),dim=1)
|
105 |
+
#print("model origin k :",k[:,0,0])
|
106 |
+
bk,nk,hk=k.shape
|
107 |
+
bq,nq,hq=q.shape
|
108 |
+
q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
|
109 |
+
k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
110 |
+
v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
111 |
+
attention_logits=F.scaled_dot_product_attention(q, k, v)
|
112 |
+
attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
|
113 |
+
attention_output=self.out_proj(attention_logits)
|
114 |
+
return attention_output
|
115 |
+
|
116 |
+
|
117 |
+
class GELU(nn.Module):
|
118 |
+
"""
|
119 |
+
误差函数erf:
|
120 |
+
erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
|
121 |
+
其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说:
|
122 |
+
x是误差函数的输入参数,表示积分的上限
|
123 |
+
t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
|
124 |
+
e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
|
125 |
+
通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
|
126 |
+
"""
|
127 |
+
def forward(self,x):
|
128 |
+
return 0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))
|
129 |
+
|
130 |
+
class QuickGELU(nn.Module):
|
131 |
+
def __init__(self):
|
132 |
+
super(QuickGELU,self).__init__()
|
133 |
+
def forward(self,x):
|
134 |
+
old_dtype=x.dtype
|
135 |
+
x=x.to(torch.float32)
|
136 |
+
return (x*torch.sigmoid(1.702*x)).to(old_dtype)
|
137 |
+
|
138 |
+
|
139 |
+
class MLP(nn.Module):
|
140 |
+
def __init__(self,config):
|
141 |
+
super(MLP,self).__init__()
|
142 |
+
self.hidden_size=config.hidden_size
|
143 |
+
self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
144 |
+
self.gelu=QuickGELU()
|
145 |
+
self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
146 |
+
def forward(self,hidden_state):
|
147 |
+
hidden_state=self.c_fc(hidden_state)
|
148 |
+
hidden_state=self.gelu(hidden_state)
|
149 |
+
hidden_state=self.c_proj(hidden_state)
|
150 |
+
return hidden_state
|
151 |
+
|
152 |
+
class Config:
|
153 |
+
def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype):
|
154 |
+
self.vocab_size=vocab_size
|
155 |
+
self.max_position_embeddings=max_position_embeddings
|
156 |
+
self.hidden_size=hidden_size
|
157 |
+
self.num_layers=num_layers
|
158 |
+
self.num_heads=num_heads
|
159 |
+
self.device=device
|
160 |
+
self.dtype=dtype
|
161 |
+
self.norm_eps=1e-5
|
162 |
+
self.num_virtual_tokens=20
|
163 |
+
self.token_dim=hidden_size
|
164 |
+
self.encoder_hidden_size=hidden_size
|
165 |
+
config=Config(
|
166 |
+
vocab_size=49408,
|
167 |
+
max_position_embeddings=77,
|
168 |
+
hidden_size=512,
|
169 |
+
num_layers=12,
|
170 |
+
num_heads=8,
|
171 |
+
device=torch.device('cuda:0'),
|
172 |
+
dtype=torch.float16
|
173 |
+
)
|
174 |
+
class TextEncoder(nn.Module):
|
175 |
+
def __init__(self,config):
|
176 |
+
super(TextEncoder,self).__init__()
|
177 |
+
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,device=config.device,dtype=config.dtype)
|
178 |
+
self.positional_embedding=nn.Parameter(torch.zeros(config.max_position_embeddings,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
179 |
+
self.transformer=Transformer(config)
|
180 |
+
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
181 |
+
def forward(self,input_ids):
|
182 |
+
b,n=input_ids.shape
|
183 |
+
prompt_embedding,token_embeddings=self.token_embedding(input_ids)
|
184 |
+
position_ids=torch.arange(n,device=config.device,dtype=config.dtype).unsqueeze(0).expand(b,n)
|
185 |
+
position_embeddings=self.positional_embedding[position_ids]
|
186 |
+
embeddings=token_embeddings+position_embeddings
|
187 |
+
embeddings=torch.cat((prompt_embedding,embeddings),dim=1)
|
188 |
+
embeddings=self.transformer(embeddings)
|
189 |
+
embeddings=self.ln_final(embeddings)
|
190 |
+
return embeddings
|
191 |
+
|
192 |
+
text_encoder=Transformer(config)
|
.ipynb_checkpoints/VIT-checkpoint.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
#huggingface实现的前缀微调
|
10 |
+
class PrefixEncoder(torch.nn.Module):
|
11 |
+
def __init__(self,config):
|
12 |
+
super(PrefixEncoder,self).__init__()
|
13 |
+
self.config=config
|
14 |
+
self.device=config.device
|
15 |
+
self.dtype=config.dtype
|
16 |
+
self.num_virtual_tokens=config.num_virtual_tokens
|
17 |
+
self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype)
|
18 |
+
self.token_dim=config.token_dim
|
19 |
+
self.encoder_hidden_size=config.encoder_hidden_size
|
20 |
+
self.num_layers=config.num_layers
|
21 |
+
self.transformer=torch.nn.Sequential(
|
22 |
+
torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
|
23 |
+
torch.nn.Tanh(),
|
24 |
+
torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
|
25 |
+
)
|
26 |
+
def forward(self,input_ids,batch_size):
|
27 |
+
input_ids=input_ids.unsqueeze(0)
|
28 |
+
prefix_embedding=self.embedding(input_ids)
|
29 |
+
prefix_embedding=self.transformer(prefix_embedding)
|
30 |
+
self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
|
31 |
+
prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
|
32 |
+
prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
|
33 |
+
prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
|
34 |
+
del self.embedding
|
35 |
+
del self.transformer
|
36 |
+
k,v=prefix_embedding.chunk(2,dim=0)
|
37 |
+
return (k.squeeze(0),v.squeeze(0))
|
38 |
+
|
39 |
+
|
40 |
+
import torch
|
41 |
+
import torch.nn as nn
|
42 |
+
import math
|
43 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
44 |
+
from torch.nn import functional as F
|
45 |
+
def position_embedding(x,position_ids):
|
46 |
+
hidden_size=x.size(2)
|
47 |
+
seq_len=x.size(1)
|
48 |
+
div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
|
49 |
+
positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
|
50 |
+
positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
|
51 |
+
positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
|
52 |
+
positional_encoding=positional_encoding.unsqueeze(0)
|
53 |
+
return positional_encoding
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
class VisionTransformer(nn.Module):
|
59 |
+
def __init__(self,config):
|
60 |
+
super(VisionTransformer,self).__init__()
|
61 |
+
self.image_channel=config.image_channel
|
62 |
+
self.hidden_size=config.hidden_size
|
63 |
+
self.norm_eps=config.norm_eps
|
64 |
+
self.patch_size=config.patch_size
|
65 |
+
self.output_dim=config.output_dim
|
66 |
+
self.dtype=config.dtype
|
67 |
+
self.num_patches=config.num_patches
|
68 |
+
self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
|
69 |
+
self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
|
70 |
+
self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
71 |
+
self.transformer=Transformer(config)
|
72 |
+
#self.position_ids=torch.arange(config.num_patches+1,dtype=torch.long,device=config.device)
|
73 |
+
#self.position_embeddings=nn.Parameter(torch.zeros(1,config.num_patches+1,config.hidden_size))
|
74 |
+
#nn.init.normal_(self.position_embeddings)
|
75 |
+
#clsToken,用于图像分类任务
|
76 |
+
#self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size,device=config.device))
|
77 |
+
#分类token不是可训练参数
|
78 |
+
self.class_embedding=nn.Parameter(torch.empty(config.hidden_size,device=config.device),requires_grad=False)
|
79 |
+
#很明显这里的position_embedding也是一个可学习参数
|
80 |
+
self.positional_embedding=nn.Parameter(torch.empty(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=False)
|
81 |
+
#可训练参数
|
82 |
+
self.proj=nn.Parameter(torch.empty(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=False)
|
83 |
+
self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
84 |
+
def forward(self,hidden_state,use_emotion):
|
85 |
+
b,c,h,w=hidden_state.shape
|
86 |
+
#获得embedding向量
|
87 |
+
hidden_state=self.conv1(hidden_state)
|
88 |
+
hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
|
89 |
+
#添加cls token embedding
|
90 |
+
hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
|
91 |
+
#使用transformer原论文中的固定位置嵌入
|
92 |
+
#hidden_state=hidden_state+position_embedding(hidden_state,self.position_ids)
|
93 |
+
hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
|
94 |
+
hidden_state=self.ln_pre(hidden_state)
|
95 |
+
hidden_state=self.transformer(hidden_state,use_emotion)
|
96 |
+
#提取cls token输出 与image patch输出
|
97 |
+
cls_state=hidden_state[:,0,:]
|
98 |
+
cls_state=self.ln_post(cls_state)
|
99 |
+
cls_state=torch.matmul(cls_state,self.proj)
|
100 |
+
#image_state=hidden_state[:,1:,:]
|
101 |
+
#image_state size (batch_size,49,768)
|
102 |
+
return cls_state
|
103 |
+
|
104 |
+
class Transformer(nn.Module):
|
105 |
+
def __init__(self,config):
|
106 |
+
super(Transformer,self).__init__()
|
107 |
+
self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
|
108 |
+
self.prefix=PrefixEncoder(config)
|
109 |
+
prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
|
110 |
+
self.register_buffer("prefix_tokens",prefix_tokens)
|
111 |
+
def forward(self,hidden_state,use_emotion):
|
112 |
+
if use_emotion:
|
113 |
+
b,n,h=hidden_state.shape
|
114 |
+
prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
|
115 |
+
for index,resblock in enumerate(self.resblocks):
|
116 |
+
#在每一层之前提取前缀向量输入到resblock中进行拼接
|
117 |
+
hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
|
118 |
+
return hidden_state
|
119 |
+
else:
|
120 |
+
for index,resblock in enumerate(self.resblocks):
|
121 |
+
#在每一层之前提取前缀向量输入到resblock中进行拼接
|
122 |
+
hidden_state=resblock(hidden_state)
|
123 |
+
return hidden_state
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
class ResidualAttentionBlock(nn.Module):
|
131 |
+
def __init__(self,config):
|
132 |
+
super(ResidualAttentionBlock,self).__init__()
|
133 |
+
self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
134 |
+
self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
135 |
+
#self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
|
136 |
+
self.attn=MultiHeadAttention(config)
|
137 |
+
self.mlp=MLP(config)
|
138 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
139 |
+
residual=hidden_state
|
140 |
+
hidden_state=self.ln_1(hidden_state)
|
141 |
+
hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
|
142 |
+
hidden_state=residual+hidden_state
|
143 |
+
residual=hidden_state
|
144 |
+
hidden_state=self.ln_2(hidden_state)
|
145 |
+
hidden_state=self.mlp(hidden_state)
|
146 |
+
hidden_state=residual+hidden_state
|
147 |
+
return hidden_state
|
148 |
+
|
149 |
+
class MultiHeadAttention(nn.Module):
|
150 |
+
def __init__(self,config):
|
151 |
+
super(MultiHeadAttention,self).__init__()
|
152 |
+
self.hidden_size=config.hidden_size
|
153 |
+
self.num_heads=config.num_heads
|
154 |
+
self.head_size=self.hidden_size//self.num_heads
|
155 |
+
#nn.Parameter包含weight和bias可训练参数
|
156 |
+
self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
157 |
+
self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
158 |
+
#self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
159 |
+
#self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
160 |
+
#self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
161 |
+
self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
|
162 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
163 |
+
b,n,h=hidden_state.shape
|
164 |
+
#q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
165 |
+
#k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
|
166 |
+
#v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
167 |
+
q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
|
168 |
+
if prefix_k is not None and prefix_v is not None:
|
169 |
+
#将前缀插入到序列之前
|
170 |
+
#print("origional k.shape",prefix_k.shape)
|
171 |
+
k=torch.cat((prefix_k,k),dim=1)
|
172 |
+
v=torch.cat((prefix_v,v),dim=1)
|
173 |
+
#print("model original k :",k[:,0,0])
|
174 |
+
bk,nk,hk=k.shape
|
175 |
+
bq,nq,hq=q.shape
|
176 |
+
q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
|
177 |
+
k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
178 |
+
v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
179 |
+
attention_logits=F.scaled_dot_product_attention(q, k, v)
|
180 |
+
attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
|
181 |
+
attention_output=self.out_proj(attention_logits)
|
182 |
+
return attention_output
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
class GELU(nn.Module):
|
187 |
+
"""
|
188 |
+
误差函数erf:
|
189 |
+
erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
|
190 |
+
其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一���点,具体来说:
|
191 |
+
x是误差函数的输入参数,表示积分的上限
|
192 |
+
t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
|
193 |
+
e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
|
194 |
+
通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
|
195 |
+
"""
|
196 |
+
def forward(self,x):
|
197 |
+
old_dtype=x.dtype
|
198 |
+
x=x.to(torch.float32)
|
199 |
+
return (0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))).to(old_dtype)
|
200 |
+
|
201 |
+
class QuickGELU(nn.Module):
|
202 |
+
def __init__(self):
|
203 |
+
super(QuickGELU,self).__init__()
|
204 |
+
def forward(self,x):
|
205 |
+
old_dtype=x.dtype
|
206 |
+
x=x.to(torch.float32)
|
207 |
+
return (x*torch.sigmoid(1.702*x)).to(old_dtype)
|
208 |
+
|
209 |
+
|
210 |
+
class MLP(nn.Module):
|
211 |
+
def __init__(self,config):
|
212 |
+
super(MLP,self).__init__()
|
213 |
+
self.hidden_size=config.hidden_size
|
214 |
+
self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
215 |
+
self.gelu=QuickGELU()
|
216 |
+
self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
217 |
+
def forward(self,hidden_state):
|
218 |
+
hidden_state=self.c_fc(hidden_state)
|
219 |
+
hidden_state=self.gelu(hidden_state)
|
220 |
+
hidden_state=self.c_proj(hidden_state)
|
221 |
+
return hidden_state
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
class ViTConfig:
|
226 |
+
def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
|
227 |
+
self.image_channel=image_channel
|
228 |
+
self.hidden_size=hidden_size
|
229 |
+
self.num_heads=num_heads
|
230 |
+
self.num_layers=num_layers
|
231 |
+
self.patch_size=patch_size
|
232 |
+
self.num_patches=num_patches
|
233 |
+
self.norm_eps=norm_eps
|
234 |
+
self.device=device
|
235 |
+
self.dtype=torch.float16
|
236 |
+
self.patch_token_num=self.hidden_size//self.patch_size**2+1
|
237 |
+
self.output_dim=output_dim
|
238 |
+
self.num_virtual_tokens=20
|
239 |
+
self.token_dim=self.hidden_size
|
240 |
+
self.encoder_hidden_size=self.hidden_size
|
241 |
+
|
242 |
+
config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
|
243 |
+
model=VisionTransformer(config)
|
Dog sad.jpg
ADDED
![]() |
EmotionCLIP-V2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3d83b57423a070150ca67c87f7ad2a163b531a890632f4ebe3cf1c12a08ffd9
|
3 |
+
size 304602701
|
EmotionCLIP.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱
|
3 |
+
|
4 |
+
文本生成仍然考虑因果掩码。
|
5 |
+
"""
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from VIT import model as VIT
|
8 |
+
from Text_Encoder import text_encoder as transformer
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch
|
11 |
+
from Text_Encoder import MLP
|
12 |
+
|
13 |
+
class Prompt_block(nn.Module):
|
14 |
+
def __init__(self,config):
|
15 |
+
super(Prompt_block,self).__init__()
|
16 |
+
self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device)
|
17 |
+
def forward(self,text_embeddings):
|
18 |
+
b,_,_=text_embeddings.size()
|
19 |
+
n,dim=self.prompt_embedding.weight.size()
|
20 |
+
"""
|
21 |
+
new_embeddings=[]
|
22 |
+
for batch,index_ in enumerate(index):
|
23 |
+
text_embedding=text_embeddings[0]
|
24 |
+
text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0)
|
25 |
+
new_embeddings.append(text_embedding)
|
26 |
+
stacked_embedding= torch.stack(new_embeddings, dim=0)
|
27 |
+
return stacked_embedding
|
28 |
+
"""
|
29 |
+
text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1)
|
30 |
+
return text_embeddings
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
class CLIP(nn.Module):
|
37 |
+
def __init__(self,config):
|
38 |
+
super().__init__()
|
39 |
+
self.visual=VIT
|
40 |
+
self.device=config.device
|
41 |
+
self.dtype=config.dtype
|
42 |
+
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
|
43 |
+
self.max_position_embeddings=config.max_position_embeddings
|
44 |
+
self.prompt_num=config.prompt_num
|
45 |
+
self.transformer=transformer
|
46 |
+
#增加一个prompt block
|
47 |
+
self.prompt_block=Prompt_block(config)
|
48 |
+
self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device))
|
49 |
+
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
|
50 |
+
self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
|
51 |
+
self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False)
|
52 |
+
def encode_image(self,img,use_emotion=True):
|
53 |
+
cls_embedding=self.visual(img,use_emotion)
|
54 |
+
#cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512]
|
55 |
+
return cls_embedding
|
56 |
+
def encode_text(self,text,use_emotion=True):
|
57 |
+
#预留20token的位置
|
58 |
+
b,n=text.size()
|
59 |
+
index=text.argmax(dim=-1)
|
60 |
+
text_embedding=self.token_embedding(text)
|
61 |
+
#text_embedding=self.prompt_block(index,text_embedding)
|
62 |
+
if n==self.max_position_embeddings-self.prompt_num:
|
63 |
+
text_embedding=self.prompt_block(text_embedding)
|
64 |
+
index=index+torch.tensor(20,device=index.device,dtype=index.dtype)
|
65 |
+
position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype)
|
66 |
+
text_embedding=position_embedding+text_embedding
|
67 |
+
text_embedding=self.transformer(text_embedding,use_emotion=use_emotion)
|
68 |
+
text_embedding=self.ln_final(text_embedding)
|
69 |
+
#传入的标记有
|
70 |
+
#print(index[0],index_new[0],text_embedding.shape)
|
71 |
+
text_embedding=text_embedding[torch.arange(text.shape[0]),index]
|
72 |
+
[email protected]_projection.to(self.dtype)
|
73 |
+
|
74 |
+
return text_embedding
|
75 |
+
|
76 |
+
def forward(self,image,text,use_emotion=True):
|
77 |
+
image_features=self.encode_image(image,use_emotion)
|
78 |
+
text_features=self.encode_text(text,use_emotion)
|
79 |
+
# normalized features
|
80 |
+
image_features=image_features/image_features.norm(dim=-1,keepdim=True)
|
81 |
+
text_features=text_features/text_features.norm(dim=-1,keepdim=True)
|
82 |
+
# cosine similarity as logits
|
83 |
+
logit_scale=self.logit_scale.exp()
|
84 |
+
logits_per_image=logit_scale*image_features@text_features.t()
|
85 |
+
logits_per_text=logits_per_image.t()
|
86 |
+
# shape = [global_batch_size, global_batch_size]
|
87 |
+
return logits_per_image,logits_per_text
|
88 |
+
|
89 |
+
class Config:
|
90 |
+
def __init__(self):
|
91 |
+
self.vocab_size=49408
|
92 |
+
self.image_dim=768
|
93 |
+
self.num_patches=49
|
94 |
+
self.patch_size=32
|
95 |
+
self.hidden_size=512
|
96 |
+
self.prompt_num=20
|
97 |
+
self.max_position_embeddings=77
|
98 |
+
self.num_hidden_layers=12
|
99 |
+
self.num_attention_heads=8
|
100 |
+
self.head_size=64
|
101 |
+
self.layer_norm_eps=1e-5
|
102 |
+
self.activation_function="Quickgelu"
|
103 |
+
self.dtype=torch.float16
|
104 |
+
self.device=torch.device("cuda:0")
|
105 |
+
self.logit_scale_init=4.6052
|
106 |
+
self.num_virtual_tokens=20
|
107 |
+
self.token_dim=self.hidden_size
|
108 |
+
self.encoder_hidden_size=self.hidden_size
|
109 |
+
|
110 |
+
config=Config()
|
111 |
+
model=CLIP(config)
|
112 |
+
#加载预训练权重
|
113 |
+
model.load_state_dict(torch.load(r'./EmotionCLIP-V2.pth',weights_only=True,map_location='cpu'),strict=True)
|
114 |
+
"""
|
115 |
+
for name, param in model.named_parameters():
|
116 |
+
if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix'
|
117 |
+
print(name,"'s requires_grad turn off.")
|
118 |
+
param.requires_grad = False # 冻结该参数
|
119 |
+
else:
|
120 |
+
print(name,"'s requires_grad turn on.")
|
121 |
+
param.requires_grad = True # 允许该参数进行训练
|
122 |
+
"""
|
123 |
+
|
124 |
+
#编译模型
|
125 |
+
#model=torch.compile(model)
|
126 |
+
import pickle
|
127 |
+
from PIL import Image
|
128 |
+
import numpy as np
|
129 |
+
import clip
|
130 |
+
with open('./preprocess.pkl','rb') as f:
|
131 |
+
preprocess = pickle.load(f)
|
132 |
+
with open('./tokenize.pkl','rb') as f:
|
133 |
+
tokenizer=pickle.load(f)
|
134 |
+
device=config.device
|
135 |
+
image = preprocess(Image.open("Dog sad.jpg")).unsqueeze(0).to(device)
|
136 |
+
# 情感识别
|
137 |
+
labels=[
|
138 |
+
'amusement',
|
139 |
+
'anger',
|
140 |
+
'awe',
|
141 |
+
'contentment',
|
142 |
+
'disgust',
|
143 |
+
'excitement',
|
144 |
+
'fear',
|
145 |
+
'sadness',
|
146 |
+
'neutral'
|
147 |
+
]
|
148 |
+
text_list=[ f"This picture conveys a sense of {label}" for label in labels]
|
149 |
+
tokens= tokenizer(text_list,
|
150 |
+
context_length=57).to(device)
|
151 |
+
|
152 |
+
with torch.no_grad():
|
153 |
+
logits_per_image, logits_per_text = model(image.to(config.dtype), tokens)
|
154 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
155 |
+
|
156 |
+
# 获取预测标签
|
157 |
+
predicted_index = np.argmax(probs, axis=1)
|
158 |
+
predicted_label=labels[predicted_index[0]]
|
159 |
+
|
160 |
+
print("情感识别:", probs)
|
161 |
+
print("预测的情感标签:", predicted_label)
|
162 |
+
|
163 |
+
# 泛化性能
|
164 |
+
labels=[
|
165 |
+
'spider',
|
166 |
+
'dog',
|
167 |
+
'cat',
|
168 |
+
'fish'
|
169 |
+
]
|
170 |
+
text_list=[ f"This is a {label}" for label in labels]
|
171 |
+
tokens= tokenizer(text_list,context_length=57).to(device)
|
172 |
+
|
173 |
+
with torch.no_grad():
|
174 |
+
logits_per_image, logits_per_text = model(image.to(config.dtype), tokens, use_emotion=False)
|
175 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
176 |
+
|
177 |
+
# 获取预测标签
|
178 |
+
predicted_index = np.argmax(probs, axis=1)
|
179 |
+
predicted_label=labels[predicted_index[0]]
|
180 |
+
|
181 |
+
print("泛化识别:", probs)
|
182 |
+
print("预测的泛化标签:", predicted_label)
|
183 |
+
|
Text_Encoder.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class PrefixEncoder(torch.nn.Module):
|
9 |
+
def __init__(self,config):
|
10 |
+
super(PrefixEncoder,self).__init__()
|
11 |
+
self.config=config
|
12 |
+
self.device=config.device
|
13 |
+
self.dtype=config.dtype
|
14 |
+
self.num_virtual_tokens=config.num_virtual_tokens
|
15 |
+
self.token_dim=config.token_dim
|
16 |
+
self.encoder_hidden_size=config.encoder_hidden_size
|
17 |
+
self.num_layers=config.num_layers
|
18 |
+
self.prefix_embedding=nn.Parameter(torch.empty(1,self.num_virtual_tokens,self.num_layers*2*self.token_dim,device=config.device,dtype=config.dtype),requires_grad=False)
|
19 |
+
def forward(self,input_ids,batch_size):
|
20 |
+
prefix_embedding=self.prefix_embedding
|
21 |
+
prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
|
22 |
+
prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
|
23 |
+
prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
|
24 |
+
k,v=prefix_embedding.chunk(2,dim=0)
|
25 |
+
return (k.squeeze(0),v.squeeze(0))
|
26 |
+
|
27 |
+
|
28 |
+
class Transformer(nn.Module):
|
29 |
+
def __init__(self,config):
|
30 |
+
super(Transformer,self).__init__()
|
31 |
+
self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
|
32 |
+
self.prefix=PrefixEncoder(config)
|
33 |
+
prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
|
34 |
+
self.register_buffer("prefix_tokens",prefix_tokens)
|
35 |
+
def forward(self,hidden_state,use_emotion):
|
36 |
+
if use_emotion:
|
37 |
+
#print("激活text transformer prefix.")
|
38 |
+
b,n,h=hidden_state.shape
|
39 |
+
prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
|
40 |
+
for index,resblock in enumerate(self.resblocks):
|
41 |
+
hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
|
42 |
+
return hidden_state
|
43 |
+
else:
|
44 |
+
for index,resblock in enumerate(self.resblocks):
|
45 |
+
hidden_state=resblock(hidden_state)
|
46 |
+
return hidden_state
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
class ResidualAttentionBlock(nn.Module):
|
54 |
+
def __init__(self,config):
|
55 |
+
super(ResidualAttentionBlock,self).__init__()
|
56 |
+
self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
57 |
+
self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
58 |
+
#self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
|
59 |
+
self.attn=MultiHeadAttention(config)
|
60 |
+
self.mlp=MLP(config)
|
61 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
62 |
+
residual=hidden_state
|
63 |
+
hidden_state=self.ln_1(hidden_state)
|
64 |
+
hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
|
65 |
+
hidden_state=residual+hidden_state
|
66 |
+
residual=hidden_state
|
67 |
+
hidden_state=self.ln_2(hidden_state)
|
68 |
+
hidden_state=self.mlp(hidden_state)
|
69 |
+
hidden_state=residual+hidden_state
|
70 |
+
return hidden_state
|
71 |
+
|
72 |
+
class MultiHeadAttention(nn.Module):
|
73 |
+
def __init__(self,config):
|
74 |
+
super(MultiHeadAttention,self).__init__()
|
75 |
+
self.hidden_size=config.hidden_size
|
76 |
+
self.num_heads=config.num_heads
|
77 |
+
self.head_size=self.hidden_size//self.num_heads
|
78 |
+
#nn.Parameter包含weight和bias可训练参数
|
79 |
+
self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
80 |
+
self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
81 |
+
#self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
82 |
+
#self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
83 |
+
#self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
84 |
+
self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
|
85 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
86 |
+
b,n,c=hidden_state.shape
|
87 |
+
#q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
88 |
+
#k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
|
89 |
+
#v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
90 |
+
q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
|
91 |
+
if prefix_k is not None and prefix_v is not None:
|
92 |
+
#将前缀插入到序列之前
|
93 |
+
k=torch.cat((prefix_k,k),dim=1)
|
94 |
+
v=torch.cat((prefix_v,v),dim=1)
|
95 |
+
#print("model origin k :",k[:,0,0])
|
96 |
+
bk,nk,hk=k.shape
|
97 |
+
bq,nq,hq=q.shape
|
98 |
+
q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
|
99 |
+
k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
100 |
+
v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
101 |
+
attention_logits=F.scaled_dot_product_attention(q, k, v)
|
102 |
+
attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
|
103 |
+
attention_output=self.out_proj(attention_logits)
|
104 |
+
return attention_output
|
105 |
+
|
106 |
+
|
107 |
+
class GELU(nn.Module):
|
108 |
+
"""
|
109 |
+
误差函数erf:
|
110 |
+
erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
|
111 |
+
其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说:
|
112 |
+
x是误差函数的输入参数,表示积分的上限
|
113 |
+
t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
|
114 |
+
e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
|
115 |
+
通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
|
116 |
+
"""
|
117 |
+
def forward(self,x):
|
118 |
+
return 0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))
|
119 |
+
|
120 |
+
class QuickGELU(nn.Module):
|
121 |
+
def __init__(self):
|
122 |
+
super(QuickGELU,self).__init__()
|
123 |
+
def forward(self,x):
|
124 |
+
old_dtype=x.dtype
|
125 |
+
x=x.to(torch.float32)
|
126 |
+
return (x*torch.sigmoid(1.702*x)).to(old_dtype)
|
127 |
+
|
128 |
+
|
129 |
+
class MLP(nn.Module):
|
130 |
+
def __init__(self,config):
|
131 |
+
super(MLP,self).__init__()
|
132 |
+
self.hidden_size=config.hidden_size
|
133 |
+
self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
134 |
+
self.gelu=QuickGELU()
|
135 |
+
self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
136 |
+
def forward(self,hidden_state):
|
137 |
+
hidden_state=self.c_fc(hidden_state)
|
138 |
+
hidden_state=self.gelu(hidden_state)
|
139 |
+
hidden_state=self.c_proj(hidden_state)
|
140 |
+
return hidden_state
|
141 |
+
|
142 |
+
class Config:
|
143 |
+
def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype):
|
144 |
+
self.vocab_size=vocab_size
|
145 |
+
self.max_position_embeddings=max_position_embeddings
|
146 |
+
self.hidden_size=hidden_size
|
147 |
+
self.num_layers=num_layers
|
148 |
+
self.num_heads=num_heads
|
149 |
+
self.device=device
|
150 |
+
self.dtype=dtype
|
151 |
+
self.norm_eps=1e-5
|
152 |
+
self.num_virtual_tokens=20
|
153 |
+
self.token_dim=hidden_size
|
154 |
+
self.encoder_hidden_size=hidden_size
|
155 |
+
config=Config(
|
156 |
+
vocab_size=49408,
|
157 |
+
max_position_embeddings=77,
|
158 |
+
hidden_size=512,
|
159 |
+
num_layers=12,
|
160 |
+
num_heads=8,
|
161 |
+
device=torch.device('cuda:0'),
|
162 |
+
dtype=torch.float16
|
163 |
+
)
|
164 |
+
class TextEncoder(nn.Module):
|
165 |
+
def __init__(self,config):
|
166 |
+
super(TextEncoder,self).__init__()
|
167 |
+
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,device=config.device,dtype=config.dtype)
|
168 |
+
self.positional_embedding=nn.Parameter(torch.zeros(config.max_position_embeddings,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
169 |
+
self.transformer=Transformer(config)
|
170 |
+
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
171 |
+
def forward(self,input_ids):
|
172 |
+
b,n=input_ids.shape
|
173 |
+
prompt_embedding,token_embeddings=self.token_embedding(input_ids)
|
174 |
+
position_ids=torch.arange(n,device=config.device,dtype=config.dtype).unsqueeze(0).expand(b,n)
|
175 |
+
position_embeddings=self.positional_embedding[position_ids]
|
176 |
+
embeddings=token_embeddings+position_embeddings
|
177 |
+
embeddings=torch.cat((prompt_embedding,embeddings),dim=1)
|
178 |
+
embeddings=self.transformer(embeddings)
|
179 |
+
embeddings=self.ln_final(embeddings)
|
180 |
+
return embeddings
|
181 |
+
|
182 |
+
text_encoder=Transformer(config)
|
VIT.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
#huggingface实现的前缀微调
|
10 |
+
class PrefixEncoder(torch.nn.Module):
|
11 |
+
def __init__(self,config):
|
12 |
+
super(PrefixEncoder,self).__init__()
|
13 |
+
self.config=config
|
14 |
+
self.device=config.device
|
15 |
+
self.dtype=config.dtype
|
16 |
+
self.num_virtual_tokens=config.num_virtual_tokens
|
17 |
+
self.token_dim=config.token_dim
|
18 |
+
self.encoder_hidden_size=config.encoder_hidden_size
|
19 |
+
self.num_layers=config.num_layers
|
20 |
+
self.prefix_embedding=nn.Parameter(torch.empty(1,self.num_virtual_tokens,self.num_layers*2*self.token_dim,device=config.device,dtype=config.dtype),requires_grad=False)
|
21 |
+
def forward(self,input_ids,batch_size):
|
22 |
+
prefix_embedding=self.prefix_embedding
|
23 |
+
prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
|
24 |
+
prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
|
25 |
+
prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
|
26 |
+
k,v=prefix_embedding.chunk(2,dim=0)
|
27 |
+
return (k.squeeze(0),v.squeeze(0))
|
28 |
+
|
29 |
+
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
import math
|
33 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
34 |
+
from torch.nn import functional as F
|
35 |
+
def position_embedding(x,position_ids):
|
36 |
+
hidden_size=x.size(2)
|
37 |
+
seq_len=x.size(1)
|
38 |
+
div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
|
39 |
+
positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
|
40 |
+
positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
|
41 |
+
positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
|
42 |
+
positional_encoding=positional_encoding.unsqueeze(0)
|
43 |
+
return positional_encoding
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
class VisionTransformer(nn.Module):
|
49 |
+
def __init__(self,config):
|
50 |
+
super(VisionTransformer,self).__init__()
|
51 |
+
self.image_channel=config.image_channel
|
52 |
+
self.hidden_size=config.hidden_size
|
53 |
+
self.norm_eps=config.norm_eps
|
54 |
+
self.patch_size=config.patch_size
|
55 |
+
self.output_dim=config.output_dim
|
56 |
+
self.dtype=config.dtype
|
57 |
+
self.num_patches=config.num_patches
|
58 |
+
self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
|
59 |
+
self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
|
60 |
+
self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
61 |
+
self.transformer=Transformer(config)
|
62 |
+
#self.position_ids=torch.arange(config.num_patches+1,dtype=torch.long,device=config.device)
|
63 |
+
#self.position_embeddings=nn.Parameter(torch.zeros(1,config.num_patches+1,config.hidden_size))
|
64 |
+
#nn.init.normal_(self.position_embeddings)
|
65 |
+
#clsToken,用于图像分类任务
|
66 |
+
#self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size,device=config.device))
|
67 |
+
#分类token不是可训练参数
|
68 |
+
self.class_embedding=nn.Parameter(torch.empty(config.hidden_size,device=config.device),requires_grad=False)
|
69 |
+
#很明显这里的position_embedding也是一个可学习参数
|
70 |
+
self.positional_embedding=nn.Parameter(torch.empty(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=False)
|
71 |
+
#可训练参数
|
72 |
+
self.proj=nn.Parameter(torch.empty(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=False)
|
73 |
+
self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
74 |
+
def forward(self,hidden_state,use_emotion):
|
75 |
+
b,c,h,w=hidden_state.shape
|
76 |
+
#获得embedding向量
|
77 |
+
hidden_state=self.conv1(hidden_state)
|
78 |
+
hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
|
79 |
+
#添加cls token embedding
|
80 |
+
hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
|
81 |
+
#使用transformer原论文中的固定位置嵌入
|
82 |
+
#hidden_state=hidden_state+position_embedding(hidden_state,self.position_ids)
|
83 |
+
hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
|
84 |
+
hidden_state=self.ln_pre(hidden_state)
|
85 |
+
hidden_state=self.transformer(hidden_state,use_emotion)
|
86 |
+
#提取cls token输出 与image patch输出
|
87 |
+
cls_state=hidden_state[:,0,:]
|
88 |
+
cls_state=self.ln_post(cls_state)
|
89 |
+
cls_state=torch.matmul(cls_state,self.proj)
|
90 |
+
#image_state=hidden_state[:,1:,:]
|
91 |
+
#image_state size (batch_size,49,768)
|
92 |
+
return cls_state
|
93 |
+
|
94 |
+
class Transformer(nn.Module):
|
95 |
+
def __init__(self,config):
|
96 |
+
super(Transformer,self).__init__()
|
97 |
+
self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
|
98 |
+
self.prefix=PrefixEncoder(config)
|
99 |
+
prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
|
100 |
+
self.register_buffer("prefix_tokens",prefix_tokens)
|
101 |
+
def forward(self,hidden_state,use_emotion):
|
102 |
+
if use_emotion:
|
103 |
+
b,n,h=hidden_state.shape
|
104 |
+
prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
|
105 |
+
for index,resblock in enumerate(self.resblocks):
|
106 |
+
#在每一层之前提取前缀向量输入到resblock中进行拼接
|
107 |
+
hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
|
108 |
+
return hidden_state
|
109 |
+
else:
|
110 |
+
for index,resblock in enumerate(self.resblocks):
|
111 |
+
#在每一层之前提取前缀向量输入到resblock中进行拼接
|
112 |
+
hidden_state=resblock(hidden_state)
|
113 |
+
return hidden_state
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class ResidualAttentionBlock(nn.Module):
|
121 |
+
def __init__(self,config):
|
122 |
+
super(ResidualAttentionBlock,self).__init__()
|
123 |
+
self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
124 |
+
self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
125 |
+
#self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
|
126 |
+
self.attn=MultiHeadAttention(config)
|
127 |
+
self.mlp=MLP(config)
|
128 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
129 |
+
residual=hidden_state
|
130 |
+
hidden_state=self.ln_1(hidden_state)
|
131 |
+
hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
|
132 |
+
hidden_state=residual+hidden_state
|
133 |
+
residual=hidden_state
|
134 |
+
hidden_state=self.ln_2(hidden_state)
|
135 |
+
hidden_state=self.mlp(hidden_state)
|
136 |
+
hidden_state=residual+hidden_state
|
137 |
+
return hidden_state
|
138 |
+
|
139 |
+
class MultiHeadAttention(nn.Module):
|
140 |
+
def __init__(self,config):
|
141 |
+
super(MultiHeadAttention,self).__init__()
|
142 |
+
self.hidden_size=config.hidden_size
|
143 |
+
self.num_heads=config.num_heads
|
144 |
+
self.head_size=self.hidden_size//self.num_heads
|
145 |
+
#nn.Parameter包含weight和bias可训练参数
|
146 |
+
self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
147 |
+
self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
|
148 |
+
#self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
149 |
+
#self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
150 |
+
#self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
151 |
+
self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
|
152 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
153 |
+
b,n,h=hidden_state.shape
|
154 |
+
#q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
155 |
+
#k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
|
156 |
+
#v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
157 |
+
q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
|
158 |
+
if prefix_k is not None and prefix_v is not None:
|
159 |
+
#将前缀插入到序列之前
|
160 |
+
#print("origional k.shape",prefix_k.shape)
|
161 |
+
k=torch.cat((prefix_k,k),dim=1)
|
162 |
+
v=torch.cat((prefix_v,v),dim=1)
|
163 |
+
#print("model original k :",k[:,0,0])
|
164 |
+
bk,nk,hk=k.shape
|
165 |
+
bq,nq,hq=q.shape
|
166 |
+
q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
|
167 |
+
k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
168 |
+
v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
169 |
+
attention_logits=F.scaled_dot_product_attention(q, k, v)
|
170 |
+
attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
|
171 |
+
attention_output=self.out_proj(attention_logits)
|
172 |
+
return attention_output
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
class GELU(nn.Module):
|
177 |
+
"""
|
178 |
+
误差函数erf:
|
179 |
+
erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
|
180 |
+
其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说:
|
181 |
+
x是误差函数的输入参数,表示积分的上限
|
182 |
+
t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
|
183 |
+
e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
|
184 |
+
通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
|
185 |
+
"""
|
186 |
+
def forward(self,x):
|
187 |
+
old_dtype=x.dtype
|
188 |
+
x=x.to(torch.float32)
|
189 |
+
return (0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))).to(old_dtype)
|
190 |
+
|
191 |
+
class QuickGELU(nn.Module):
|
192 |
+
def __init__(self):
|
193 |
+
super(QuickGELU,self).__init__()
|
194 |
+
def forward(self,x):
|
195 |
+
old_dtype=x.dtype
|
196 |
+
x=x.to(torch.float32)
|
197 |
+
return (x*torch.sigmoid(1.702*x)).to(old_dtype)
|
198 |
+
|
199 |
+
|
200 |
+
class MLP(nn.Module):
|
201 |
+
def __init__(self,config):
|
202 |
+
super(MLP,self).__init__()
|
203 |
+
self.hidden_size=config.hidden_size
|
204 |
+
self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
205 |
+
self.gelu=QuickGELU()
|
206 |
+
self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
207 |
+
def forward(self,hidden_state):
|
208 |
+
hidden_state=self.c_fc(hidden_state)
|
209 |
+
hidden_state=self.gelu(hidden_state)
|
210 |
+
hidden_state=self.c_proj(hidden_state)
|
211 |
+
return hidden_state
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
class ViTConfig:
|
216 |
+
def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
|
217 |
+
self.image_channel=image_channel
|
218 |
+
self.hidden_size=hidden_size
|
219 |
+
self.num_heads=num_heads
|
220 |
+
self.num_layers=num_layers
|
221 |
+
self.patch_size=patch_size
|
222 |
+
self.num_patches=num_patches
|
223 |
+
self.norm_eps=norm_eps
|
224 |
+
self.device=device
|
225 |
+
self.dtype=torch.float16
|
226 |
+
self.patch_token_num=self.hidden_size//self.patch_size**2+1
|
227 |
+
self.output_dim=output_dim
|
228 |
+
self.num_virtual_tokens=20
|
229 |
+
self.token_dim=self.hidden_size
|
230 |
+
self.encoder_hidden_size=self.hidden_size
|
231 |
+
|
232 |
+
config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
|
233 |
+
model=VisionTransformer(config)
|
__pycache__/Text_Encoder.cpython-312.pyc
ADDED
Binary file (14.6 kB). View file
|
|
__pycache__/VIT.cpython-312.pyc
ADDED
Binary file (17.7 kB). View file
|
|
preprocess.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51ff0f1d35da9d25c16b5a82957cfb43b76d01a94084c501ec4a9180dc4b53aa
|
3 |
+
size 1116
|
tokenize.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7cd84774d43b4d7513f250b615ecd579a9a0c852f3e011043330407f7ca93e1
|
3 |
+
size 37
|