levit-256-onnx / README.md
Felix Marty
better readme
b93cbc9
|
raw
history blame
2.24 kB
---
license: apache-2.0
tags:
- vision
- image-classification
datasets:
- imagenet-1k
---
This model is a fork of [facebook/levit-256](https://huggingface.co/facebook/levit-256), where:
* `nn.BatchNorm2d` and `nn.Conv2d` are fused
* `nn.BatchNorm1d` and `nn.Linear` are fused
and the optimized model is converted to the onnx format.
## How to use
```python
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoModelForImageClassification
pt_model = AutoModelForImageClassification.from_pretrained("facebook/levit-256")
pt_model.eval()
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")
inp = {"pixel_values": torch.rand(1, 3, 224, 224)}
with torch.no_grad():
res = pt_model(**inp)
res_ort = ort_model(**inp)
assert torch.allclose(res.logits, res_ort.logits, atol=1e-4)
```
## Benchmarking
More than x2 throughput with batch normalization folding and onnxruntime 🔥
```
PyTorch runtime:
{'latency_50': 22.3024695,
'latency_90': 23.1230725,
'latency_95': 23.2653985,
'latency_99': 23.60095705,
'latency_999': 23.865580469999998,
'latency_mean': 22.442956878923766,
'latency_std': 0.46544295612971265,
'nb_forwards': 446,
'throughput': 44.6}
Optimum-onnxruntime runtime:
{'latency_50': 9.302445,
'latency_90': 9.782875,
'latency_95': 9.9071944,
'latency_99': 11.084606999999997,
'latency_999': 12.035858692000001,
'latency_mean': 9.357703552853133,
'latency_std': 0.4018553286992142,
'nb_forwards': 1069,
'throughput': 106.9}
```
```python
from optimum.runs_base import TimeBenchmark
from pprint import pprint
time_benchmark_ort = TimeBenchmark(
model=ort_model,
batch_size=1,
input_length=224,
model_input_names={"pixel_values"},
warmup_runs=10,
duration=10
)
results_ort = time_benchmark_ort.execute()
with torch.no_grad():
time_benchmark_pt = TimeBenchmark(
model=pt_model,
batch_size=1,
input_length=224,
model_input_names={"pixel_values"},
warmup_runs=10,
duration=10
)
results_pt = time_benchmark_pt.execute()
print("PyTorch runtime:\n")
pprint(results_pt)
print("\nOptimum-onnxruntime runtime:\n")
pprint(results_ort)
```