levit-256-onnx / README.md
Felix Marty
mention torch fx
9f0bbf2
metadata
license: apache-2.0
tags:
  - vision
  - image-classification
datasets:
  - imagenet-1k

This model is a fork of facebook/levit-256, where:

  • nn.BatchNorm2d and nn.Conv2d are fused
  • nn.BatchNorm1d and nn.Linear are fused

and the optimized model is converted to the onnx format.

The fusion of layers leverages torch.fx, using the transformations FuseBatchNorm2dInConv2d and FuseBatchNorm1dInLinear soon to be available to use out-of-the-box with 🤗 Optimum, check it out: https://huggingface.co/docs/optimum/main/en/fx/optimization#the-transformation-guide .

How to use

from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoFeatureExtractor

from PIL import Image
import requests

preprocessor = AutoFeatureExtractor.from_pretrained("fxmarty/levit-256-onnx")
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

inputs = preprocessor(images=image, return_tensors="pt")
outputs = model(**inputs)

predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

To be safe, check as well that the onnx model returns the same logits as the PyTorch model:

from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoModelForImageClassification

pt_model = AutoModelForImageClassification.from_pretrained("facebook/levit-256")
pt_model.eval()

ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")

inp = {"pixel_values": torch.rand(1, 3, 224, 224)}

with torch.no_grad():
    res = pt_model(**inp)
res_ort = ort_model(**inp)

assert torch.allclose(res.logits, res_ort.logits, atol=1e-4)

Benchmarking

More than x2 throughput with batch normalization folding and onnxruntime 🔥

Below you can find latency percentiles and mean (in ms), and the models throughput (in iterations/s).

PyTorch runtime:

{'latency_50': 22.3024695,
 'latency_90': 23.1230725,
 'latency_95': 23.2653985,
 'latency_99': 23.60095705,
 'latency_999': 23.865580469999998,
 'latency_mean': 22.442956878923766,
 'latency_std': 0.46544295612971265,
 'nb_forwards': 446,
 'throughput': 44.6}

Optimum-onnxruntime runtime:

{'latency_50': 9.302445,
 'latency_90': 9.782875,
 'latency_95': 9.9071944,
 'latency_99': 11.084606999999997,
 'latency_999': 12.035858692000001,
 'latency_mean': 9.357703552853133,
 'latency_std': 0.4018553286992142,
 'nb_forwards': 1069,
 'throughput': 106.9}

Run on your own machine with:

from optimum.runs_base import TimeBenchmark

from pprint import pprint

time_benchmark_ort = TimeBenchmark(
    model=ort_model,
    batch_size=1,
    input_length=224,
    model_input_names={"pixel_values"},
    warmup_runs=10,
    duration=10
)

results_ort = time_benchmark_ort.execute()

with torch.no_grad():
    time_benchmark_pt = TimeBenchmark(
        model=pt_model,
        batch_size=1,
        input_length=224,
        model_input_names={"pixel_values"},
        warmup_runs=10,
        duration=10
    )

    results_pt = time_benchmark_pt.execute()

print("PyTorch runtime:\n")
pprint(results_pt)

print("\nOptimum-onnxruntime runtime:\n")
pprint(results_ort)