How to Transfer a Simple Keras Model to PyTorch – The Hard Way

I had a problem.

I had trained a model in Keras using a TensorFlow backend.

Unfortunately, and for reasons I won’t get into here, I wanted to use that trained model with PyTorch.

I couldn’t find a good library to transfer the models whole or a good writeup on how to transfer the weights manually so I figured I would put a super quick writeup about the procedure here.n int

Note: If you are looking for an introduction to Machine Learning and AI in python, I’ve heard this free book is good.

Step 0: Train a Model in Keras

For this post, I’m just going to train a deep neural network (NN) to approximate a sin wave. The NN is going to have 5 layers, 32 neurons per layer, use ReLU’s for the hidden layers, and Tanh’s for the output layer.

Performance of Keras

As can be seen above, the Keras model learned the sin wave quite well, especially in the -pi to pi region.

Step 1: Recreate & Initialize Your Model Architecture in PyTorch

The reason I call this transfer method “The hard way” is because we’re going to have to recreate the network architecture in PyTorch. There are tons of other resources to learn PyTorch.

Step 2: Import Your Keras Model and Copy the Weights

Step 3: Load Those Weights onto Your PyTorch Model

Note that when you load the weights into your PyTorch model, you will need to transpose the weights, but not the biases.

Step 4: Test and Save Your Pytorch Model

Always test your model before you save it to ensure that no errors slipped by.

Yep, those look pretty similar, time to save the model!

Foats Vs Double

You’ll notice that there is an error between the Keras and PyTorch model predictions.

I believe this is because Keras saves its weights as doubles while PyTorch saves its weights as floats and there is a small amount of truncation error. You should be able to fix this by using yourModel.double() , but your inputs will now need to be a DoubleTensor. For my use case, the truncation didn’t matter, but you should be aware of it.

GitHub

This entire project can be found here

Want More Gereshes

If you want to receive new Gereshes blog post directly to your email when they come out, you can sign up for that here!

Don’t want another email? That’s ok, Gereshes also has a twitter account and subreddit!