EmotionCLIP-V2 / VIT.py
jiangchengchengNLP's picture
Upload my files
b11ecdd verified
raw
history blame
11.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import sys
#huggingface实现的前缀微调
class PrefixEncoder(torch.nn.Module):
def __init__(self,config):
super(PrefixEncoder,self).__init__()
self.config=config
self.device=config.device
self.dtype=config.dtype
self.num_virtual_tokens=config.num_virtual_tokens
self.token_dim=config.token_dim
self.encoder_hidden_size=config.encoder_hidden_size
self.num_layers=config.num_layers
self.prefix_embedding=nn.Parameter(torch.empty(1,self.num_virtual_tokens,self.num_layers*2*self.token_dim,device=config.device,dtype=config.dtype),requires_grad=False)
def forward(self,input_ids,batch_size):
prefix_embedding=self.prefix_embedding
prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
k,v=prefix_embedding.chunk(2,dim=0)
return (k.squeeze(0),v.squeeze(0))
import torch
import torch.nn as nn
import math
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn import functional as F
def position_embedding(x,position_ids):
hidden_size=x.size(2)
seq_len=x.size(1)
div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
positional_encoding=positional_encoding.unsqueeze(0)
return positional_encoding
class VisionTransformer(nn.Module):
def __init__(self,config):
super(VisionTransformer,self).__init__()
self.image_channel=config.image_channel
self.hidden_size=config.hidden_size
self.norm_eps=config.norm_eps
self.patch_size=config.patch_size
self.output_dim=config.output_dim
self.dtype=config.dtype
self.num_patches=config.num_patches
self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
self.transformer=Transformer(config)
#self.position_ids=torch.arange(config.num_patches+1,dtype=torch.long,device=config.device)
#self.position_embeddings=nn.Parameter(torch.zeros(1,config.num_patches+1,config.hidden_size))
#nn.init.normal_(self.position_embeddings)
#clsToken,用于图像分类任务
#self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size,device=config.device))
#分类token不是可训练参数
self.class_embedding=nn.Parameter(torch.empty(config.hidden_size,device=config.device),requires_grad=False)
#很明显这里的position_embedding也是一个可学习参数
self.positional_embedding=nn.Parameter(torch.empty(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=False)
#可训练参数
self.proj=nn.Parameter(torch.empty(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=False)
self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
def forward(self,hidden_state,use_emotion):
b,c,h,w=hidden_state.shape
#获得embedding向量
hidden_state=self.conv1(hidden_state)
hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
#添加cls token embedding
hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
#使用transformer原论文中的固定位置嵌入
#hidden_state=hidden_state+position_embedding(hidden_state,self.position_ids)
hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
hidden_state=self.ln_pre(hidden_state)
hidden_state=self.transformer(hidden_state,use_emotion)
#提取cls token输出 与image patch输出
cls_state=hidden_state[:,0,:]
cls_state=self.ln_post(cls_state)
cls_state=torch.matmul(cls_state,self.proj)
#image_state=hidden_state[:,1:,:]
#image_state size (batch_size,49,768)
return cls_state
class Transformer(nn.Module):
def __init__(self,config):
super(Transformer,self).__init__()
self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
self.prefix=PrefixEncoder(config)
prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
self.register_buffer("prefix_tokens",prefix_tokens)
def forward(self,hidden_state,use_emotion):
if use_emotion:
b,n,h=hidden_state.shape
prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
for index,resblock in enumerate(self.resblocks):
#在每一层之前提取前缀向量输入到resblock中进行拼接
hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
return hidden_state
else:
for index,resblock in enumerate(self.resblocks):
#在每一层之前提取前缀向量输入到resblock中进行拼接
hidden_state=resblock(hidden_state)
return hidden_state
class ResidualAttentionBlock(nn.Module):
def __init__(self,config):
super(ResidualAttentionBlock,self).__init__()
self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
#self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
self.attn=MultiHeadAttention(config)
self.mlp=MLP(config)
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
residual=hidden_state
hidden_state=self.ln_1(hidden_state)
hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
hidden_state=residual+hidden_state
residual=hidden_state
hidden_state=self.ln_2(hidden_state)
hidden_state=self.mlp(hidden_state)
hidden_state=residual+hidden_state
return hidden_state
class MultiHeadAttention(nn.Module):
def __init__(self,config):
super(MultiHeadAttention,self).__init__()
self.hidden_size=config.hidden_size
self.num_heads=config.num_heads
self.head_size=self.hidden_size//self.num_heads
#nn.Parameter包含weight和bias可训练参数
self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
#self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
#self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
#self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
b,n,h=hidden_state.shape
#q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
#k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
#v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
if prefix_k is not None and prefix_v is not None:
#将前缀插入到序列之前
#print("origional k.shape",prefix_k.shape)
k=torch.cat((prefix_k,k),dim=1)
v=torch.cat((prefix_v,v),dim=1)
#print("model original k :",k[:,0,0])
bk,nk,hk=k.shape
bq,nq,hq=q.shape
q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
attention_logits=F.scaled_dot_product_attention(q, k, v)
attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
attention_output=self.out_proj(attention_logits)
return attention_output
class GELU(nn.Module):
"""
误差函数erf:
erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说:
x是误差函数的输入参数,表示积分的上限
t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
"""
def forward(self,x):
old_dtype=x.dtype
x=x.to(torch.float32)
return (0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))).to(old_dtype)
class QuickGELU(nn.Module):
def __init__(self):
super(QuickGELU,self).__init__()
def forward(self,x):
old_dtype=x.dtype
x=x.to(torch.float32)
return (x*torch.sigmoid(1.702*x)).to(old_dtype)
class MLP(nn.Module):
def __init__(self,config):
super(MLP,self).__init__()
self.hidden_size=config.hidden_size
self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
self.gelu=QuickGELU()
self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
def forward(self,hidden_state):
hidden_state=self.c_fc(hidden_state)
hidden_state=self.gelu(hidden_state)
hidden_state=self.c_proj(hidden_state)
return hidden_state
class ViTConfig:
def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
self.image_channel=image_channel
self.hidden_size=hidden_size
self.num_heads=num_heads
self.num_layers=num_layers
self.patch_size=patch_size
self.num_patches=num_patches
self.norm_eps=norm_eps
self.device=device
self.dtype=torch.float16
self.patch_token_num=self.hidden_size//self.patch_size**2+1
self.output_dim=output_dim
self.num_virtual_tokens=20
self.token_dim=self.hidden_size
self.encoder_hidden_size=self.hidden_size
config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
model=VisionTransformer(config)