--- license: apache-2.0 tags: - vision - image-classification datasets: - imagenet-1k --- This model is a fork of [facebook/levit-256](https://huggingface.co/facebook/levit-256), where: * `nn.BatchNorm2d` and `nn.Conv2d` are fused * `nn.BatchNorm1d` and `nn.Linear` are fused and the optimized model is converted to the onnx format. ## How to use ```python from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification from transformers import AutoModelForImageClassification pt_model = AutoModelForImageClassification.from_pretrained("facebook/levit-256") pt_model.eval() ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx") inp = {"pixel_values": torch.rand(1, 3, 224, 224)} with torch.no_grad(): res = pt_model(**inp) res_ort = ort_model(**inp) assert torch.allclose(res.logits, res_ort.logits, atol=1e-4) ``` ## Benchmarking More than x2 throughput with batch normalization folding and onnxruntime 🔥 ``` PyTorch runtime: {'latency_50': 22.3024695, 'latency_90': 23.1230725, 'latency_95': 23.2653985, 'latency_99': 23.60095705, 'latency_999': 23.865580469999998, 'latency_mean': 22.442956878923766, 'latency_std': 0.46544295612971265, 'nb_forwards': 446, 'throughput': 44.6} Optimum-onnxruntime runtime: {'latency_50': 9.302445, 'latency_90': 9.782875, 'latency_95': 9.9071944, 'latency_99': 11.084606999999997, 'latency_999': 12.035858692000001, 'latency_mean': 9.357703552853133, 'latency_std': 0.4018553286992142, 'nb_forwards': 1069, 'throughput': 106.9} ``` ```python from optimum.runs_base import TimeBenchmark from pprint import pprint time_benchmark_ort = TimeBenchmark( model=ort_model, batch_size=1, input_length=224, model_input_names={"pixel_values"}, warmup_runs=10, duration=10 ) results_ort = time_benchmark_ort.execute() with torch.no_grad(): time_benchmark_pt = TimeBenchmark( model=pt_model, batch_size=1, input_length=224, model_input_names={"pixel_values"}, warmup_runs=10, duration=10 ) results_pt = time_benchmark_pt.execute() print("PyTorch runtime:\n") pprint(results_pt) print("\nOptimum-onnxruntime runtime:\n") pprint(results_ort) ```