• Github
  • Docs
  • Discord
  1. Examples and Demos
  2. Accelerating MMPreTrain models with JAX
  • Quickstart
  • Learn the Basics
    • Write Ivy code
    • Unify code
    • Compile code
    • Transpile code
    • Lazy vs Eager
    • How to use decorators
    • Transpile any library
    • Transpile any model
  • Guides
    • Transpiling a PyTorch model to build on top
    • Transpiling a Tensorflow model to build on top
  • Examples and Demos
    • Using Ivy ResNet
    • Accelerating PyTorch models with JAX
    • Accelerating MMPreTrain models with JAX
    • Image Segmentation with Ivy UNet
    • Ivy AlexNet demo

Accelerating MMPreTrain models with JAX

Accelerate your MMPreTrain models by converting them to JAX for faster inference.

Installations

Make sure you run this demo with GPU enabled!

!pip install -U -q openmim && mim install -q "mmpretrain>=1.0.0rc8"
!pip install -q ivy
!pip install -q dm-haiku

Let’s now import Ivy and the libraries we’ll use in this example:

import jax
import ivy
import torch
import requests
import numpy as np
from PIL import Image
import time

import torchvision
from mmpretrain import get_model, list_models
from mmengine import ConfigDict

Sanity check to make sure checkpoint name is correct against mmpretrain’s model zoo

checkpoint_name = "convnext-tiny_32xb128-noema_in1k"
list_models(checkpoint_name)
['convnext-tiny_32xb128-noema_in1k']

Now we can load the ConvNext model from OpenMMLab’s mmpretrain library

jax.config.update("jax_enable_x64", True)

model = get_model(checkpoint_name, pretrained=True, device='cuda')

We will also need a sample image to pass during tracing, so let’s use the appropriate transforms to get the corresponding torch tensors.

def get_scale(cfg):
    if type(cfg) == ConfigDict:
        if cfg.get('type', False) and cfg.get('scale', False):
            return cfg['scale']
        else:
            for k in cfg.keys():
                input_shape = get_scale(cfg[k])
                if input_shape:
                    return input_shape
    elif type(cfg) == list:
        for block in cfg:
            input_shape = get_scale(block)
            if input_shape:
                return input_shape
    else:
        return None
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
input_shape = get_scale(model._config.train_pipeline)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_shape, input_shape)),
    torchvision.transforms.ToTensor()
])
tensor_image = transform(image).unsqueeze(0).to("cuda")

And finally, let’s transpile the model to haiku!

transpiled_graph = ivy.transpile(model, to="haiku", args=(tensor_image,))

After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile

tensor_image = transform(image).unsqueeze(0).to("cuda")

def _f(args):
  return model(args)

comp_model = torch.compile(_f)
_ = comp_model(tensor_image)

Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:

tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
jax_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])

import haiku as hk

def _forward(args):
  module = transpiled_graph()
  return module(args)

rng_key = jax.random.PRNGKey(42)
jax_mlp_forward = hk.transform(_forward)
params = jax_mlp_forward.init(rng=rng_key, args=jax_image)
apply = jax.jit(jax_mlp_forward.apply)
_ = apply(params, None, jax_image)

Now that we have both models optimized, let’s see how their runtime speeds compare to each other!

%timeit comp_model(tensor_image)
8.06 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply(params, None, jax_image).block_until_ready()
6.08 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

As expected, we have made the model significantly faster with just one line of code! Latency gets even better on a V100 GPU, where we can get up to a 2-3x increase in execution speed! 🚀

Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models

url = "http://images.cocodataset.org/train2017/000000283921.jpg"
image = Image.open(requests.get(url, stream=True).raw)
tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
jax_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])

st = time.perf_counter()
out_torch = comp_model(tensor_image)
et = time.perf_counter()
print(f'Torch call took: {(et - st) * 1000:.2f}ms')

st = time.perf_counter()
out_jax = apply(params, None, jax_image)
et = time.perf_counter()
print(f'Jax call took: {(et - st) * 1000:.2f}ms')

print(np.allclose(out_torch.detach().cpu().numpy(), out_jax, atol=1e-4))
Torch call took: 6.66ms
Jax call took: 2.53ms
True

That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!

Accelerating PyTorch models with JAX
Image Segmentation with Ivy UNet