!pip install -U -q openmim && mim install -q "mmpretrain>=1.0.0rc8"
!pip install -q ivy
!pip install -q dm-haiku
Accelerating MMPreTrain models with JAX
Accelerate your MMPreTrain models by converting them to JAX for faster inference.
Installations
Make sure you run this demo with GPU enabled!
Let’s now import Ivy and the libraries we’ll use in this example:
import jax
import ivy
import torch
import requests
import numpy as np
from PIL import Image
import time
import torchvision
from mmpretrain import get_model, list_models
from mmengine import ConfigDict
Sanity check to make sure checkpoint name is correct against mmpretrain’s model zoo
= "convnext-tiny_32xb128-noema_in1k"
checkpoint_name list_models(checkpoint_name)
['convnext-tiny_32xb128-noema_in1k']
Now we can load the ConvNext model from OpenMMLab’s mmpretrain library
"jax_enable_x64", True)
jax.config.update(
= get_model(checkpoint_name, pretrained=True, device='cuda') model
We will also need a sample image to pass during tracing, so let’s use the appropriate transforms to get the corresponding torch tensors.
def get_scale(cfg):
if type(cfg) == ConfigDict:
if cfg.get('type', False) and cfg.get('scale', False):
return cfg['scale']
else:
for k in cfg.keys():
= get_scale(cfg[k])
input_shape if input_shape:
return input_shape
elif type(cfg) == list:
for block in cfg:
= get_scale(block)
input_shape if input_shape:
return input_shape
else:
return None
= "http://images.cocodataset.org/val2017/000000039769.jpg"
url = Image.open(requests.get(url, stream=True).raw)
image = get_scale(model._config.train_pipeline)
input_shape = torchvision.transforms.Compose([
transform
torchvision.transforms.Resize((input_shape, input_shape)),
torchvision.transforms.ToTensor()
])= transform(image).unsqueeze(0).to("cuda") tensor_image
And finally, let’s transpile the model to haiku!
= ivy.transpile(model, to="haiku", args=(tensor_image,)) transpiled_graph
After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile
= transform(image).unsqueeze(0).to("cuda")
tensor_image
def _f(args):
return model(args)
= torch.compile(_f)
comp_model = comp_model(tensor_image) _
Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:
= transform(image).unsqueeze(0).to("cuda")
tensor_image = tensor_image.detach().cpu().numpy()
np_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])
jax_image
import haiku as hk
def _forward(args):
= transpiled_graph()
module return module(args)
= jax.random.PRNGKey(42)
rng_key = hk.transform(_forward)
jax_mlp_forward = jax_mlp_forward.init(rng=rng_key, args=jax_image)
params apply = jax.jit(jax_mlp_forward.apply)
= apply(params, None, jax_image) _
Now that we have both models optimized, let’s see how their runtime speeds compare to each other!
%timeit comp_model(tensor_image)
8.06 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply(params, None, jax_image).block_until_ready()
6.08 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
As expected, we have made the model significantly faster with just one line of code! Latency gets even better on a V100 GPU, where we can get up to a 2-3x increase in execution speed! 🚀
Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models
= "http://images.cocodataset.org/train2017/000000283921.jpg"
url = Image.open(requests.get(url, stream=True).raw)
image = transform(image).unsqueeze(0).to("cuda")
tensor_image = tensor_image.detach().cpu().numpy()
np_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])
jax_image
= time.perf_counter()
st = comp_model(tensor_image)
out_torch = time.perf_counter()
et print(f'Torch call took: {(et - st) * 1000:.2f}ms')
= time.perf_counter()
st = apply(params, None, jax_image)
out_jax = time.perf_counter()
et print(f'Jax call took: {(et - st) * 1000:.2f}ms')
print(np.allclose(out_torch.detach().cpu().numpy(), out_jax, atol=1e-4))
Torch call took: 6.66ms
Jax call took: 2.53ms
True
That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!