Posted: 2018-09-27
It might seem tricky or intimidating to convert model formats, but ONNX makes it easier. However, we must get our PyTorch model into the ONNX format. This involves both the weights and network architecture defined by a PyToch model class (inheriting from nn.Module
).
I don't write out the model classes, however, I wanted to share the steps and code from the point of having the class definition and some weights (either in memory or from a model path file). One could also do this with the pre-trained models from the torchvision library.
The General Steps
- Define the model class (if using a custom model)
- Train the model and/or load the weights, usually a
.pth
or.pt
file by convention, to something usually called thestate_dict
- note, we are only loading the weights from a file. A pre-trained model such as is found intorchvision.models
may also be used with the provided weights (usingpretrained=True
- see below). - Create a properly shaped input vector (can be some sample data - the important part is the shape)
- (Optional) Give the input and output layers names (to later reference back)
- Export to ONNX format with the PyTorch ONNX exporter
Prerequisites
- PyTorch and torchvision installed
- A PyTorch model class and model weights
Using a Custom Model Class and Weights File
The Python looks something like:
import torch
import torch.onnx
# A model class instance (class not shown)
model = MyModelClass()
# Load the weights from a file (.pth usually)
state_dict = torch.load(weights_path)
# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)
# Create the right input shape (e.g. for an image)
dummy_input = torch.randn(sample_batch_size, channel, height, width)
torch.onnx.export(model, dummy_input, "onnx_model_name.onnx")
The state dictionary, or state_dict
, is a Python dict containing parameter values and persistent buffers. (Docs)
Note: The preferred way of saving the weights is with
torch.save(the_model.state_dict(), <name_here.pth>)
. (Docs)
A Pre-Trained Model from torchvision
If using the torchvision.models
pretrained vision models all you need to do is, e.g., for AlexNet:
import torch
import torchvision.models as models
# Use an existing model from Torchvision, note it
# will download this if not already on your computer (might take time)
model = models.alexnet(pretrained=True)
# Create some sample input in the shape this model expects
dummy_input = torch.randn(10, 3, 224, 224)
# It's optional to label the input and output layers
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
# Use the exporter from torch to convert to onnx
# model (that has the weights and net arch)
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
Note, the pretrained model weights that comes with
torchvision.models
went into a home folder~/.torch/models
in case you go looking for it later.
Summary
Here, I showed how to take a pre-trained PyTorch model (a weights object and network class object) and convert it to ONNX format (that contains the weights and net structure).
As of now, we can not import an ONNX model for use in PyTorch. There are other projects that are working on this as well as is shown in this list.