import torch import torch.nn as nn from math_model import QuantLinear torch.manual_seed(0) batch_size = 1 out_ch = 128 in_ch = 64 i = 2*torch.rand((batch_size,in_ch)) - 1. l = nn.Linear(in_ch, out_ch, bias=True) quant_params = { 'smoothquant_mul': torch.rand((in_ch,)), 'smoothquant_mul_shape': (1,in_ch), 'weight_scale': torch.max(torch.abs(l.weight), dim=1).values / 128., 'weight_scale_shape': (out_ch,1), 'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=1)) * (128 / torch.max(torch.abs(l.weight), dim=1).values)) + 128, 0, 255), 'weight_zp_shape': (out_ch,1), 'input_scale': torch.max(torch.abs(i)) / 128., 'input_scale_shape': tuple(), 'input_zp': torch.zeros((1,)), 'input_zp_shape': tuple(), } print(quant_params) ql = QuantLinear(in_ch, out_ch, quant_params) ql.linear.load_state_dict(l.state_dict()) o_qdq = ql(i) o_qop = ql(i, qop=True) print(o_qdq.shape) print(o_qop.shape) print(o_qdq - o_qop)