jiangchengchengNLP's picture
Upload 3 files
bcd2e51 verified
#----------------------------------
########apply AWQ-INT4#############
#-----------------------------------
# Qlinear_4
import torch.nn as nn
import torch
class Qlinear_4(nn.Module):
def __init__(self, weight, scale, zero_point, scale_factor, bias=None, dtype=torch.bfloat16):
super(Qlinear_4, self).__init__()
self.weight = nn.Parameter(weight, requires_grad=False)
self.scale = nn.Parameter(scale.to(dtype), requires_grad=False)
self.zero_point = nn.Parameter(zero_point.to(dtype), requires_grad=False)
self.scale_factor = nn.Parameter(scale_factor.to(dtype), requires_grad=False)
self.bias = nn.Parameter(bias, requires_grad=False) if bias is not None else None
self.dtype = dtype
def forward(self, x):
x = x / self.scale_factor.to(self.dtype)
#weight_decode = self.unpack_weight(self.weight).to(x.device) * self.scale.to(x.device)
weight_decode = (self.unpack_weight(self.weight).to(self.dtype)+self.zero_point)*self.scale
if self.bias is not None:
out = x @ weight_decode.T.to(x.dtype) + self.bias.to(x.dtype)
else:
out = x @ weight_decode.T.to(x.dtype)
return out
@staticmethod
def pack_weight(weight):
# 打包到 int8,区间应该在
weight = weight.to(torch.int8)
packed_weight = torch.zeros(weight.shape[0] // 2, weight.shape[1], dtype=torch.int8,device=weight.device)
packed_weight |= weight[:weight.shape[0] // 2 * 2:2] << 4
packed_weight |= weight[1:weight.shape[0] // 2 * 2:2]
return packed_weight
@staticmethod
def unpack_weight(packed_weight):
# 解包 int8 为 4bit
weight = torch.zeros(packed_weight.shape[0] * 2, packed_weight.shape[1], dtype=torch.int8,device=packed_weight.device)
weight[:weight.shape[0] // 2 * 2:2] = (packed_weight >> 4) & 0x0F
weight[1:weight.shape[0] // 2 * 2:2] = packed_weight & 0x0F
return weight
#apply
def apply_quantized_layers(model,config,prefix=''):
for name, module in model.named_children():
full_name = prefix + ('.' if prefix else '') + name
if isinstance(module, nn.Linear):
#print(full_name)
if config.get(full_name)["quantized"]==True:
# 初始化 Qlinear 的参数
#print(module.weight.data.shape,module.weight.shape)
weight = torch.zeros_like(module.weight.data,device=module.weight.data.device) # 使用 zeros_like 初始化权重
packed_weight = Qlinear_4.pack_weight(weight) # 打包权重
scale=torch.zeros((module.weight.size(0),1),device=module.weight.data.device) # 初始化 scale,与输出通道维度保持一致
zero_point =torch.zeros_like(scale,device=module.weight.data.device) # 初始化 zero_point,与scale保持一致
scale_factor = torch.ones(module.weight.size(1),device=module.weight.data.device) # 初始化 scale_factor,假设初始值为 1
# 创建 Qlinear_4 实例
module = Qlinear_4(
weight=packed_weight,
scale=scale,
zero_point=zero_point,
scale_factor=scale_factor,
bias=module.bias,
dtype=torch.bfloat16
)
else:
module=module
setattr(model, name, module) # 替换模块
else:
apply_quantized_layers(module,config,full_name) # 递归处理子模块
#config read
import json
def load_config(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
return config
def apply_AWQ(model,quantization_config_path,quantization_weight_path):
quantization_config=load_config(quantization_config_path)
apply_quantized_layers(model,quantization_config)
state_dict = torch.load(quantization_weight_path,map_location='cpu')
model.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
return True