File size: 3,426 Bytes
c727080
 
b93cbc9
 
 
 
 
c727080
1fbf25f
b93cbc9
1fbf25f
 
 
 
 
 
9f0bbf2
 
1fbf25f
 
f6c4ac6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fbf25f
 
 
 
b93cbc9
 
 
1fbf25f
 
b93cbc9
1fbf25f
b93cbc9
 
1fbf25f
 
 
 
b93cbc9
 
 
 
 
78caf38
 
b93cbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c4ac6
 
b93cbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
---
license: apache-2.0
tags:
- vision
- image-classification
datasets:
- imagenet-1k
---

This model is a fork of [facebook/levit-256](https://huggingface.co/facebook/levit-256), where:

* `nn.BatchNorm2d` and `nn.Conv2d` are fused
* `nn.BatchNorm1d` and `nn.Linear` are fused

and the optimized model is converted to the onnx format.

The fusion of layers leverages torch.fx, using the transformations `FuseBatchNorm2dInConv2d` and `FuseBatchNorm1dInLinear` soon to be available to use out-of-the-box with 🤗 Optimum, check it out: https://huggingface.co/docs/optimum/main/en/fx/optimization#the-transformation-guide .

## How to use

```python
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoFeatureExtractor

from PIL import Image
import requests

preprocessor = AutoFeatureExtractor.from_pretrained("fxmarty/levit-256-onnx")
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

inputs = preprocessor(images=image, return_tensors="pt")
outputs = model(**inputs)

predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
```

To be safe, check as well that the onnx model returns the same logits as the PyTorch model:

```python
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoModelForImageClassification

pt_model = AutoModelForImageClassification.from_pretrained("facebook/levit-256")
pt_model.eval()

ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")

inp = {"pixel_values": torch.rand(1, 3, 224, 224)}

with torch.no_grad():
    res = pt_model(**inp)
res_ort = ort_model(**inp)

assert torch.allclose(res.logits, res_ort.logits, atol=1e-4)
```

## Benchmarking

More than x2 throughput with batch normalization folding and onnxruntime 🔥

Below you can find latency percentiles and mean (in ms), and the models throughput (in iterations/s).

```
PyTorch runtime:

{'latency_50': 22.3024695,
 'latency_90': 23.1230725,
 'latency_95': 23.2653985,
 'latency_99': 23.60095705,
 'latency_999': 23.865580469999998,
 'latency_mean': 22.442956878923766,
 'latency_std': 0.46544295612971265,
 'nb_forwards': 446,
 'throughput': 44.6}

Optimum-onnxruntime runtime:

{'latency_50': 9.302445,
 'latency_90': 9.782875,
 'latency_95': 9.9071944,
 'latency_99': 11.084606999999997,
 'latency_999': 12.035858692000001,
 'latency_mean': 9.357703552853133,
 'latency_std': 0.4018553286992142,
 'nb_forwards': 1069,
 'throughput': 106.9}

```

Run on your own machine with:

```python
from optimum.runs_base import TimeBenchmark

from pprint import pprint

time_benchmark_ort = TimeBenchmark(
    model=ort_model,
    batch_size=1,
    input_length=224,
    model_input_names={"pixel_values"},
    warmup_runs=10,
    duration=10
)

results_ort = time_benchmark_ort.execute()

with torch.no_grad():
    time_benchmark_pt = TimeBenchmark(
        model=pt_model,
        batch_size=1,
        input_length=224,
        model_input_names={"pixel_values"},
        warmup_runs=10,
        duration=10
    )

    results_pt = time_benchmark_pt.execute()

print("PyTorch runtime:\n")
pprint(results_pt)

print("\nOptimum-onnxruntime runtime:\n")
pprint(results_ort)
```