File size: 3,391 Bytes
d5dfd96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch.nn as nn
import torch

def quantize(tensor, scale, zero_point, is_asym=False):
    if is_asym:
        clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.)
    else:
        clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.)
    quant_tensor = torch.clamp(torch.round(tensor/scale), clamp_min, clamp_max) + zero_point
    return quant_tensor

def dequantize(tensor, scale, zero_point):
    return (tensor - zero_point) * scale


class QuantLinear(nn.Module):
    def __init__(self, quant_param):
        super().__init__()
        mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
        self.register_buffer('mul_factor', mul_factor)
        self.linear = nn.Linear(128, 128)
        weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
        weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
        input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
        input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
        self.register_buffer('weight_scale', weight_scale)
        self.register_buffer('weight_zp', weight_zp)
        self.register_buffer('input_scale', input_scale)
        self.register_buffer('input_zp', input_zp)

    def forward(self, x):
        scaled_x = x * self.mul_factor
        quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
        quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
        dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
        dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
        out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
        return out

class QuantConv2d(nn.Module):
    def __init__(self, quant_param):
        super().__init__()
        mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
        self.register_buffer('mul_factor', mul_factor)
        self.conv2d = nn.Conv2d(128, 128, 3)
        weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
        weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
        input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
        input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
        self.register_buffer('weight_scale', weight_scale)
        self.register_buffer('weight_zp', weight_zp)
        self.register_buffer('input_scale', input_scale)
        self.register_buffer('input_zp', input_zp)

    def forward(self, x):
        scaled_x = x * self.mul_factor
        quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
        quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
        dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
        dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
        out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
        return out