import torch.nn as nn import torch def quantize(tensor, scale, zero_point, is_asym=False): if is_asym: clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.) else: clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.) quant_tensor = torch.clamp(torch.round(tensor/scale), clamp_min, clamp_max) + zero_point return quant_tensor def dequantize(tensor, scale, zero_point): return (tensor - zero_point) * scale class QuantLinear(nn.Module): def __init__(self, quant_param): super().__init__() mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) self.register_buffer('mul_factor', mul_factor) self.linear = nn.Linear(128, 128) weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) self.register_buffer('weight_scale', weight_scale) self.register_buffer('weight_zp', weight_zp) self.register_buffer('input_scale', input_scale) self.register_buffer('input_zp', input_zp) def forward(self, x): scaled_x = x * self.mul_factor quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True) quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False) dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp) dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp) out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias) return out class QuantConv2d(nn.Module): def __init__(self, quant_param): super().__init__() mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) self.register_buffer('mul_factor', mul_factor) self.conv2d = nn.Conv2d(128, 128, 3) weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) self.register_buffer('weight_scale', weight_scale) self.register_buffer('weight_zp', weight_zp) self.register_buffer('input_scale', input_scale) self.register_buffer('input_zp', input_zp) def forward(self, x): scaled_x = x * self.mul_factor quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True) quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False) dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp) dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp) out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias) return out