Upload 3 files
Browse files- EmotionCLIP.py +353 -0
- preprocess.pkl +3 -0
- tokenize.pkl +3 -0
EmotionCLIP.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
#--------------------------------------
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
############# PUBLIC MODEL CLASS ################
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
#----------------------------------------
|
15 |
+
class PrefixEncoder(torch.nn.Module):
|
16 |
+
def __init__(self,config):
|
17 |
+
super(PrefixEncoder,self).__init__()
|
18 |
+
self.config=config
|
19 |
+
self.device=config.device
|
20 |
+
self.dtype=config.dtype
|
21 |
+
self.num_virtual_tokens=config.num_virtual_tokens
|
22 |
+
#self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype)
|
23 |
+
self.token_dim=config.token_dim
|
24 |
+
self.encoder_hidden_size=config.encoder_hidden_size
|
25 |
+
self.num_layers=config.num_layers
|
26 |
+
"""
|
27 |
+
self.transformer=torch.nn.Sequential(
|
28 |
+
torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
|
29 |
+
torch.nn.Tanh(),
|
30 |
+
torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
|
31 |
+
)
|
32 |
+
"""
|
33 |
+
self.prefix_embedding=nn.Parameter(torch.zeros(1,self.num_virtual_tokens,self.token_dim*2*self.num_layers,device=self.device,dtype=self.dtype),requires_grad=False)
|
34 |
+
def forward(self,batch_size):
|
35 |
+
"""
|
36 |
+
input_ids=input_ids.unsqueeze(0).expand(batch_size,self.num_virtual_tokens)
|
37 |
+
prefix_embedding=self.embedding(input_ids)
|
38 |
+
prefix_embedding=self.transformer(prefix_embedding)
|
39 |
+
self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
|
40 |
+
"""
|
41 |
+
#prefix_embedding=self.prefix_embedding.expand(b,self.num_virtual_tokens,self.token_dim*2*self.num_layers)
|
42 |
+
|
43 |
+
#prefix_embedding=prefix_embedding.contiguous().view(2,self.num_layers,prefix_embedding.shape[0],self.num_virtual_tokens,self.token_dim)
|
44 |
+
prefix_embedding=self.prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.token_dim*2*self.num_layers)
|
45 |
+
prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
|
46 |
+
prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
|
47 |
+
k,v=prefix_embedding.chunk(2,dim=0)
|
48 |
+
return (k.squeeze(0),v.squeeze(0))
|
49 |
+
|
50 |
+
class MultiHeadAttention(nn.Module):
|
51 |
+
def __init__(self,config):
|
52 |
+
super(MultiHeadAttention,self).__init__()
|
53 |
+
self.hidden_size=config.hidden_size
|
54 |
+
self.num_heads=config.num_heads
|
55 |
+
self.head_size=self.hidden_size//self.num_heads
|
56 |
+
#nn.Parameter包含weight和bias可训练参数
|
57 |
+
self.in_proj_weight=nn.Parameter(torch.zeros(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=True)
|
58 |
+
self.in_proj_bias=nn.Parameter(torch.zeros(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=True)
|
59 |
+
#self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
60 |
+
#self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
61 |
+
#self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
|
62 |
+
self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
|
63 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
64 |
+
b,n,c=hidden_state.shape
|
65 |
+
#q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
66 |
+
#k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
|
67 |
+
#v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
|
68 |
+
q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
|
69 |
+
if prefix_k is not None and prefix_v is not None:
|
70 |
+
#将前缀插入到序列之前
|
71 |
+
k=torch.cat((prefix_k,k),dim=1)
|
72 |
+
#print("model k :",k[:,0,0])
|
73 |
+
v=torch.cat((prefix_v,v),dim=1)
|
74 |
+
bk,nk,hk=k.shape
|
75 |
+
bq,nq,hq=q.shape
|
76 |
+
q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
|
77 |
+
k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
78 |
+
v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
79 |
+
attention_logits=F.scaled_dot_product_attention(q, k, v)
|
80 |
+
attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
|
81 |
+
attention_output=self.out_proj(attention_logits)
|
82 |
+
return attention_output
|
83 |
+
|
84 |
+
|
85 |
+
class QuickGELU(nn.Module):
|
86 |
+
def __init__(self):
|
87 |
+
super(QuickGELU,self).__init__()
|
88 |
+
def forward(self,x):
|
89 |
+
old_dtype=x.dtype
|
90 |
+
x=x.to(torch.float32)
|
91 |
+
return (x*torch.sigmoid(1.702*x)).to(old_dtype)
|
92 |
+
|
93 |
+
|
94 |
+
class MLP(nn.Module):
|
95 |
+
def __init__(self,config):
|
96 |
+
super(MLP,self).__init__()
|
97 |
+
self.hidden_size=config.hidden_size
|
98 |
+
self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
99 |
+
self.gelu=QuickGELU()
|
100 |
+
self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
101 |
+
def forward(self,hidden_state):
|
102 |
+
hidden_state=self.c_fc(hidden_state)
|
103 |
+
hidden_state=self.gelu(hidden_state)
|
104 |
+
hidden_state=self.c_proj(hidden_state)
|
105 |
+
return hidden_state
|
106 |
+
|
107 |
+
class ResidualAttentionBlock(nn.Module):
|
108 |
+
def __init__(self,config):
|
109 |
+
super(ResidualAttentionBlock,self).__init__()
|
110 |
+
self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
111 |
+
self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
112 |
+
#self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
|
113 |
+
self.attn=MultiHeadAttention(config)
|
114 |
+
self.mlp=MLP(config)
|
115 |
+
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
116 |
+
residual=hidden_state
|
117 |
+
hidden_state=self.ln_1(hidden_state)
|
118 |
+
hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
|
119 |
+
hidden_state=residual+hidden_state
|
120 |
+
residual=hidden_state
|
121 |
+
hidden_state=self.ln_2(hidden_state)
|
122 |
+
hidden_state=self.mlp(hidden_state)
|
123 |
+
hidden_state=residual+hidden_state
|
124 |
+
return hidden_state
|
125 |
+
|
126 |
+
class Transformer(nn.Module):
|
127 |
+
def __init__(self,config):
|
128 |
+
super(Transformer,self).__init__()
|
129 |
+
self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
|
130 |
+
self.prefix=PrefixEncoder(config)
|
131 |
+
#prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
|
132 |
+
#self.register_buffer("prefix_tokens",prefix_tokens)
|
133 |
+
def forward(self,hidden_state):
|
134 |
+
b,n,h=hidden_state.shape
|
135 |
+
prefix_k,prefix_v=self.prefix(b)
|
136 |
+
for index,resblock in enumerate(self.resblocks):
|
137 |
+
hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
|
138 |
+
return hidden_state
|
139 |
+
|
140 |
+
#-----------------------------------------
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
############### TEXT ECONDER ----> transformer ################
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
#-----------------------------------------
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
class TextEncoder_Config:
|
159 |
+
def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype):
|
160 |
+
self.vocab_size=vocab_size
|
161 |
+
self.max_position_embeddings=max_position_embeddings
|
162 |
+
self.hidden_size=hidden_size
|
163 |
+
self.num_layers=num_layers
|
164 |
+
self.num_heads=num_heads
|
165 |
+
self.device=device
|
166 |
+
self.dtype=dtype
|
167 |
+
self.norm_eps=1e-5
|
168 |
+
self.num_virtual_tokens=20
|
169 |
+
self.token_dim=hidden_size
|
170 |
+
self.encoder_hidden_size=hidden_size
|
171 |
+
textencoder_config=TextEncoder_Config(
|
172 |
+
vocab_size=49408,
|
173 |
+
max_position_embeddings=77,
|
174 |
+
hidden_size=512,
|
175 |
+
num_layers=12,
|
176 |
+
num_heads=8,
|
177 |
+
device=torch.device('cuda:0'),
|
178 |
+
dtype=torch.float16
|
179 |
+
)
|
180 |
+
|
181 |
+
Encoder_model=Transformer(textencoder_config)
|
182 |
+
|
183 |
+
#--------------------------------------------
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
################### VISION TRANSFORMER ##################
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
#--------------------------------------------
|
193 |
+
|
194 |
+
def position_embedding(x,position_ids):
|
195 |
+
hidden_size=x.size(2)
|
196 |
+
seq_len=x.size(1)
|
197 |
+
div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
|
198 |
+
positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
|
199 |
+
positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
|
200 |
+
positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
|
201 |
+
positional_encoding=positional_encoding.unsqueeze(0)
|
202 |
+
return positional_encoding
|
203 |
+
|
204 |
+
class VisionTransformer(nn.Module):
|
205 |
+
def __init__(self,config):
|
206 |
+
super(VisionTransformer,self).__init__()
|
207 |
+
self.image_channel=config.image_channel
|
208 |
+
self.hidden_size=config.hidden_size
|
209 |
+
self.norm_eps=config.norm_eps
|
210 |
+
self.patch_size=config.patch_size
|
211 |
+
self.output_dim=config.output_dim
|
212 |
+
self.dtype=config.dtype
|
213 |
+
self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
|
214 |
+
self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
|
215 |
+
self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
216 |
+
self.transformer=Transformer(config)
|
217 |
+
#self.position_ids=torch.arange(config.num_patches+1,dtype=torch.long,device=config.device)
|
218 |
+
#self.position_embeddings=nn.Parameter(torch.zeros(1,config.num_patches+1,config.hidden_size))
|
219 |
+
#nn.init.normal_(self.position_embeddings)
|
220 |
+
#clsToken,用于图像分类任务
|
221 |
+
#self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size,device=config.device))
|
222 |
+
#分类token不是可训练参数
|
223 |
+
self.class_embedding=nn.Parameter(torch.zeros(config.hidden_size,device=config.device),requires_grad=True)
|
224 |
+
#很明显这里的position_embedding也是一个可学习参数
|
225 |
+
self.positional_embedding=nn.Parameter(torch.zeros(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=True)
|
226 |
+
#可训练参数
|
227 |
+
self.proj=nn.Parameter(torch.zeros(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=True)
|
228 |
+
self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
229 |
+
def forward(self,hidden_state):
|
230 |
+
b,c,h,w=hidden_state.shape
|
231 |
+
#获得embedding向量
|
232 |
+
hidden_state=self.conv1(hidden_state)
|
233 |
+
hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
|
234 |
+
#添加cls token embedding
|
235 |
+
hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
|
236 |
+
#使用transformer原论文中的固定位置嵌入
|
237 |
+
#hidden_state=hidden_state+position_embedding(hidden_state,self.position_ids)
|
238 |
+
hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
|
239 |
+
hidden_state=self.ln_pre(hidden_state)
|
240 |
+
hidden_state=self.transformer(hidden_state)
|
241 |
+
#提取cls token输出
|
242 |
+
if self.num_virtual_tokens is not None:
|
243 |
+
hidden_state=hidden_state[:,self.num_virtual_tokens,:]
|
244 |
+
else:
|
245 |
+
hidden_state=hidden_state[:,0,:]
|
246 |
+
hidden_state=self.ln_post(hidden_state)
|
247 |
+
hidden_state=torch.matmul(hidden_state,self.proj)
|
248 |
+
return hidden_state
|
249 |
+
|
250 |
+
class ViTConfig:
|
251 |
+
def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
|
252 |
+
self.image_channel=image_channel
|
253 |
+
self.hidden_size=hidden_size
|
254 |
+
self.num_heads=num_heads
|
255 |
+
self.num_layers=num_layers
|
256 |
+
self.patch_size=patch_size
|
257 |
+
self.num_patches=num_patches
|
258 |
+
self.norm_eps=norm_eps
|
259 |
+
self.device=device
|
260 |
+
self.dtype=torch.float16
|
261 |
+
self.patch_token_num=self.hidden_size//self.patch_size**2+1
|
262 |
+
self.output_dim=output_dim
|
263 |
+
self.num_virtual_tokens=20
|
264 |
+
self.token_dim=self.hidden_size
|
265 |
+
self.encoder_hidden_size=self.hidden_size
|
266 |
+
|
267 |
+
config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
|
268 |
+
VIT_model=VisionTransformer(config)
|
269 |
+
|
270 |
+
|
271 |
+
#-------------------------------------------------
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
################## PrefixCLIP ###############
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
#------------------------------------------------
|
281 |
+
|
282 |
+
class CLIP(nn.Module):
|
283 |
+
def __init__(self,config):
|
284 |
+
super().__init__()
|
285 |
+
self.visual=VIT_model
|
286 |
+
self.device=config.device
|
287 |
+
self.dtype=config.dtype
|
288 |
+
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
|
289 |
+
self.transformer=Encoder_model
|
290 |
+
self.positional_embedding=nn.Parameter(torch.randn(config.max_position_embeddings,config.hidden_size,device=config.device))
|
291 |
+
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
|
292 |
+
self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
|
293 |
+
self.logit_scale=nn.Parameter(torch.ones([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=True)
|
294 |
+
def encode_image(self,img):
|
295 |
+
return self.visual(img)
|
296 |
+
def encode_text(self,text):
|
297 |
+
token_embedding=self.token_embedding(text)
|
298 |
+
position_embedding=self.positional_embedding[None,:text.shape[1],:].to(self.dtype)
|
299 |
+
text_embedding=token_embedding+position_embedding
|
300 |
+
text_embedding=self.transformer(text_embedding)
|
301 |
+
text_embedding=self.ln_final(text_embedding)
|
302 |
+
#传入的标记有
|
303 |
+
text_embedding=text_embedding[torch.arange(text.shape[0]),text.argmax(dim=-1)]
|
304 |
+
[email protected]_projection.to(self.dtype)
|
305 |
+
|
306 |
+
return text_embedding
|
307 |
+
|
308 |
+
def forward(self,image,text):
|
309 |
+
image_features=self.encode_image(image)
|
310 |
+
text_features=self.encode_text(text)
|
311 |
+
# normalized features
|
312 |
+
image_features=image_features/image_features.norm(dim=-1,keepdim=True)
|
313 |
+
text_features=text_features/text_features.norm(dim=-1,keepdim=True)
|
314 |
+
# cosine similarity as logits
|
315 |
+
logit_scale=self.logit_scale.exp()
|
316 |
+
logits_per_image=logit_scale*image_features@text_features.t()
|
317 |
+
logits_per_text=logits_per_image.t()
|
318 |
+
# shape = [global_batch_size, global_batch_size]
|
319 |
+
return logits_per_image,logits_per_text
|
320 |
+
|
321 |
+
class CLIPConfig:
|
322 |
+
def __init__(self):
|
323 |
+
self.vocab_size=49408
|
324 |
+
self.hidden_size=512
|
325 |
+
self.max_position_embeddings=77
|
326 |
+
self.num_hidden_layers=12
|
327 |
+
self.num_attention_heads=8
|
328 |
+
self.layer_norm_eps=1e-5
|
329 |
+
self.activation_function="Quickgelu"
|
330 |
+
self.dtype=torch.float16
|
331 |
+
self.device=torch.device("cuda:0")
|
332 |
+
self.logit_scale_init=4.6052
|
333 |
+
self.num_virtual_tokens=20
|
334 |
+
self.token_dim=self.hidden_size
|
335 |
+
self.encoder_hidden_size=self.hidden_size
|
336 |
+
CLIPconfig=CLIPConfig()
|
337 |
+
model=CLIP(CLIPconfig)
|
338 |
+
#加载预训练权重
|
339 |
+
model.load_state_dict(torch.load(r'./Mix_CLIP.pth',weights_only=True),strict=False)
|
340 |
+
|
341 |
+
#---------------------------------------------
|
342 |
+
|
343 |
+
########### PreProcess Pipelines ##########
|
344 |
+
|
345 |
+
#-------------------------------------------------
|
346 |
+
|
347 |
+
import pickle
|
348 |
+
with open('./preprocess.pkl','rb') as f:
|
349 |
+
preprocess = pickle.load(f)
|
350 |
+
with open('./tokenize.pkl','rb') as f:
|
351 |
+
tokenizer=pickle.load(f)
|
352 |
+
|
353 |
+
|
preprocess.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51ff0f1d35da9d25c16b5a82957cfb43b76d01a94084c501ec4a9180dc4b53aa
|
3 |
+
size 1116
|
tokenize.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7cd84774d43b4d7513f250b615ecd579a9a0c852f3e011043330407f7ca93e1
|
3 |
+
size 37
|