|
import torch |
|
import torch.nn as nn |
|
import math |
|
from torch.nn.attention import SDPBackend, sdpa_kernel |
|
from torch.nn import functional as F |
|
|
|
|
|
class PrefixEncoder(torch.nn.Module): |
|
def __init__(self,config): |
|
super(PrefixEncoder,self).__init__() |
|
self.config=config |
|
self.device=config.device |
|
self.dtype=config.dtype |
|
self.num_virtual_tokens=config.num_virtual_tokens |
|
self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype) |
|
self.token_dim=config.token_dim |
|
self.encoder_hidden_size=config.encoder_hidden_size |
|
self.num_layers=config.num_layers |
|
self.transformer=torch.nn.Sequential( |
|
torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype), |
|
torch.nn.Tanh(), |
|
torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype), |
|
) |
|
def forward(self,input_ids,batch_size): |
|
input_ids=input_ids.unsqueeze(0) |
|
prefix_embedding=self.embedding(input_ids) |
|
prefix_embedding=self.transformer(prefix_embedding) |
|
self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False)) |
|
prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim) |
|
prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim) |
|
prefix_embedding=prefix_embedding.permute(3,2,0,1,4) |
|
del self.embedding |
|
del self.transformer |
|
k,v=prefix_embedding.chunk(2,dim=0) |
|
return (k.squeeze(0),v.squeeze(0)) |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self,config): |
|
super(Transformer,self).__init__() |
|
self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)]) |
|
self.prefix=PrefixEncoder(config) |
|
prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long) |
|
self.register_buffer("prefix_tokens",prefix_tokens) |
|
def forward(self,hidden_state,use_emotion): |
|
if use_emotion: |
|
|
|
b,n,h=hidden_state.shape |
|
prefix_k,prefix_v=self.prefix(self.prefix_tokens,b) |
|
for index,resblock in enumerate(self.resblocks): |
|
hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index]) |
|
return hidden_state |
|
else: |
|
for index,resblock in enumerate(self.resblocks): |
|
hidden_state=resblock(hidden_state) |
|
return hidden_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
def __init__(self,config): |
|
super(ResidualAttentionBlock,self).__init__() |
|
self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype) |
|
self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype) |
|
|
|
self.attn=MultiHeadAttention(config) |
|
self.mlp=MLP(config) |
|
def forward(self,hidden_state,prefix_k=None,prefix_v=None): |
|
residual=hidden_state |
|
hidden_state=self.ln_1(hidden_state) |
|
hidden_state=self.attn(hidden_state,prefix_k,prefix_v) |
|
hidden_state=residual+hidden_state |
|
residual=hidden_state |
|
hidden_state=self.ln_2(hidden_state) |
|
hidden_state=self.mlp(hidden_state) |
|
hidden_state=residual+hidden_state |
|
return hidden_state |
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self,config): |
|
super(MultiHeadAttention,self).__init__() |
|
self.hidden_size=config.hidden_size |
|
self.num_heads=config.num_heads |
|
self.head_size=self.hidden_size//self.num_heads |
|
|
|
self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False) |
|
self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False) |
|
|
|
|
|
|
|
self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype) |
|
def forward(self,hidden_state,prefix_k=None,prefix_v=None): |
|
b,n,c=hidden_state.shape |
|
|
|
|
|
|
|
q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1) |
|
if prefix_k is not None and prefix_v is not None: |
|
|
|
k=torch.cat((prefix_k,k),dim=1) |
|
v=torch.cat((prefix_v,v),dim=1) |
|
|
|
bk,nk,hk=k.shape |
|
bq,nq,hq=q.shape |
|
q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3) |
|
k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3) |
|
v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3) |
|
attention_logits=F.scaled_dot_product_attention(q, k, v) |
|
attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size) |
|
attention_output=self.out_proj(attention_logits) |
|
return attention_output |
|
|
|
|
|
class GELU(nn.Module): |
|
""" |
|
误差函数erf: |
|
erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x) |
|
其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说: |
|
x是误差函数的输入参数,表示积分的上限 |
|
t是积分变量,它从0变化到x,在每个点上计算e-t^2的值 |
|
e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。 |
|
通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分 |
|
""" |
|
def forward(self,x): |
|
return 0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0))) |
|
|
|
class QuickGELU(nn.Module): |
|
def __init__(self): |
|
super(QuickGELU,self).__init__() |
|
def forward(self,x): |
|
old_dtype=x.dtype |
|
x=x.to(torch.float32) |
|
return (x*torch.sigmoid(1.702*x)).to(old_dtype) |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self,config): |
|
super(MLP,self).__init__() |
|
self.hidden_size=config.hidden_size |
|
self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype) |
|
self.gelu=QuickGELU() |
|
self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype) |
|
def forward(self,hidden_state): |
|
hidden_state=self.c_fc(hidden_state) |
|
hidden_state=self.gelu(hidden_state) |
|
hidden_state=self.c_proj(hidden_state) |
|
return hidden_state |
|
|
|
class Config: |
|
def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype): |
|
self.vocab_size=vocab_size |
|
self.max_position_embeddings=max_position_embeddings |
|
self.hidden_size=hidden_size |
|
self.num_layers=num_layers |
|
self.num_heads=num_heads |
|
self.device=device |
|
self.dtype=dtype |
|
self.norm_eps=1e-5 |
|
self.num_virtual_tokens=20 |
|
self.token_dim=hidden_size |
|
self.encoder_hidden_size=hidden_size |
|
config=Config( |
|
vocab_size=49408, |
|
max_position_embeddings=77, |
|
hidden_size=512, |
|
num_layers=12, |
|
num_heads=8, |
|
device=torch.device('cuda:0'), |
|
dtype=torch.float16 |
|
) |
|
class TextEncoder(nn.Module): |
|
def __init__(self,config): |
|
super(TextEncoder,self).__init__() |
|
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,device=config.device,dtype=config.dtype) |
|
self.positional_embedding=nn.Parameter(torch.zeros(config.max_position_embeddings,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False) |
|
self.transformer=Transformer(config) |
|
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype) |
|
def forward(self,input_ids): |
|
b,n=input_ids.shape |
|
prompt_embedding,token_embeddings=self.token_embedding(input_ids) |
|
position_ids=torch.arange(n,device=config.device,dtype=config.dtype).unsqueeze(0).expand(b,n) |
|
position_embeddings=self.positional_embedding[position_ids] |
|
embeddings=token_embeddings+position_embeddings |
|
embeddings=torch.cat((prompt_embedding,embeddings),dim=1) |
|
embeddings=self.transformer(embeddings) |
|
embeddings=self.ln_final(embeddings) |
|
return embeddings |
|
|
|
text_encoder=Transformer(config) |