jiangchengchengNLP commited on
Commit
b11ecdd
·
verified ·
1 Parent(s): 66704c9

Upload my files

Browse files
.ipynb_checkpoints/EmotionCLIP-checkpoint.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱
3
+
4
+ 文本生成仍然考虑因果掩码。
5
+ """
6
+ import torch.nn.functional as F
7
+ from VIT import model as VIT
8
+ from Text_Encoder import text_encoder as transformer
9
+ import torch.nn as nn
10
+ import torch
11
+ from Text_Encoder import MLP
12
+
13
+ class Prompt_block(nn.Module):
14
+ def __init__(self,config):
15
+ super(Prompt_block,self).__init__()
16
+ self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device)
17
+ def forward(self,text_embeddings):
18
+ b,_,_=text_embeddings.size()
19
+ n,dim=self.prompt_embedding.weight.size()
20
+ """
21
+ new_embeddings=[]
22
+ for batch,index_ in enumerate(index):
23
+ text_embedding=text_embeddings[0]
24
+ text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0)
25
+ new_embeddings.append(text_embedding)
26
+ stacked_embedding= torch.stack(new_embeddings, dim=0)
27
+ return stacked_embedding
28
+ """
29
+ text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1)
30
+ return text_embeddings
31
+
32
+
33
+
34
+
35
+
36
+ class CLIP(nn.Module):
37
+ def __init__(self,config):
38
+ super().__init__()
39
+ self.visual=VIT
40
+ self.device=config.device
41
+ self.dtype=config.dtype
42
+ self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
43
+ self.max_position_embeddings=config.max_position_embeddings
44
+ self.prompt_num=config.prompt_num
45
+ self.transformer=transformer
46
+ #增加一个prompt block
47
+ self.prompt_block=Prompt_block(config)
48
+ self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device))
49
+ self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
50
+ self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
51
+ self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False)
52
+ def encode_image(self,img,use_emotion=True):
53
+ cls_embedding=self.visual(img,use_emotion)
54
+ #cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512]
55
+ return cls_embedding
56
+ def encode_text(self,text,use_emotion=True):
57
+ #预留20token的位置
58
+ b,n=text.size()
59
+ index=text.argmax(dim=-1)
60
+ text_embedding=self.token_embedding(text)
61
+ #text_embedding=self.prompt_block(index,text_embedding)
62
+ if n==self.max_position_embeddings-self.prompt_num:
63
+ text_embedding=self.prompt_block(text_embedding)
64
+ index=index+torch.tensor(20,device=index.device,dtype=index.dtype)
65
+ position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype)
66
+ text_embedding=position_embedding+text_embedding
67
+ text_embedding=self.transformer(text_embedding,use_emotion=use_emotion)
68
+ text_embedding=self.ln_final(text_embedding)
69
+ #传入的标记有
70
+ #print(index[0],index_new[0],text_embedding.shape)
71
+ text_embedding=text_embedding[torch.arange(text.shape[0]),index]
72
+ [email protected]_projection.to(self.dtype)
73
+
74
+ return text_embedding
75
+
76
+ def forward(self,image,text,use_emotion=True):
77
+ image_features=self.encode_image(image,use_emotion)
78
+ text_features=self.encode_text(text,use_emotion)
79
+ # normalized features
80
+ image_features=image_features/image_features.norm(dim=-1,keepdim=True)
81
+ text_features=text_features/text_features.norm(dim=-1,keepdim=True)
82
+ # cosine similarity as logits
83
+ logit_scale=self.logit_scale.exp()
84
+ logits_per_image=logit_scale*image_features@text_features.t()
85
+ logits_per_text=logits_per_image.t()
86
+ # shape = [global_batch_size, global_batch_size]
87
+ return logits_per_image,logits_per_text
88
+
89
+ class Config:
90
+ def __init__(self):
91
+ self.vocab_size=49408
92
+ self.image_dim=768
93
+ self.num_patches=49
94
+ self.patch_size=32
95
+ self.hidden_size=512
96
+ self.prompt_num=20
97
+ self.max_position_embeddings=77
98
+ self.num_hidden_layers=12
99
+ self.num_attention_heads=8
100
+ self.head_size=64
101
+ self.layer_norm_eps=1e-5
102
+ self.activation_function="Quickgelu"
103
+ self.dtype=torch.float16
104
+ self.device=torch.device("cuda:0")
105
+ self.logit_scale_init=4.6052
106
+ self.num_virtual_tokens=20
107
+ self.token_dim=self.hidden_size
108
+ self.encoder_hidden_size=self.hidden_size
109
+
110
+ config=Config()
111
+ model=CLIP(config)
112
+ #加载预训练权重
113
+ model.load_state_dict(torch.load(r'/root/autodl-tmp/true_Emoset/EmotionCLIP_v2.bin',weights_only=True,map_location='cpu'),strict=True)
114
+ """
115
+ for name, param in model.named_parameters():
116
+ if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix'
117
+ print(name,"'s requires_grad turn off.")
118
+ param.requires_grad = False # 冻结该参数
119
+ else:
120
+ print(name,"'s requires_grad turn on.")
121
+ param.requires_grad = True # 允许该参数进行训练
122
+ """
123
+
124
+ #编译模型
125
+ #model=torch.compile(model)
126
+ import pickle
127
+ from PIL import Image
128
+ import clip
129
+ with open('./preprocess.pkl','rb') as f:
130
+ preprocess = pickle.load(f)
131
+ with open('./tokenize.pkl','rb') as f:
132
+ tokenizer=pickle.load(f)
133
+ device=config.device
134
+ image = preprocess(Image.open("spider.jpg")).unsqueeze(0).to(device)
135
+ text = tokenizer(["This picture conveys a sense of fear", "This picture conveys a sense of contentment", "This picture conveys a sense of anger","This picture conveys a sense of sadness","This picture conveys a sense of neutral","This picture conveys a sense of disgust","This picture conveys a sense of excitement","This picture conveys a sense of awe","This picture conveys a sense of amusement"],context_length=57).to(device)
136
+ #context_length=57
137
+ with torch.no_grad():
138
+ logits_per_image, logits_per_text = model(image.to(config.dtype), text)
139
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
140
+ print("情感识别:",probs)
141
+ #保存合并前缀的权重
142
+ import torch
143
+ torch.save(model.state_dict(),'./upload/EmotionCLIP-V2.pth')
144
+ #泛化性能
145
+ """
146
+ text=tokenizer(['This is a spider.','This is a dog','This is a cat'],context_length=57).to(device)
147
+ with torch.no_grad():
148
+ logits_per_image, logits_per_text = model(image.to(config.dtype), text,use_emotion=False)
149
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
150
+ print("泛化识别:",probs)
151
+ """
.ipynb_checkpoints/Text_Encoder-checkpoint.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from torch.nn.attention import SDPBackend, sdpa_kernel
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class PrefixEncoder(torch.nn.Module):
9
+ def __init__(self,config):
10
+ super(PrefixEncoder,self).__init__()
11
+ self.config=config
12
+ self.device=config.device
13
+ self.dtype=config.dtype
14
+ self.num_virtual_tokens=config.num_virtual_tokens
15
+ self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype)
16
+ self.token_dim=config.token_dim
17
+ self.encoder_hidden_size=config.encoder_hidden_size
18
+ self.num_layers=config.num_layers
19
+ self.transformer=torch.nn.Sequential(
20
+ torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
21
+ torch.nn.Tanh(),
22
+ torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
23
+ )
24
+ def forward(self,input_ids,batch_size):
25
+ input_ids=input_ids.unsqueeze(0)
26
+ prefix_embedding=self.embedding(input_ids)
27
+ prefix_embedding=self.transformer(prefix_embedding)
28
+ self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
29
+ prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
30
+ prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
31
+ prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
32
+ del self.embedding
33
+ del self.transformer
34
+ k,v=prefix_embedding.chunk(2,dim=0)
35
+ return (k.squeeze(0),v.squeeze(0))
36
+
37
+
38
+ class Transformer(nn.Module):
39
+ def __init__(self,config):
40
+ super(Transformer,self).__init__()
41
+ self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
42
+ self.prefix=PrefixEncoder(config)
43
+ prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
44
+ self.register_buffer("prefix_tokens",prefix_tokens)
45
+ def forward(self,hidden_state,use_emotion):
46
+ if use_emotion:
47
+ #print("激活text transformer prefix.")
48
+ b,n,h=hidden_state.shape
49
+ prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
50
+ for index,resblock in enumerate(self.resblocks):
51
+ hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
52
+ return hidden_state
53
+ else:
54
+ for index,resblock in enumerate(self.resblocks):
55
+ hidden_state=resblock(hidden_state)
56
+ return hidden_state
57
+
58
+
59
+
60
+
61
+
62
+
63
+ class ResidualAttentionBlock(nn.Module):
64
+ def __init__(self,config):
65
+ super(ResidualAttentionBlock,self).__init__()
66
+ self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
67
+ self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
68
+ #self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
69
+ self.attn=MultiHeadAttention(config)
70
+ self.mlp=MLP(config)
71
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
72
+ residual=hidden_state
73
+ hidden_state=self.ln_1(hidden_state)
74
+ hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
75
+ hidden_state=residual+hidden_state
76
+ residual=hidden_state
77
+ hidden_state=self.ln_2(hidden_state)
78
+ hidden_state=self.mlp(hidden_state)
79
+ hidden_state=residual+hidden_state
80
+ return hidden_state
81
+
82
+ class MultiHeadAttention(nn.Module):
83
+ def __init__(self,config):
84
+ super(MultiHeadAttention,self).__init__()
85
+ self.hidden_size=config.hidden_size
86
+ self.num_heads=config.num_heads
87
+ self.head_size=self.hidden_size//self.num_heads
88
+ #nn.Parameter包含weight和bias可训练参数
89
+ self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
90
+ self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
91
+ #self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
92
+ #self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
93
+ #self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
94
+ self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
95
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
96
+ b,n,c=hidden_state.shape
97
+ #q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
98
+ #k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
99
+ #v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
100
+ q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
101
+ if prefix_k is not None and prefix_v is not None:
102
+ #将前缀插入到序列之前
103
+ k=torch.cat((prefix_k,k),dim=1)
104
+ v=torch.cat((prefix_v,v),dim=1)
105
+ #print("model origin k :",k[:,0,0])
106
+ bk,nk,hk=k.shape
107
+ bq,nq,hq=q.shape
108
+ q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
109
+ k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
110
+ v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
111
+ attention_logits=F.scaled_dot_product_attention(q, k, v)
112
+ attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
113
+ attention_output=self.out_proj(attention_logits)
114
+ return attention_output
115
+
116
+
117
+ class GELU(nn.Module):
118
+ """
119
+ 误差函数erf:
120
+ erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
121
+ 其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说:
122
+ x是误差函数的输入参数,表示积分的上限
123
+ t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
124
+ e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
125
+ 通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
126
+ """
127
+ def forward(self,x):
128
+ return 0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))
129
+
130
+ class QuickGELU(nn.Module):
131
+ def __init__(self):
132
+ super(QuickGELU,self).__init__()
133
+ def forward(self,x):
134
+ old_dtype=x.dtype
135
+ x=x.to(torch.float32)
136
+ return (x*torch.sigmoid(1.702*x)).to(old_dtype)
137
+
138
+
139
+ class MLP(nn.Module):
140
+ def __init__(self,config):
141
+ super(MLP,self).__init__()
142
+ self.hidden_size=config.hidden_size
143
+ self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
144
+ self.gelu=QuickGELU()
145
+ self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
146
+ def forward(self,hidden_state):
147
+ hidden_state=self.c_fc(hidden_state)
148
+ hidden_state=self.gelu(hidden_state)
149
+ hidden_state=self.c_proj(hidden_state)
150
+ return hidden_state
151
+
152
+ class Config:
153
+ def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype):
154
+ self.vocab_size=vocab_size
155
+ self.max_position_embeddings=max_position_embeddings
156
+ self.hidden_size=hidden_size
157
+ self.num_layers=num_layers
158
+ self.num_heads=num_heads
159
+ self.device=device
160
+ self.dtype=dtype
161
+ self.norm_eps=1e-5
162
+ self.num_virtual_tokens=20
163
+ self.token_dim=hidden_size
164
+ self.encoder_hidden_size=hidden_size
165
+ config=Config(
166
+ vocab_size=49408,
167
+ max_position_embeddings=77,
168
+ hidden_size=512,
169
+ num_layers=12,
170
+ num_heads=8,
171
+ device=torch.device('cuda:0'),
172
+ dtype=torch.float16
173
+ )
174
+ class TextEncoder(nn.Module):
175
+ def __init__(self,config):
176
+ super(TextEncoder,self).__init__()
177
+ self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,device=config.device,dtype=config.dtype)
178
+ self.positional_embedding=nn.Parameter(torch.zeros(config.max_position_embeddings,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
179
+ self.transformer=Transformer(config)
180
+ self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
181
+ def forward(self,input_ids):
182
+ b,n=input_ids.shape
183
+ prompt_embedding,token_embeddings=self.token_embedding(input_ids)
184
+ position_ids=torch.arange(n,device=config.device,dtype=config.dtype).unsqueeze(0).expand(b,n)
185
+ position_embeddings=self.positional_embedding[position_ids]
186
+ embeddings=token_embeddings+position_embeddings
187
+ embeddings=torch.cat((prompt_embedding,embeddings),dim=1)
188
+ embeddings=self.transformer(embeddings)
189
+ embeddings=self.ln_final(embeddings)
190
+ return embeddings
191
+
192
+ text_encoder=Transformer(config)
.ipynb_checkpoints/VIT-checkpoint.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import math
6
+ import os
7
+ import sys
8
+
9
+ #huggingface实现的前缀微调
10
+ class PrefixEncoder(torch.nn.Module):
11
+ def __init__(self,config):
12
+ super(PrefixEncoder,self).__init__()
13
+ self.config=config
14
+ self.device=config.device
15
+ self.dtype=config.dtype
16
+ self.num_virtual_tokens=config.num_virtual_tokens
17
+ self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype)
18
+ self.token_dim=config.token_dim
19
+ self.encoder_hidden_size=config.encoder_hidden_size
20
+ self.num_layers=config.num_layers
21
+ self.transformer=torch.nn.Sequential(
22
+ torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
23
+ torch.nn.Tanh(),
24
+ torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
25
+ )
26
+ def forward(self,input_ids,batch_size):
27
+ input_ids=input_ids.unsqueeze(0)
28
+ prefix_embedding=self.embedding(input_ids)
29
+ prefix_embedding=self.transformer(prefix_embedding)
30
+ self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
31
+ prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
32
+ prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
33
+ prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
34
+ del self.embedding
35
+ del self.transformer
36
+ k,v=prefix_embedding.chunk(2,dim=0)
37
+ return (k.squeeze(0),v.squeeze(0))
38
+
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ import math
43
+ from torch.nn.attention import SDPBackend, sdpa_kernel
44
+ from torch.nn import functional as F
45
+ def position_embedding(x,position_ids):
46
+ hidden_size=x.size(2)
47
+ seq_len=x.size(1)
48
+ div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
49
+ positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
50
+ positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
51
+ positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
52
+ positional_encoding=positional_encoding.unsqueeze(0)
53
+ return positional_encoding
54
+
55
+
56
+
57
+
58
+ class VisionTransformer(nn.Module):
59
+ def __init__(self,config):
60
+ super(VisionTransformer,self).__init__()
61
+ self.image_channel=config.image_channel
62
+ self.hidden_size=config.hidden_size
63
+ self.norm_eps=config.norm_eps
64
+ self.patch_size=config.patch_size
65
+ self.output_dim=config.output_dim
66
+ self.dtype=config.dtype
67
+ self.num_patches=config.num_patches
68
+ self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
69
+ self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
70
+ self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
71
+ self.transformer=Transformer(config)
72
+ #self.position_ids=torch.arange(config.num_patches+1,dtype=torch.long,device=config.device)
73
+ #self.position_embeddings=nn.Parameter(torch.zeros(1,config.num_patches+1,config.hidden_size))
74
+ #nn.init.normal_(self.position_embeddings)
75
+ #clsToken,用于图像分类任务
76
+ #self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size,device=config.device))
77
+ #分类token不是可训练参数
78
+ self.class_embedding=nn.Parameter(torch.empty(config.hidden_size,device=config.device),requires_grad=False)
79
+ #很明显这里的position_embedding也是一个可学习参数
80
+ self.positional_embedding=nn.Parameter(torch.empty(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=False)
81
+ #可训练参数
82
+ self.proj=nn.Parameter(torch.empty(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=False)
83
+ self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
84
+ def forward(self,hidden_state,use_emotion):
85
+ b,c,h,w=hidden_state.shape
86
+ #获得embedding向量
87
+ hidden_state=self.conv1(hidden_state)
88
+ hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
89
+ #添加cls token embedding
90
+ hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
91
+ #使用transformer原论文中的固定位置嵌入
92
+ #hidden_state=hidden_state+position_embedding(hidden_state,self.position_ids)
93
+ hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
94
+ hidden_state=self.ln_pre(hidden_state)
95
+ hidden_state=self.transformer(hidden_state,use_emotion)
96
+ #提取cls token输出 与image patch输出
97
+ cls_state=hidden_state[:,0,:]
98
+ cls_state=self.ln_post(cls_state)
99
+ cls_state=torch.matmul(cls_state,self.proj)
100
+ #image_state=hidden_state[:,1:,:]
101
+ #image_state size (batch_size,49,768)
102
+ return cls_state
103
+
104
+ class Transformer(nn.Module):
105
+ def __init__(self,config):
106
+ super(Transformer,self).__init__()
107
+ self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
108
+ self.prefix=PrefixEncoder(config)
109
+ prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
110
+ self.register_buffer("prefix_tokens",prefix_tokens)
111
+ def forward(self,hidden_state,use_emotion):
112
+ if use_emotion:
113
+ b,n,h=hidden_state.shape
114
+ prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
115
+ for index,resblock in enumerate(self.resblocks):
116
+ #在每一层之前提取前缀向量输入到resblock中进行拼接
117
+ hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
118
+ return hidden_state
119
+ else:
120
+ for index,resblock in enumerate(self.resblocks):
121
+ #在每一层之前提取前缀向量输入到resblock中进行拼接
122
+ hidden_state=resblock(hidden_state)
123
+ return hidden_state
124
+
125
+
126
+
127
+
128
+
129
+
130
+ class ResidualAttentionBlock(nn.Module):
131
+ def __init__(self,config):
132
+ super(ResidualAttentionBlock,self).__init__()
133
+ self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
134
+ self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
135
+ #self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
136
+ self.attn=MultiHeadAttention(config)
137
+ self.mlp=MLP(config)
138
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
139
+ residual=hidden_state
140
+ hidden_state=self.ln_1(hidden_state)
141
+ hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
142
+ hidden_state=residual+hidden_state
143
+ residual=hidden_state
144
+ hidden_state=self.ln_2(hidden_state)
145
+ hidden_state=self.mlp(hidden_state)
146
+ hidden_state=residual+hidden_state
147
+ return hidden_state
148
+
149
+ class MultiHeadAttention(nn.Module):
150
+ def __init__(self,config):
151
+ super(MultiHeadAttention,self).__init__()
152
+ self.hidden_size=config.hidden_size
153
+ self.num_heads=config.num_heads
154
+ self.head_size=self.hidden_size//self.num_heads
155
+ #nn.Parameter包含weight和bias可训练参数
156
+ self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
157
+ self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
158
+ #self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
159
+ #self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
160
+ #self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
161
+ self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
162
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
163
+ b,n,h=hidden_state.shape
164
+ #q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
165
+ #k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
166
+ #v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
167
+ q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
168
+ if prefix_k is not None and prefix_v is not None:
169
+ #将前缀插入到序列之前
170
+ #print("origional k.shape",prefix_k.shape)
171
+ k=torch.cat((prefix_k,k),dim=1)
172
+ v=torch.cat((prefix_v,v),dim=1)
173
+ #print("model original k :",k[:,0,0])
174
+ bk,nk,hk=k.shape
175
+ bq,nq,hq=q.shape
176
+ q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
177
+ k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
178
+ v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
179
+ attention_logits=F.scaled_dot_product_attention(q, k, v)
180
+ attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
181
+ attention_output=self.out_proj(attention_logits)
182
+ return attention_output
183
+
184
+
185
+
186
+ class GELU(nn.Module):
187
+ """
188
+ 误差函数erf:
189
+ erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
190
+ 其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一���点,具体来说:
191
+ x是误差函数的输入参数,表示积分的上限
192
+ t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
193
+ e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
194
+ 通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
195
+ """
196
+ def forward(self,x):
197
+ old_dtype=x.dtype
198
+ x=x.to(torch.float32)
199
+ return (0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))).to(old_dtype)
200
+
201
+ class QuickGELU(nn.Module):
202
+ def __init__(self):
203
+ super(QuickGELU,self).__init__()
204
+ def forward(self,x):
205
+ old_dtype=x.dtype
206
+ x=x.to(torch.float32)
207
+ return (x*torch.sigmoid(1.702*x)).to(old_dtype)
208
+
209
+
210
+ class MLP(nn.Module):
211
+ def __init__(self,config):
212
+ super(MLP,self).__init__()
213
+ self.hidden_size=config.hidden_size
214
+ self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
215
+ self.gelu=QuickGELU()
216
+ self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
217
+ def forward(self,hidden_state):
218
+ hidden_state=self.c_fc(hidden_state)
219
+ hidden_state=self.gelu(hidden_state)
220
+ hidden_state=self.c_proj(hidden_state)
221
+ return hidden_state
222
+
223
+
224
+
225
+ class ViTConfig:
226
+ def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
227
+ self.image_channel=image_channel
228
+ self.hidden_size=hidden_size
229
+ self.num_heads=num_heads
230
+ self.num_layers=num_layers
231
+ self.patch_size=patch_size
232
+ self.num_patches=num_patches
233
+ self.norm_eps=norm_eps
234
+ self.device=device
235
+ self.dtype=torch.float16
236
+ self.patch_token_num=self.hidden_size//self.patch_size**2+1
237
+ self.output_dim=output_dim
238
+ self.num_virtual_tokens=20
239
+ self.token_dim=self.hidden_size
240
+ self.encoder_hidden_size=self.hidden_size
241
+
242
+ config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
243
+ model=VisionTransformer(config)
Dog sad.jpg ADDED
EmotionCLIP-V2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3d83b57423a070150ca67c87f7ad2a163b531a890632f4ebe3cf1c12a08ffd9
3
+ size 304602701
EmotionCLIP.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱
3
+
4
+ 文本生成仍然考虑因果掩码。
5
+ """
6
+ import torch.nn.functional as F
7
+ from VIT import model as VIT
8
+ from Text_Encoder import text_encoder as transformer
9
+ import torch.nn as nn
10
+ import torch
11
+ from Text_Encoder import MLP
12
+
13
+ class Prompt_block(nn.Module):
14
+ def __init__(self,config):
15
+ super(Prompt_block,self).__init__()
16
+ self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device)
17
+ def forward(self,text_embeddings):
18
+ b,_,_=text_embeddings.size()
19
+ n,dim=self.prompt_embedding.weight.size()
20
+ """
21
+ new_embeddings=[]
22
+ for batch,index_ in enumerate(index):
23
+ text_embedding=text_embeddings[0]
24
+ text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0)
25
+ new_embeddings.append(text_embedding)
26
+ stacked_embedding= torch.stack(new_embeddings, dim=0)
27
+ return stacked_embedding
28
+ """
29
+ text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1)
30
+ return text_embeddings
31
+
32
+
33
+
34
+
35
+
36
+ class CLIP(nn.Module):
37
+ def __init__(self,config):
38
+ super().__init__()
39
+ self.visual=VIT
40
+ self.device=config.device
41
+ self.dtype=config.dtype
42
+ self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
43
+ self.max_position_embeddings=config.max_position_embeddings
44
+ self.prompt_num=config.prompt_num
45
+ self.transformer=transformer
46
+ #增加一个prompt block
47
+ self.prompt_block=Prompt_block(config)
48
+ self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device))
49
+ self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
50
+ self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
51
+ self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False)
52
+ def encode_image(self,img,use_emotion=True):
53
+ cls_embedding=self.visual(img,use_emotion)
54
+ #cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512]
55
+ return cls_embedding
56
+ def encode_text(self,text,use_emotion=True):
57
+ #预留20token的位置
58
+ b,n=text.size()
59
+ index=text.argmax(dim=-1)
60
+ text_embedding=self.token_embedding(text)
61
+ #text_embedding=self.prompt_block(index,text_embedding)
62
+ if n==self.max_position_embeddings-self.prompt_num:
63
+ text_embedding=self.prompt_block(text_embedding)
64
+ index=index+torch.tensor(20,device=index.device,dtype=index.dtype)
65
+ position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype)
66
+ text_embedding=position_embedding+text_embedding
67
+ text_embedding=self.transformer(text_embedding,use_emotion=use_emotion)
68
+ text_embedding=self.ln_final(text_embedding)
69
+ #传入的标记有
70
+ #print(index[0],index_new[0],text_embedding.shape)
71
+ text_embedding=text_embedding[torch.arange(text.shape[0]),index]
72
+ [email protected]_projection.to(self.dtype)
73
+
74
+ return text_embedding
75
+
76
+ def forward(self,image,text,use_emotion=True):
77
+ image_features=self.encode_image(image,use_emotion)
78
+ text_features=self.encode_text(text,use_emotion)
79
+ # normalized features
80
+ image_features=image_features/image_features.norm(dim=-1,keepdim=True)
81
+ text_features=text_features/text_features.norm(dim=-1,keepdim=True)
82
+ # cosine similarity as logits
83
+ logit_scale=self.logit_scale.exp()
84
+ logits_per_image=logit_scale*image_features@text_features.t()
85
+ logits_per_text=logits_per_image.t()
86
+ # shape = [global_batch_size, global_batch_size]
87
+ return logits_per_image,logits_per_text
88
+
89
+ class Config:
90
+ def __init__(self):
91
+ self.vocab_size=49408
92
+ self.image_dim=768
93
+ self.num_patches=49
94
+ self.patch_size=32
95
+ self.hidden_size=512
96
+ self.prompt_num=20
97
+ self.max_position_embeddings=77
98
+ self.num_hidden_layers=12
99
+ self.num_attention_heads=8
100
+ self.head_size=64
101
+ self.layer_norm_eps=1e-5
102
+ self.activation_function="Quickgelu"
103
+ self.dtype=torch.float16
104
+ self.device=torch.device("cuda:0")
105
+ self.logit_scale_init=4.6052
106
+ self.num_virtual_tokens=20
107
+ self.token_dim=self.hidden_size
108
+ self.encoder_hidden_size=self.hidden_size
109
+
110
+ config=Config()
111
+ model=CLIP(config)
112
+ #加载预训练权重
113
+ model.load_state_dict(torch.load(r'./EmotionCLIP-V2.pth',weights_only=True,map_location='cpu'),strict=True)
114
+ """
115
+ for name, param in model.named_parameters():
116
+ if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix'
117
+ print(name,"'s requires_grad turn off.")
118
+ param.requires_grad = False # 冻结该参数
119
+ else:
120
+ print(name,"'s requires_grad turn on.")
121
+ param.requires_grad = True # 允许该参数进行训练
122
+ """
123
+
124
+ #编译模型
125
+ #model=torch.compile(model)
126
+ import pickle
127
+ from PIL import Image
128
+ import numpy as np
129
+ import clip
130
+ with open('./preprocess.pkl','rb') as f:
131
+ preprocess = pickle.load(f)
132
+ with open('./tokenize.pkl','rb') as f:
133
+ tokenizer=pickle.load(f)
134
+ device=config.device
135
+ image = preprocess(Image.open("Dog sad.jpg")).unsqueeze(0).to(device)
136
+ # 情感识别
137
+ labels=[
138
+ 'amusement',
139
+ 'anger',
140
+ 'awe',
141
+ 'contentment',
142
+ 'disgust',
143
+ 'excitement',
144
+ 'fear',
145
+ 'sadness',
146
+ 'neutral'
147
+ ]
148
+ text_list=[ f"This picture conveys a sense of {label}" for label in labels]
149
+ tokens= tokenizer(text_list,
150
+ context_length=57).to(device)
151
+
152
+ with torch.no_grad():
153
+ logits_per_image, logits_per_text = model(image.to(config.dtype), tokens)
154
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
155
+
156
+ # 获取预测标签
157
+ predicted_index = np.argmax(probs, axis=1)
158
+ predicted_label=labels[predicted_index[0]]
159
+
160
+ print("情感识别:", probs)
161
+ print("预测的情感标签:", predicted_label)
162
+
163
+ # 泛化性能
164
+ labels=[
165
+ 'spider',
166
+ 'dog',
167
+ 'cat',
168
+ 'fish'
169
+ ]
170
+ text_list=[ f"This is a {label}" for label in labels]
171
+ tokens= tokenizer(text_list,context_length=57).to(device)
172
+
173
+ with torch.no_grad():
174
+ logits_per_image, logits_per_text = model(image.to(config.dtype), tokens, use_emotion=False)
175
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
176
+
177
+ # 获取预测标签
178
+ predicted_index = np.argmax(probs, axis=1)
179
+ predicted_label=labels[predicted_index[0]]
180
+
181
+ print("泛化识别:", probs)
182
+ print("预测的泛化标签:", predicted_label)
183
+
Text_Encoder.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from torch.nn.attention import SDPBackend, sdpa_kernel
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class PrefixEncoder(torch.nn.Module):
9
+ def __init__(self,config):
10
+ super(PrefixEncoder,self).__init__()
11
+ self.config=config
12
+ self.device=config.device
13
+ self.dtype=config.dtype
14
+ self.num_virtual_tokens=config.num_virtual_tokens
15
+ self.token_dim=config.token_dim
16
+ self.encoder_hidden_size=config.encoder_hidden_size
17
+ self.num_layers=config.num_layers
18
+ self.prefix_embedding=nn.Parameter(torch.empty(1,self.num_virtual_tokens,self.num_layers*2*self.token_dim,device=config.device,dtype=config.dtype),requires_grad=False)
19
+ def forward(self,input_ids,batch_size):
20
+ prefix_embedding=self.prefix_embedding
21
+ prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
22
+ prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
23
+ prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
24
+ k,v=prefix_embedding.chunk(2,dim=0)
25
+ return (k.squeeze(0),v.squeeze(0))
26
+
27
+
28
+ class Transformer(nn.Module):
29
+ def __init__(self,config):
30
+ super(Transformer,self).__init__()
31
+ self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
32
+ self.prefix=PrefixEncoder(config)
33
+ prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
34
+ self.register_buffer("prefix_tokens",prefix_tokens)
35
+ def forward(self,hidden_state,use_emotion):
36
+ if use_emotion:
37
+ #print("激活text transformer prefix.")
38
+ b,n,h=hidden_state.shape
39
+ prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
40
+ for index,resblock in enumerate(self.resblocks):
41
+ hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
42
+ return hidden_state
43
+ else:
44
+ for index,resblock in enumerate(self.resblocks):
45
+ hidden_state=resblock(hidden_state)
46
+ return hidden_state
47
+
48
+
49
+
50
+
51
+
52
+
53
+ class ResidualAttentionBlock(nn.Module):
54
+ def __init__(self,config):
55
+ super(ResidualAttentionBlock,self).__init__()
56
+ self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
57
+ self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
58
+ #self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
59
+ self.attn=MultiHeadAttention(config)
60
+ self.mlp=MLP(config)
61
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
62
+ residual=hidden_state
63
+ hidden_state=self.ln_1(hidden_state)
64
+ hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
65
+ hidden_state=residual+hidden_state
66
+ residual=hidden_state
67
+ hidden_state=self.ln_2(hidden_state)
68
+ hidden_state=self.mlp(hidden_state)
69
+ hidden_state=residual+hidden_state
70
+ return hidden_state
71
+
72
+ class MultiHeadAttention(nn.Module):
73
+ def __init__(self,config):
74
+ super(MultiHeadAttention,self).__init__()
75
+ self.hidden_size=config.hidden_size
76
+ self.num_heads=config.num_heads
77
+ self.head_size=self.hidden_size//self.num_heads
78
+ #nn.Parameter包含weight和bias可训练参数
79
+ self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
80
+ self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
81
+ #self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
82
+ #self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
83
+ #self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
84
+ self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
85
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
86
+ b,n,c=hidden_state.shape
87
+ #q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
88
+ #k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
89
+ #v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
90
+ q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
91
+ if prefix_k is not None and prefix_v is not None:
92
+ #将前缀插入到序列之前
93
+ k=torch.cat((prefix_k,k),dim=1)
94
+ v=torch.cat((prefix_v,v),dim=1)
95
+ #print("model origin k :",k[:,0,0])
96
+ bk,nk,hk=k.shape
97
+ bq,nq,hq=q.shape
98
+ q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
99
+ k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
100
+ v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
101
+ attention_logits=F.scaled_dot_product_attention(q, k, v)
102
+ attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
103
+ attention_output=self.out_proj(attention_logits)
104
+ return attention_output
105
+
106
+
107
+ class GELU(nn.Module):
108
+ """
109
+ 误差函数erf:
110
+ erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
111
+ 其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说:
112
+ x是误差函数的输入参数,表示积分的上限
113
+ t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
114
+ e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
115
+ 通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
116
+ """
117
+ def forward(self,x):
118
+ return 0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))
119
+
120
+ class QuickGELU(nn.Module):
121
+ def __init__(self):
122
+ super(QuickGELU,self).__init__()
123
+ def forward(self,x):
124
+ old_dtype=x.dtype
125
+ x=x.to(torch.float32)
126
+ return (x*torch.sigmoid(1.702*x)).to(old_dtype)
127
+
128
+
129
+ class MLP(nn.Module):
130
+ def __init__(self,config):
131
+ super(MLP,self).__init__()
132
+ self.hidden_size=config.hidden_size
133
+ self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
134
+ self.gelu=QuickGELU()
135
+ self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
136
+ def forward(self,hidden_state):
137
+ hidden_state=self.c_fc(hidden_state)
138
+ hidden_state=self.gelu(hidden_state)
139
+ hidden_state=self.c_proj(hidden_state)
140
+ return hidden_state
141
+
142
+ class Config:
143
+ def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype):
144
+ self.vocab_size=vocab_size
145
+ self.max_position_embeddings=max_position_embeddings
146
+ self.hidden_size=hidden_size
147
+ self.num_layers=num_layers
148
+ self.num_heads=num_heads
149
+ self.device=device
150
+ self.dtype=dtype
151
+ self.norm_eps=1e-5
152
+ self.num_virtual_tokens=20
153
+ self.token_dim=hidden_size
154
+ self.encoder_hidden_size=hidden_size
155
+ config=Config(
156
+ vocab_size=49408,
157
+ max_position_embeddings=77,
158
+ hidden_size=512,
159
+ num_layers=12,
160
+ num_heads=8,
161
+ device=torch.device('cuda:0'),
162
+ dtype=torch.float16
163
+ )
164
+ class TextEncoder(nn.Module):
165
+ def __init__(self,config):
166
+ super(TextEncoder,self).__init__()
167
+ self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,device=config.device,dtype=config.dtype)
168
+ self.positional_embedding=nn.Parameter(torch.zeros(config.max_position_embeddings,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
169
+ self.transformer=Transformer(config)
170
+ self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
171
+ def forward(self,input_ids):
172
+ b,n=input_ids.shape
173
+ prompt_embedding,token_embeddings=self.token_embedding(input_ids)
174
+ position_ids=torch.arange(n,device=config.device,dtype=config.dtype).unsqueeze(0).expand(b,n)
175
+ position_embeddings=self.positional_embedding[position_ids]
176
+ embeddings=token_embeddings+position_embeddings
177
+ embeddings=torch.cat((prompt_embedding,embeddings),dim=1)
178
+ embeddings=self.transformer(embeddings)
179
+ embeddings=self.ln_final(embeddings)
180
+ return embeddings
181
+
182
+ text_encoder=Transformer(config)
VIT.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import math
6
+ import os
7
+ import sys
8
+
9
+ #huggingface实现的前缀微调
10
+ class PrefixEncoder(torch.nn.Module):
11
+ def __init__(self,config):
12
+ super(PrefixEncoder,self).__init__()
13
+ self.config=config
14
+ self.device=config.device
15
+ self.dtype=config.dtype
16
+ self.num_virtual_tokens=config.num_virtual_tokens
17
+ self.token_dim=config.token_dim
18
+ self.encoder_hidden_size=config.encoder_hidden_size
19
+ self.num_layers=config.num_layers
20
+ self.prefix_embedding=nn.Parameter(torch.empty(1,self.num_virtual_tokens,self.num_layers*2*self.token_dim,device=config.device,dtype=config.dtype),requires_grad=False)
21
+ def forward(self,input_ids,batch_size):
22
+ prefix_embedding=self.prefix_embedding
23
+ prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
24
+ prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
25
+ prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
26
+ k,v=prefix_embedding.chunk(2,dim=0)
27
+ return (k.squeeze(0),v.squeeze(0))
28
+
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import math
33
+ from torch.nn.attention import SDPBackend, sdpa_kernel
34
+ from torch.nn import functional as F
35
+ def position_embedding(x,position_ids):
36
+ hidden_size=x.size(2)
37
+ seq_len=x.size(1)
38
+ div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
39
+ positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
40
+ positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
41
+ positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
42
+ positional_encoding=positional_encoding.unsqueeze(0)
43
+ return positional_encoding
44
+
45
+
46
+
47
+
48
+ class VisionTransformer(nn.Module):
49
+ def __init__(self,config):
50
+ super(VisionTransformer,self).__init__()
51
+ self.image_channel=config.image_channel
52
+ self.hidden_size=config.hidden_size
53
+ self.norm_eps=config.norm_eps
54
+ self.patch_size=config.patch_size
55
+ self.output_dim=config.output_dim
56
+ self.dtype=config.dtype
57
+ self.num_patches=config.num_patches
58
+ self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
59
+ self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
60
+ self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
61
+ self.transformer=Transformer(config)
62
+ #self.position_ids=torch.arange(config.num_patches+1,dtype=torch.long,device=config.device)
63
+ #self.position_embeddings=nn.Parameter(torch.zeros(1,config.num_patches+1,config.hidden_size))
64
+ #nn.init.normal_(self.position_embeddings)
65
+ #clsToken,用于图像分类任务
66
+ #self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size,device=config.device))
67
+ #分类token不是可训练参数
68
+ self.class_embedding=nn.Parameter(torch.empty(config.hidden_size,device=config.device),requires_grad=False)
69
+ #很明显这里的position_embedding也是一个可学习参数
70
+ self.positional_embedding=nn.Parameter(torch.empty(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=False)
71
+ #可训练参数
72
+ self.proj=nn.Parameter(torch.empty(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=False)
73
+ self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
74
+ def forward(self,hidden_state,use_emotion):
75
+ b,c,h,w=hidden_state.shape
76
+ #获得embedding向量
77
+ hidden_state=self.conv1(hidden_state)
78
+ hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
79
+ #添加cls token embedding
80
+ hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
81
+ #使用transformer原论文中的固定位置嵌入
82
+ #hidden_state=hidden_state+position_embedding(hidden_state,self.position_ids)
83
+ hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
84
+ hidden_state=self.ln_pre(hidden_state)
85
+ hidden_state=self.transformer(hidden_state,use_emotion)
86
+ #提取cls token输出 与image patch输出
87
+ cls_state=hidden_state[:,0,:]
88
+ cls_state=self.ln_post(cls_state)
89
+ cls_state=torch.matmul(cls_state,self.proj)
90
+ #image_state=hidden_state[:,1:,:]
91
+ #image_state size (batch_size,49,768)
92
+ return cls_state
93
+
94
+ class Transformer(nn.Module):
95
+ def __init__(self,config):
96
+ super(Transformer,self).__init__()
97
+ self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
98
+ self.prefix=PrefixEncoder(config)
99
+ prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
100
+ self.register_buffer("prefix_tokens",prefix_tokens)
101
+ def forward(self,hidden_state,use_emotion):
102
+ if use_emotion:
103
+ b,n,h=hidden_state.shape
104
+ prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
105
+ for index,resblock in enumerate(self.resblocks):
106
+ #在每一层之前提取前缀向量输入到resblock中进行拼接
107
+ hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
108
+ return hidden_state
109
+ else:
110
+ for index,resblock in enumerate(self.resblocks):
111
+ #在每一层之前提取前缀向量输入到resblock中进行拼接
112
+ hidden_state=resblock(hidden_state)
113
+ return hidden_state
114
+
115
+
116
+
117
+
118
+
119
+
120
+ class ResidualAttentionBlock(nn.Module):
121
+ def __init__(self,config):
122
+ super(ResidualAttentionBlock,self).__init__()
123
+ self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
124
+ self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
125
+ #self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
126
+ self.attn=MultiHeadAttention(config)
127
+ self.mlp=MLP(config)
128
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
129
+ residual=hidden_state
130
+ hidden_state=self.ln_1(hidden_state)
131
+ hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
132
+ hidden_state=residual+hidden_state
133
+ residual=hidden_state
134
+ hidden_state=self.ln_2(hidden_state)
135
+ hidden_state=self.mlp(hidden_state)
136
+ hidden_state=residual+hidden_state
137
+ return hidden_state
138
+
139
+ class MultiHeadAttention(nn.Module):
140
+ def __init__(self,config):
141
+ super(MultiHeadAttention,self).__init__()
142
+ self.hidden_size=config.hidden_size
143
+ self.num_heads=config.num_heads
144
+ self.head_size=self.hidden_size//self.num_heads
145
+ #nn.Parameter包含weight和bias可训练参数
146
+ self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
147
+ self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
148
+ #self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
149
+ #self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
150
+ #self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
151
+ self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
152
+ def forward(self,hidden_state,prefix_k=None,prefix_v=None):
153
+ b,n,h=hidden_state.shape
154
+ #q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
155
+ #k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
156
+ #v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
157
+ q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
158
+ if prefix_k is not None and prefix_v is not None:
159
+ #将前缀插入到序列之前
160
+ #print("origional k.shape",prefix_k.shape)
161
+ k=torch.cat((prefix_k,k),dim=1)
162
+ v=torch.cat((prefix_v,v),dim=1)
163
+ #print("model original k :",k[:,0,0])
164
+ bk,nk,hk=k.shape
165
+ bq,nq,hq=q.shape
166
+ q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
167
+ k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
168
+ v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
169
+ attention_logits=F.scaled_dot_product_attention(q, k, v)
170
+ attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
171
+ attention_output=self.out_proj(attention_logits)
172
+ return attention_output
173
+
174
+
175
+
176
+ class GELU(nn.Module):
177
+ """
178
+ 误差函数erf:
179
+ erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
180
+ 其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说:
181
+ x是误差函数的输入参数,表示积分的上限
182
+ t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
183
+ e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
184
+ 通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
185
+ """
186
+ def forward(self,x):
187
+ old_dtype=x.dtype
188
+ x=x.to(torch.float32)
189
+ return (0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))).to(old_dtype)
190
+
191
+ class QuickGELU(nn.Module):
192
+ def __init__(self):
193
+ super(QuickGELU,self).__init__()
194
+ def forward(self,x):
195
+ old_dtype=x.dtype
196
+ x=x.to(torch.float32)
197
+ return (x*torch.sigmoid(1.702*x)).to(old_dtype)
198
+
199
+
200
+ class MLP(nn.Module):
201
+ def __init__(self,config):
202
+ super(MLP,self).__init__()
203
+ self.hidden_size=config.hidden_size
204
+ self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
205
+ self.gelu=QuickGELU()
206
+ self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
207
+ def forward(self,hidden_state):
208
+ hidden_state=self.c_fc(hidden_state)
209
+ hidden_state=self.gelu(hidden_state)
210
+ hidden_state=self.c_proj(hidden_state)
211
+ return hidden_state
212
+
213
+
214
+
215
+ class ViTConfig:
216
+ def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
217
+ self.image_channel=image_channel
218
+ self.hidden_size=hidden_size
219
+ self.num_heads=num_heads
220
+ self.num_layers=num_layers
221
+ self.patch_size=patch_size
222
+ self.num_patches=num_patches
223
+ self.norm_eps=norm_eps
224
+ self.device=device
225
+ self.dtype=torch.float16
226
+ self.patch_token_num=self.hidden_size//self.patch_size**2+1
227
+ self.output_dim=output_dim
228
+ self.num_virtual_tokens=20
229
+ self.token_dim=self.hidden_size
230
+ self.encoder_hidden_size=self.hidden_size
231
+
232
+ config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
233
+ model=VisionTransformer(config)
__pycache__/Text_Encoder.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
__pycache__/VIT.cpython-312.pyc ADDED
Binary file (17.7 kB). View file
 
preprocess.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51ff0f1d35da9d25c16b5a82957cfb43b76d01a94084c501ec4a9180dc4b53aa
3
+ size 1116
tokenize.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7cd84774d43b4d7513f250b615ecd579a9a0c852f3e011043330407f7ca93e1
3
+ size 37