Skip to content

How to Convert a PyTorch Model to ONNX Format

Posted: 2018-09-27

It might seem tricky or intimidating to convert model formats, but ONNX makes it easier. However, we must get our PyTorch model into the ONNX format. This involves both the weights and network architecture defined by a PyToch model class (inheriting from nn.Module).

I don't write out the model classes, however, I wanted to share the steps and code from the point of having the class definition and some weights (either in memory or from a model path file). One could also do this with the pre-trained models from the torchvision library.

The General Steps

  1. Define the model class (if using a custom model)
  2. Train the model and load the weights (.pth file by convention) to something usually called the state_dict or load a pre-trained model such as is found in torchvision.models
  3. Create a properly shaped input vector
  4. (Optional) Give the input and output layers names (to later reference back)
  5. Export to ONNX format with the PyTorch ONNX exporter

Prerequisites

  1. PyTorch and torchvision installed
  2. A PyTorch model class and model weights

Using a Custom Model Class and Weights File

The Python looks something like:

import torch
import torch.onnx

# A model class instance (class not shown)
model = MyModelClass()

# Load the weights from a file (.pth usually)
state_dict = torch.load(weights_path)

# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)

# Create the right input shape (e.g. for an image)
dummy_input = torch.randn(sample_batch_size, channel, height, width)

torch.onnx.export(model, dummy_input, "onnx_model_name.onnx")

The state dictionary, or state_dict, is a Python dict containing parameter values and persistent buffers. (Docs)

Note: The preferred way of saving the weights is with torch.save(the_model.state_dict(), <name_here.pth>). (Docs)

A Pre-Trained Model from torchvision

If using the torchvision.models pretrained vision models all you need to do is, e.g., for AlexNet:

import torch
import torchvision.models as models

# Use an existing model from Torchvision, note it 
# will download this if not already on your computer (might take time)
model = models.alexnet(pretrained=True)

# Create some sample input in the shape this model expects
dummy_input = torch.randn(10, 3, 224, 224)

# It's optional to label the input and output layers
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

# Use the exporter from torch to convert to onnx 
# model (that has the weights and net arch)
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

Note, the pretrained model weights that comes with torchvision.models went into a home folder ~/.torch/models in case you go looking for it later.

Summary

Here, I showed how to take a pre-trained PyTorch model (a weights object and network class object) and convert it to ONNX format (that contains the weights and net structure).

More References

  1. Example: End-to-end AlexNet from PyTorch to Caffe2
  2. ONNX GitHub
  3. PyTorch.org