|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
import torch |
|
class Qlinear_4(nn.Module): |
|
def __init__(self, weight, scale, zero_point, scale_factor, bias=None, dtype=torch.bfloat16): |
|
super(Qlinear_4, self).__init__() |
|
self.weight = nn.Parameter(weight, requires_grad=False) |
|
self.scale = nn.Parameter(scale.to(dtype), requires_grad=False) |
|
self.zero_point = nn.Parameter(zero_point.to(dtype), requires_grad=False) |
|
self.scale_factor = nn.Parameter(scale_factor.to(dtype), requires_grad=False) |
|
self.bias = nn.Parameter(bias, requires_grad=False) if bias is not None else None |
|
self.dtype = dtype |
|
|
|
def forward(self, x): |
|
x = x / self.scale_factor.to(self.dtype) |
|
|
|
weight_decode = (self.unpack_weight(self.weight).to(self.dtype)+self.zero_point)*self.scale |
|
if self.bias is not None: |
|
out = x @ weight_decode.T.to(x.dtype) + self.bias.to(x.dtype) |
|
else: |
|
out = x @ weight_decode.T.to(x.dtype) |
|
return out |
|
|
|
@staticmethod |
|
def pack_weight(weight): |
|
|
|
weight = weight.to(torch.int8) |
|
packed_weight = torch.zeros(weight.shape[0] // 2, weight.shape[1], dtype=torch.int8,device=weight.device) |
|
packed_weight |= weight[:weight.shape[0] // 2 * 2:2] << 4 |
|
packed_weight |= weight[1:weight.shape[0] // 2 * 2:2] |
|
return packed_weight |
|
|
|
@staticmethod |
|
def unpack_weight(packed_weight): |
|
|
|
weight = torch.zeros(packed_weight.shape[0] * 2, packed_weight.shape[1], dtype=torch.int8,device=packed_weight.device) |
|
weight[:weight.shape[0] // 2 * 2:2] = (packed_weight >> 4) & 0x0F |
|
weight[1:weight.shape[0] // 2 * 2:2] = packed_weight & 0x0F |
|
return weight |
|
|
|
def apply_quantized_layers(model,config,prefix=''): |
|
for name, module in model.named_children(): |
|
full_name = prefix + ('.' if prefix else '') + name |
|
if isinstance(module, nn.Linear): |
|
|
|
if config.get(full_name)["quantized"]==True: |
|
|
|
|
|
weight = torch.zeros_like(module.weight.data,device=module.weight.data.device) |
|
packed_weight = Qlinear_4.pack_weight(weight) |
|
scale=torch.zeros((module.weight.size(0),1),device=module.weight.data.device) |
|
zero_point =torch.zeros_like(scale,device=module.weight.data.device) |
|
scale_factor = torch.ones(module.weight.size(1),device=module.weight.data.device) |
|
|
|
module = Qlinear_4( |
|
weight=packed_weight, |
|
scale=scale, |
|
zero_point=zero_point, |
|
scale_factor=scale_factor, |
|
bias=module.bias, |
|
dtype=torch.bfloat16 |
|
) |
|
else: |
|
module=module |
|
setattr(model, name, module) |
|
|
|
else: |
|
apply_quantized_layers(module,config,full_name) |
|
|
|
|
|
|
|
import json |
|
def load_config(config_path): |
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
return config |
|
|
|
def apply_AWQ(model,quantization_config_path,quantization_weight_path): |
|
quantization_config=load_config(quantization_config_path) |
|
apply_quantized_layers(model,quantization_config) |
|
state_dict = torch.load(quantization_weight_path,map_location='cpu') |
|
model.load_state_dict(state_dict) |
|
del state_dict |
|
torch.cuda.empty_cache() |
|
return True |
|
|
|
|