File size: 4,155 Bytes
bcd2e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#----------------------------------

########apply AWQ-INT4#############

#-----------------------------------

# Qlinear_4
import torch.nn as nn
import torch 
class Qlinear_4(nn.Module):
    def __init__(self, weight, scale, zero_point, scale_factor, bias=None, dtype=torch.bfloat16):
        super(Qlinear_4, self).__init__()
        self.weight = nn.Parameter(weight, requires_grad=False)
        self.scale = nn.Parameter(scale.to(dtype), requires_grad=False)
        self.zero_point = nn.Parameter(zero_point.to(dtype), requires_grad=False)
        self.scale_factor = nn.Parameter(scale_factor.to(dtype), requires_grad=False)
        self.bias = nn.Parameter(bias, requires_grad=False) if bias is not None else None
        self.dtype = dtype

    def forward(self, x):
        x = x / self.scale_factor.to(self.dtype)
        #weight_decode = self.unpack_weight(self.weight).to(x.device) * self.scale.to(x.device)
        weight_decode = (self.unpack_weight(self.weight).to(self.dtype)+self.zero_point)*self.scale
        if self.bias is not None:
            out = x @ weight_decode.T.to(x.dtype) + self.bias.to(x.dtype)
        else:
            out = x @ weight_decode.T.to(x.dtype)
        return out

    @staticmethod
    def pack_weight(weight):
        # 打包到 int8,区间应该在
        weight = weight.to(torch.int8)  
        packed_weight = torch.zeros(weight.shape[0] // 2, weight.shape[1], dtype=torch.int8,device=weight.device)
        packed_weight |= weight[:weight.shape[0] // 2 * 2:2] << 4
        packed_weight |= weight[1:weight.shape[0] // 2 * 2:2]
        return packed_weight

    @staticmethod
    def unpack_weight(packed_weight):
        # 解包 int8 为 4bit
        weight = torch.zeros(packed_weight.shape[0] * 2, packed_weight.shape[1], dtype=torch.int8,device=packed_weight.device)
        weight[:weight.shape[0] // 2 * 2:2] = (packed_weight >> 4) & 0x0F
        weight[1:weight.shape[0] // 2 * 2:2] = packed_weight & 0x0F
        return weight
#apply
def apply_quantized_layers(model,config,prefix=''):
    for name, module in model.named_children():
        full_name = prefix + ('.' if prefix else '') + name
        if isinstance(module, nn.Linear):
            #print(full_name)
            if config.get(full_name)["quantized"]==True:
                    # 初始化 Qlinear 的参数
                    #print(module.weight.data.shape,module.weight.shape)
                    weight = torch.zeros_like(module.weight.data,device=module.weight.data.device)  # 使用 zeros_like 初始化权重
                    packed_weight = Qlinear_4.pack_weight(weight)  # 打包权重
                    scale=torch.zeros((module.weight.size(0),1),device=module.weight.data.device) # 初始化 scale,与输出通道维度保持一致
                    zero_point =torch.zeros_like(scale,device=module.weight.data.device)   # 初始化 zero_point,与scale保持一致
                    scale_factor = torch.ones(module.weight.size(1),device=module.weight.data.device)  # 初始化 scale_factor,假设初始值为 1
                    # 创建 Qlinear_4 实例
                    module = Qlinear_4(
                    weight=packed_weight,
                    scale=scale,
                    zero_point=zero_point,
                    scale_factor=scale_factor,
                    bias=module.bias,
                    dtype=torch.bfloat16
                )        
            else:
                module=module
            setattr(model, name, module)  # 替换模块
                
        else:
            apply_quantized_layers(module,config,full_name)  # 递归处理子模块


#config read
import json
def load_config(config_path):
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

def  apply_AWQ(model,quantization_config_path,quantization_weight_path):
    quantization_config=load_config(quantization_config_path)
    apply_quantized_layers(model,quantization_config)
    state_dict = torch.load(quantization_weight_path,map_location='cpu')
    model.load_state_dict(state_dict)
    del state_dict
    torch.cuda.empty_cache()
    return True