File size: 8,124 Bytes
d5dfd96
 
 
 
 
 
 
 
049c65f
d5dfd96
 
 
 
 
 
 
eb5a5f6
d5dfd96
 
 
eb5a5f6
d5dfd96
 
 
 
 
 
 
 
 
eb5a5f6
 
d5dfd96
 
 
 
 
 
 
 
eb5a5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5dfd96
eb5a5f6
d5dfd96
 
 
eb5a5f6
d5dfd96
 
 
 
 
 
 
 
 
eb5a5f6
 
d5dfd96
eb5a5f6
d5dfd96
 
 
 
 
eb5a5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch.nn as nn
import torch

def quantize(tensor, scale, zero_point, is_asym=False):
    if is_asym:
        clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.)
    else:
        clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.)
    quant_tensor = torch.clamp(torch.round(tensor/scale + zero_point), clamp_min, clamp_max) 
    return quant_tensor

def dequantize(tensor, scale, zero_point):
    return (tensor - zero_point) * scale


class QuantLinear(nn.Module):
    def __init__(self, in_ch, out_ch, quant_param):
        super().__init__()
        mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
        self.register_buffer('mul_factor', mul_factor)
        self.linear = nn.Linear(in_ch, out_ch)
        weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
        weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
        input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
        input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
        self.register_buffer('weight_scale', weight_scale)
        self.register_buffer('weight_zp', weight_zp)
        self.register_buffer('input_scale', input_scale)
        self.register_buffer('input_zp', input_zp)

    # I.e., "fake quantization"
    def qdq_forward(self, x):
        scaled_x = x * self.mul_factor
        quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
        quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
        dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
        dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
        out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
        return out

    # Accelerated version
    def qop_forward(self, x):
        # With an integer linear kernel, if the weight zero point is not zero,
        # A correction term must be calculated to correct the output.
        # The correction term calculated as follows:
        #  - sum the input tensor across the dot-product dimentions: (e.g., `torch.sum(quant_input, dim=-1)`)
        #  - multiply this sum with every weight zero-point (e.g., `torch.sum(quant_input, dim=-1) * self.weight_zp`
        #  - Subtract from previous output (e.g., `quant_output -= torch.sum(quant_input, dim=-1) * self.weight_zp`)
        #  - All other code is just to make sure the broadcasting semantics work correctly
        scaled_x = x * self.mul_factor
        quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True).to(torch.uint8)
        quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False).to(torch.int8)
        quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
        correction = torch.sum(quant_input, dim=-1).to(torch.int32).unsqueeze(-1) * (-self.weight_zp).to(torch.uint8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
        quant_output = quant_output + correction
        output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
        output += self.linear.bias
        return output

    def forward(self, x, qop=False):
        if qop:
            return self.qop_forward(x)
        else:
            return self.qdq_forward(x)

class QuantConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, quant_param):
        super().__init__()
        mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
        self.register_buffer('mul_factor', mul_factor)
        self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
        weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
        weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
        input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
        input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
        self.register_buffer('weight_scale', weight_scale)
        self.register_buffer('weight_zp', weight_zp)
        self.register_buffer('input_scale', input_scale)
        self.register_buffer('input_zp', input_zp)

    # I.e., "fake quantization"
    def qdq_forward(self, x):
        scaled_x = x * self.mul_factor
        quant_weight = quantize(self.conv2d.weight, self.weight_scale, self.weight_zp, is_asym=True)
        quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
        dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
        dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
        out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
        return out

    # Accelerated version
    def qop_forward(self, x):
        # With an integer conv2d kernel, if the weight zero point is not zero,
        # A correction term must be calculated to correct the output.
        # Conceptually, it's identical to the linear case except that it's difficult
        # to reduce the input across the dot-product dimension. This leaves us with two obvious options:
        #  1. Manually compute the reduction via Im2Col -> `torch.sum`
        #  2. Add an extra _output channel_ to the convolution with a kernel made from all ones (e.g., `torch.ones()`)
        # In this example, I've used option #2.
        # The correction term is then calculated as follows:
        #  - Add an extra output channel to the weight tensor with all values equal to 1 to calculate the sum (e.g., `torch.cat((quant_weight, torch.ones(shape)), dim=0)`)
        #  - Extract the sum from the output tensor (e.g., `sum = quant_output[:,-1,:,:]`)
        #  - multiply this sum with every weight zero-point (e.g., `sum * self.weight_zp`
        #  - Subtract from previous output (e.g., `quant_output -= sum * self.weight_zp`)
        #  - All other code is just to make sure the broadcasting semantics work correctly
        scaled_x = x * self.mul_factor
        quant_weight = quantize(self.conv2d.weight, self.weight_scale, self.weight_zp, is_asym=True).to(torch.uint8)
        b_shape = list(quant_weight.shape) # Used for weight zero-point correction
        b_shape[0] = 1 # Used for weight zero-point correction
        weight_cat = torch.ones((1,1,1,1)).broadcast_to(b_shape).to(torch.uint8) # Used for weight zero-point correction
        quant_weight = torch.cat((quant_weight,weight_cat),dim=0).to(torch.uint8) # Create extra output channel, used for weight zero-point correction
        quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False).to(torch.int8)
        quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
        correction = quant_output[:,-1,:,:] * (-self.weight_zp).to(torch.uint8).view([1, self.weight_zp.nelement()] +  [1]*(quant_output.ndim-2)) # Correct zero-point for weight
        quant_output = quant_output[:,:-1,:,:] + correction
        output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)), 0.0)
        output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
        return output

    def forward(self, x, qop=False):
        if qop:
            return self.qop_forward(x)
        else:
            return self.qdq_forward(x)