jiangchengchengNLP commited on
Commit
ec091c3
·
verified ·
1 Parent(s): 80b0321

Upload 3 files

Browse files
Files changed (3) hide show
  1. EmotionCLIP.py +353 -0
  2. preprocess.pkl +3 -0
  3. tokenize.pkl +3 -0
EmotionCLIP.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from torch.nn import functional as F
5
+
6
+ #--------------------------------------
7
+
8
+
9
+
10
+ ############# PUBLIC MODEL CLASS ################
11
+
12
+
13
+
14
+ #----------------------------------------
15
+ class PrefixEncoder(torch.nn.Module):
16
+ def __init__(self,config):
17
+ super(PrefixEncoder,self).__init__()
18
+ self.config=config
19
+ self.device=config.device
20
+ self.dtype=config.dtype
21
+ self.num_virtual_tokens=config.num_virtual_tokens
22
+ #self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype)
23
+ self.token_dim=config.token_dim
24
+ self.encoder_hidden_size=config.encoder_hidden_size
25
+ self.num_layers=config.num_layers
26
+ """
27
+ self.transformer=torch.nn.Sequential(
28
+ torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
29
+ torch.nn.Tanh(),
30
+ torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
31
+ )
32
+ """
33
+ self.prefix_embedding=nn.Parameter(torch.zeros(1,self.num_virtual_tokens,self.token_dim*2*self.num_layers,device=self.device,dtype=self.dtype),requires_grad=False)
34
+ def forward(self,batch_size):
35
+ """
36
+ input_ids=input_ids.unsqueeze(0).expand(batch_size,self.num_virtual_tokens)
37
+ prefix_embedding=self.embedding(input_ids)
38
+ prefix_embedding=self.transformer(prefix_embedding)
39
+ self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
40
+ """
41
+ #prefix_embedding=self.prefix_embedding.expand(b,self.num_virtual_tokens,self.token_dim*2*self.num_layers)
42
+
43
+ #prefix_embedding=prefix_embedding.contiguous().view(2,self.num_layers,prefix_embedding.shape[0],self.num_virtual_tokens,self.token_dim)
44
+ prefix_embedding=self.prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.token_dim*2*self.num_layers)
45
+ prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
46
+ prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
47
+ k,v=prefix_embedding.chunk(2,dim=0)
48
+ return (k.squeeze(0),v.squeeze(0))
49
+
50
+ class MultiHeadAttention(nn.Module):
51
+ def __init__(self,config):
52
+ super(MultiHeadAttention,self).__init__()
53
+ self.hidden_size=config.hidden_size
54
+ self.num_heads=config.num_heads
55
+ self.head_size=self.hidden_size//self.num_heads
56
+ #nn.Parameter包含weight和bias可训练参数
57
+ self.in_proj_weight=nn.Parameter(torch.zeros(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=True)
58
+ self.in_proj_bias=nn.Parameter(torch.zeros(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=True)
59
+ #self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
60
+ #self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
61
+ #self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
62
+ self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
63
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
64
+ b,n,c=hidden_state.shape
65
+ #q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
66
+ #k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
67
+ #v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
68
+ q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
69
+ if prefix_k is not None and prefix_v is not None:
70
+ #将前缀插入到序列之前
71
+ k=torch.cat((prefix_k,k),dim=1)
72
+ #print("model k :",k[:,0,0])
73
+ v=torch.cat((prefix_v,v),dim=1)
74
+ bk,nk,hk=k.shape
75
+ bq,nq,hq=q.shape
76
+ q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
77
+ k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
78
+ v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
79
+ attention_logits=F.scaled_dot_product_attention(q, k, v)
80
+ attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
81
+ attention_output=self.out_proj(attention_logits)
82
+ return attention_output
83
+
84
+
85
+ class QuickGELU(nn.Module):
86
+ def __init__(self):
87
+ super(QuickGELU,self).__init__()
88
+ def forward(self,x):
89
+ old_dtype=x.dtype
90
+ x=x.to(torch.float32)
91
+ return (x*torch.sigmoid(1.702*x)).to(old_dtype)
92
+
93
+
94
+ class MLP(nn.Module):
95
+ def __init__(self,config):
96
+ super(MLP,self).__init__()
97
+ self.hidden_size=config.hidden_size
98
+ self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
99
+ self.gelu=QuickGELU()
100
+ self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
101
+ def forward(self,hidden_state):
102
+ hidden_state=self.c_fc(hidden_state)
103
+ hidden_state=self.gelu(hidden_state)
104
+ hidden_state=self.c_proj(hidden_state)
105
+ return hidden_state
106
+
107
+ class ResidualAttentionBlock(nn.Module):
108
+ def __init__(self,config):
109
+ super(ResidualAttentionBlock,self).__init__()
110
+ self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
111
+ self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
112
+ #self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
113
+ self.attn=MultiHeadAttention(config)
114
+ self.mlp=MLP(config)
115
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
116
+ residual=hidden_state
117
+ hidden_state=self.ln_1(hidden_state)
118
+ hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
119
+ hidden_state=residual+hidden_state
120
+ residual=hidden_state
121
+ hidden_state=self.ln_2(hidden_state)
122
+ hidden_state=self.mlp(hidden_state)
123
+ hidden_state=residual+hidden_state
124
+ return hidden_state
125
+
126
+ class Transformer(nn.Module):
127
+ def __init__(self,config):
128
+ super(Transformer,self).__init__()
129
+ self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
130
+ self.prefix=PrefixEncoder(config)
131
+ #prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
132
+ #self.register_buffer("prefix_tokens",prefix_tokens)
133
+ def forward(self,hidden_state):
134
+ b,n,h=hidden_state.shape
135
+ prefix_k,prefix_v=self.prefix(b)
136
+ for index,resblock in enumerate(self.resblocks):
137
+ hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
138
+ return hidden_state
139
+
140
+ #-----------------------------------------
141
+
142
+
143
+
144
+
145
+ ############### TEXT ECONDER ----> transformer ################
146
+
147
+
148
+
149
+
150
+
151
+ #-----------------------------------------
152
+
153
+
154
+
155
+
156
+
157
+
158
+ class TextEncoder_Config:
159
+ def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype):
160
+ self.vocab_size=vocab_size
161
+ self.max_position_embeddings=max_position_embeddings
162
+ self.hidden_size=hidden_size
163
+ self.num_layers=num_layers
164
+ self.num_heads=num_heads
165
+ self.device=device
166
+ self.dtype=dtype
167
+ self.norm_eps=1e-5
168
+ self.num_virtual_tokens=20
169
+ self.token_dim=hidden_size
170
+ self.encoder_hidden_size=hidden_size
171
+ textencoder_config=TextEncoder_Config(
172
+ vocab_size=49408,
173
+ max_position_embeddings=77,
174
+ hidden_size=512,
175
+ num_layers=12,
176
+ num_heads=8,
177
+ device=torch.device('cuda:0'),
178
+ dtype=torch.float16
179
+ )
180
+
181
+ Encoder_model=Transformer(textencoder_config)
182
+
183
+ #--------------------------------------------
184
+
185
+
186
+
187
+ ################### VISION TRANSFORMER ##################
188
+
189
+
190
+
191
+
192
+ #--------------------------------------------
193
+
194
+ def position_embedding(x,position_ids):
195
+ hidden_size=x.size(2)
196
+ seq_len=x.size(1)
197
+ div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
198
+ positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
199
+ positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
200
+ positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
201
+ positional_encoding=positional_encoding.unsqueeze(0)
202
+ return positional_encoding
203
+
204
+ class VisionTransformer(nn.Module):
205
+ def __init__(self,config):
206
+ super(VisionTransformer,self).__init__()
207
+ self.image_channel=config.image_channel
208
+ self.hidden_size=config.hidden_size
209
+ self.norm_eps=config.norm_eps
210
+ self.patch_size=config.patch_size
211
+ self.output_dim=config.output_dim
212
+ self.dtype=config.dtype
213
+ self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
214
+ self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
215
+ self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
216
+ self.transformer=Transformer(config)
217
+ #self.position_ids=torch.arange(config.num_patches+1,dtype=torch.long,device=config.device)
218
+ #self.position_embeddings=nn.Parameter(torch.zeros(1,config.num_patches+1,config.hidden_size))
219
+ #nn.init.normal_(self.position_embeddings)
220
+ #clsToken,用于图像分类任务
221
+ #self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size,device=config.device))
222
+ #分类token不是可训练参数
223
+ self.class_embedding=nn.Parameter(torch.zeros(config.hidden_size,device=config.device),requires_grad=True)
224
+ #很明显这里的position_embedding也是一个可学习参数
225
+ self.positional_embedding=nn.Parameter(torch.zeros(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=True)
226
+ #可训练参数
227
+ self.proj=nn.Parameter(torch.zeros(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=True)
228
+ self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
229
+ def forward(self,hidden_state):
230
+ b,c,h,w=hidden_state.shape
231
+ #获得embedding向量
232
+ hidden_state=self.conv1(hidden_state)
233
+ hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
234
+ #添加cls token embedding
235
+ hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
236
+ #使用transformer原论文中的固定位置嵌入
237
+ #hidden_state=hidden_state+position_embedding(hidden_state,self.position_ids)
238
+ hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
239
+ hidden_state=self.ln_pre(hidden_state)
240
+ hidden_state=self.transformer(hidden_state)
241
+ #提取cls token输出
242
+ if self.num_virtual_tokens is not None:
243
+ hidden_state=hidden_state[:,self.num_virtual_tokens,:]
244
+ else:
245
+ hidden_state=hidden_state[:,0,:]
246
+ hidden_state=self.ln_post(hidden_state)
247
+ hidden_state=torch.matmul(hidden_state,self.proj)
248
+ return hidden_state
249
+
250
+ class ViTConfig:
251
+ def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
252
+ self.image_channel=image_channel
253
+ self.hidden_size=hidden_size
254
+ self.num_heads=num_heads
255
+ self.num_layers=num_layers
256
+ self.patch_size=patch_size
257
+ self.num_patches=num_patches
258
+ self.norm_eps=norm_eps
259
+ self.device=device
260
+ self.dtype=torch.float16
261
+ self.patch_token_num=self.hidden_size//self.patch_size**2+1
262
+ self.output_dim=output_dim
263
+ self.num_virtual_tokens=20
264
+ self.token_dim=self.hidden_size
265
+ self.encoder_hidden_size=self.hidden_size
266
+
267
+ config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
268
+ VIT_model=VisionTransformer(config)
269
+
270
+
271
+ #-------------------------------------------------
272
+
273
+
274
+
275
+
276
+ ################## PrefixCLIP ###############
277
+
278
+
279
+
280
+ #------------------------------------------------
281
+
282
+ class CLIP(nn.Module):
283
+ def __init__(self,config):
284
+ super().__init__()
285
+ self.visual=VIT_model
286
+ self.device=config.device
287
+ self.dtype=config.dtype
288
+ self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
289
+ self.transformer=Encoder_model
290
+ self.positional_embedding=nn.Parameter(torch.randn(config.max_position_embeddings,config.hidden_size,device=config.device))
291
+ self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
292
+ self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
293
+ self.logit_scale=nn.Parameter(torch.ones([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=True)
294
+ def encode_image(self,img):
295
+ return self.visual(img)
296
+ def encode_text(self,text):
297
+ token_embedding=self.token_embedding(text)
298
+ position_embedding=self.positional_embedding[None,:text.shape[1],:].to(self.dtype)
299
+ text_embedding=token_embedding+position_embedding
300
+ text_embedding=self.transformer(text_embedding)
301
+ text_embedding=self.ln_final(text_embedding)
302
+ #传入的标记有
303
+ text_embedding=text_embedding[torch.arange(text.shape[0]),text.argmax(dim=-1)]
304
+ [email protected]_projection.to(self.dtype)
305
+
306
+ return text_embedding
307
+
308
+ def forward(self,image,text):
309
+ image_features=self.encode_image(image)
310
+ text_features=self.encode_text(text)
311
+ # normalized features
312
+ image_features=image_features/image_features.norm(dim=-1,keepdim=True)
313
+ text_features=text_features/text_features.norm(dim=-1,keepdim=True)
314
+ # cosine similarity as logits
315
+ logit_scale=self.logit_scale.exp()
316
+ logits_per_image=logit_scale*image_features@text_features.t()
317
+ logits_per_text=logits_per_image.t()
318
+ # shape = [global_batch_size, global_batch_size]
319
+ return logits_per_image,logits_per_text
320
+
321
+ class CLIPConfig:
322
+ def __init__(self):
323
+ self.vocab_size=49408
324
+ self.hidden_size=512
325
+ self.max_position_embeddings=77
326
+ self.num_hidden_layers=12
327
+ self.num_attention_heads=8
328
+ self.layer_norm_eps=1e-5
329
+ self.activation_function="Quickgelu"
330
+ self.dtype=torch.float16
331
+ self.device=torch.device("cuda:0")
332
+ self.logit_scale_init=4.6052
333
+ self.num_virtual_tokens=20
334
+ self.token_dim=self.hidden_size
335
+ self.encoder_hidden_size=self.hidden_size
336
+ CLIPconfig=CLIPConfig()
337
+ model=CLIP(CLIPconfig)
338
+ #加载预训练权重
339
+ model.load_state_dict(torch.load(r'./Mix_CLIP.pth',weights_only=True),strict=False)
340
+
341
+ #---------------------------------------------
342
+
343
+ ########### PreProcess Pipelines ##########
344
+
345
+ #-------------------------------------------------
346
+
347
+ import pickle
348
+ with open('./preprocess.pkl','rb') as f:
349
+ preprocess = pickle.load(f)
350
+ with open('./tokenize.pkl','rb') as f:
351
+ tokenizer=pickle.load(f)
352
+
353
+
preprocess.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51ff0f1d35da9d25c16b5a82957cfb43b76d01a94084c501ec4a9180dc4b53aa
3
+ size 1116
tokenize.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7cd84774d43b4d7513f250b615ecd579a9a0c852f3e011043330407f7ca93e1
3
+ size 37