sdxl-quant-int8 / test_quant_conv2d.py
nickfraser's picture
Updated math model to target int8 x int8 kernels.
4024f9d
raw
history blame
No virus
1.11 kB
import torch
import torch.nn as nn
from math_model import QuantConv2d
torch.manual_seed(0)
batch_size = 1
out_ch = 128
in_ch = 64
k = 3
h = 5
w = 5
i = 2*torch.rand((batch_size,in_ch,h,w)) - 1.
l = nn.Conv2d(in_ch, out_ch, k, bias=True)
quant_params = {
'smoothquant_mul': torch.rand((in_ch,)),
'smoothquant_mul_shape': (1,in_ch,1,1),
'weight_scale': torch.rand((out_ch,)),
'weight_scale': torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values / 128.,
'weight_scale_shape': (out_ch,1,1,1),
'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=(1,2,3))) * (128 / torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values)) + 128, 0, 255),
'weight_zp_shape': (out_ch,1,1,1),
'input_scale': torch.max(torch.abs(i)) / 128.,
'input_scale_shape': tuple(),
'input_zp': torch.zeros((1,)),
'input_zp_shape': tuple(),
}
print(quant_params)
ql = QuantConv2d(in_ch, out_ch, k, quant_params)
ql.conv2d.load_state_dict(l.state_dict())
o_qdq = ql(i)
o_qop = ql(i, qop=True)
print(o_qdq.shape)
print(o_qop.shape)
print(o_qdq - o_qop)