!pip install -U -q openmim && mim install -q "mmpretrain>=1.0.0rc8"
!pip install -q ivy
!pip install -q dm-haikuAccelerating MMPreTrain models with JAX
Accelerate your MMPreTrain models by converting them to JAX for faster inference.
Installations
Make sure you run this demo with GPU enabled!
Let’s now import Ivy and the libraries we’ll use in this example:
import jax
import ivy
import torch
import requests
import numpy as np
from PIL import Image
import time
import torchvision
from mmpretrain import get_model, list_models
from mmengine import ConfigDictSanity check to make sure checkpoint name is correct against mmpretrain’s model zoo
checkpoint_name = "convnext-tiny_32xb128-noema_in1k"
list_models(checkpoint_name)['convnext-tiny_32xb128-noema_in1k']
Now we can load the ConvNext model from OpenMMLab’s mmpretrain library
jax.config.update("jax_enable_x64", True)
model = get_model(checkpoint_name, pretrained=True, device='cuda')We will also need a sample image to pass during tracing, so let’s use the appropriate transforms to get the corresponding torch tensors.
def get_scale(cfg):
    if type(cfg) == ConfigDict:
        if cfg.get('type', False) and cfg.get('scale', False):
            return cfg['scale']
        else:
            for k in cfg.keys():
                input_shape = get_scale(cfg[k])
                if input_shape:
                    return input_shape
    elif type(cfg) == list:
        for block in cfg:
            input_shape = get_scale(block)
            if input_shape:
                return input_shape
    else:
        return Noneurl = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
input_shape = get_scale(model._config.train_pipeline)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_shape, input_shape)),
    torchvision.transforms.ToTensor()
])
tensor_image = transform(image).unsqueeze(0).to("cuda")And finally, let’s transpile the model to haiku!
transpiled_graph = ivy.transpile(model, to="haiku", args=(tensor_image,))After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile
tensor_image = transform(image).unsqueeze(0).to("cuda")
def _f(args):
  return model(args)
comp_model = torch.compile(_f)
_ = comp_model(tensor_image)Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:
tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
jax_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])
import haiku as hk
def _forward(args):
  module = transpiled_graph()
  return module(args)
rng_key = jax.random.PRNGKey(42)
jax_mlp_forward = hk.transform(_forward)
params = jax_mlp_forward.init(rng=rng_key, args=jax_image)
apply = jax.jit(jax_mlp_forward.apply)
_ = apply(params, None, jax_image)Now that we have both models optimized, let’s see how their runtime speeds compare to each other!
%timeit comp_model(tensor_image)8.06 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply(params, None, jax_image).block_until_ready()6.08 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
As expected, we have made the model significantly faster with just one line of code! Latency gets even better on a V100 GPU, where we can get up to a 2-3x increase in execution speed! 🚀
Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models
url = "http://images.cocodataset.org/train2017/000000283921.jpg"
image = Image.open(requests.get(url, stream=True).raw)
tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
jax_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])
st = time.perf_counter()
out_torch = comp_model(tensor_image)
et = time.perf_counter()
print(f'Torch call took: {(et - st) * 1000:.2f}ms')
st = time.perf_counter()
out_jax = apply(params, None, jax_image)
et = time.perf_counter()
print(f'Jax call took: {(et - st) * 1000:.2f}ms')
print(np.allclose(out_torch.detach().cpu().numpy(), out_jax, atol=1e-4))Torch call took: 6.66ms
Jax call took: 2.53ms
True
That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!