!pip install -q ivy
!pip install -q transformers
!pip install -q dm-haiku
Accelerating PyTorch models with JAX
Accelerate your Pytorch models by converting them to JAX for faster inference.
⚠️ If you are running this notebook in Colab, you will have to install Ivy
and some dependencies manually. You can do so by running the cell below ⬇️
If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Get Started section of the docs.
Make sure you run this demo with GPU enabled!
Let’s now import Ivy and the libraries we’ll use in this example:
import jax
import ivy
import torch
import requests
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoFeatureExtractor
Now we can load a ResNet model and its corresponding feature extractor from Hugging Face transformers library
"jax_enable_x64", False)
jax.config.update(
= "ResNet"
arch_name = "microsoft/resnet-50"
checkpoint_name
= AutoFeatureExtractor.from_pretrained(checkpoint_name)
feature_extractor = AutoModel.from_pretrained(checkpoint_name) model
We will also need a sample image to pass during tracing, so let’s use the feature extractor to get the corresponding torch tensors.
= "http://images.cocodataset.org/val2017/000000039769.jpg"
url = Image.open(requests.get(url, stream=True).raw)
image = feature_extractor(
inputs =image, return_tensors="pt"
images )
And finally, let’s transpile the model to haiku!
= ivy.transpile(model, to="haiku", kwargs=inputs) transpiled_graph
After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile
= feature_extractor(
inputs =image, return_tensors="pt"
images"cuda")
).to(
"cuda")
model.to(
def _f(**kwargs):
return model(**kwargs)
= torch.compile(_f)
comp_model = comp_model(**inputs) _
Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:
= feature_extractor(
inputs_jax =image, return_tensors="jax"
images
)
import haiku as hk
def _forward(**kwargs):
= transpiled_graph()
module return module(**kwargs).last_hidden_state
= jax.random.PRNGKey(42)
rng_key = hk.transform(_forward)
jax_forward = jax_forward.init(rng=rng_key, **inputs_jax)
params = jax.jit(jax_forward.apply) jit_apply
Now that we have both models optimized, let’s see how their runtime speeds compare to each other!
%%timeit
= comp_model(**inputs) _
9.67 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
= jit_apply(params, None, **inputs_jax) out
4.09 ms ± 9.48 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
As expected, we have made the model significantly faster with just one line of code, getting a ~2x increase in its execution speed! 🚀
Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models
= "http://images.cocodataset.org/train2017/000000283921.jpg"
url = Image.open(requests.get(url, stream=True).raw)
image = feature_extractor(
inputs =image, return_tensors="pt"
images"cuda")
).to(= feature_extractor(
inputs_jax =image, return_tensors="jax"
images
)= comp_model(**inputs)
out_torch = jit_apply(params, None, **inputs_jax)
out_jax
=1e-4) np.allclose(out_torch.last_hidden_state.detach().cpu().numpy(), out_jax, atol
True
That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!