Posted: 2018-09-27

It might seem tricky or intimidating to convert model formats, but ONNX makes it easier. However, we must get our PyTorch model into the ONNX format. This involves both the weights and network architecture defined by a PyToch model class (inheriting from nn.Module).

I don't write out the model classes, however, I wanted to share the steps and code from the point of having the class definition and some weights (either in memory or from a model path file). One could also do this with the pre-trained models from the torchvision library.

The General Steps

  1. Define the model class (if using a custom model)
  2. Train the model and/or load the weights, usually a .pth or .pt file by convention, to something usually called the state_dict - note, we are only loading the weights from a file. A pre-trained model such as is found in torchvision.models may also be used with the provided weights (using pretrained=True - see below).
  3. Create a properly shaped input vector (can be some sample data - the important part is the shape)
  4. (Optional) Give the input and output layers names (to later reference back)
  5. Export to ONNX format with the PyTorch ONNX exporter

Prerequisites

  1. PyTorch and torchvision installed
  2. A PyTorch model class and model weights

Using a Custom Model Class and Weights File

The Python looks something like:

import torch
import torch.onnx

# A model class instance (class not shown)
model = MyModelClass()

# Load the weights from a file (.pth usually)
state_dict = torch.load(weights_path)

# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)

# Create the right input shape (e.g. for an image)
dummy_input = torch.randn(sample_batch_size, channel, height, width)

torch.onnx.export(model, dummy_input, "onnx_model_name.onnx")

The state dictionary, or state_dict, is a Python dict containing parameter values and persistent buffers. (Docs)

Note: The preferred way of saving the weights is with torch.save(the_model.state_dict(), <name_here.pth>). (Docs)

A Pre-Trained Model from torchvision

If using the torchvision.models pretrained vision models all you need to do is, e.g., for AlexNet:

import torch
import torchvision.models as models

# Use an existing model from Torchvision, note it 
# will download this if not already on your computer (might take time)
model = models.alexnet(pretrained=True)

# Create some sample input in the shape this model expects
dummy_input = torch.randn(10, 3, 224, 224)

# It's optional to label the input and output layers
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

# Use the exporter from torch to convert to onnx 
# model (that has the weights and net arch)
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

Note, the pretrained model weights that comes with torchvision.models went into a home folder ~/.torch/models in case you go looking for it later.

Summary

Here, I showed how to take a pre-trained PyTorch model (a weights object and network class object) and convert it to ONNX format (that contains the weights and net structure).

As of now, we can not import an ONNX model for use in PyTorch. There are other projects that are working on this as well as is shown in this list.

More References

  1. Example: End-to-end AlexNet from PyTorch to Caffe2
  2. ONNX GitHub
  3. PyTorch.org
  4. For a more complicated example, see this conversion