• Github
  • Docs
  • Discord
  1. Examples and Demos
  2. Image Segmentation with Ivy UNet
  • Quickstart
  • Learn the Basics
    • Write Ivy code
    • Unify code
    • Compile code
    • Transpile code
    • Lazy vs Eager
    • How to use decorators
    • Transpile any library
    • Transpile any model
  • Guides
    • Transpiling a PyTorch model to build on top
    • Transpiling a Tensorflow model to build on top
  • Examples and Demos
    • Using Ivy ResNet
    • Accelerating PyTorch models with JAX
    • Accelerating MMPreTrain models with JAX
    • Image Segmentation with Ivy UNet
    • Ivy AlexNet demo

On this page

  • Imports
  • Data Preparation
    • Custom Preprocessing
    • Load the image example 🖼️
    • Visualise image
  • Model Inference
    • Initializing Native Torch UNet
    • Initializing Ivy UNet with Pretrained Weights ⬇️
    • Custom masking function
    • Use the model to segment your images 🚀
  • TensorFlow backend
  • JAX
  • Appendix: the Ivy native implementation of UNet

Image Segmentation with Ivy UNet

Use the Ivy UNet model for image segmentation.

Since we want the packages to be available after installing, after running the first cell, the notebook will automatically restart.

You can then do Runtime -> Run all after the notebook has restarted, to run all of the cells.

Make sure you run this demo with GPU enabled!

!pip install -q ivy
!pip install -q dm-haiku
!git clone https://github.com/unifyai/models.git

# Installing models package from cloned repository! 😄
!cd models/ && pip install .

exit()

Imports

import ivy
import torch
import numpy as np

Data Preparation

Custom Preprocessing

# ref: https://github.com/milesial/Pytorch-UNet/blob/2f62e6b1c8e98022a6418d31a76f6abd800e5ae7/utils/data_loading.py#L65

def preprocess(mask_values, pil_img, scale, is_mask):
        w, h = pil_img.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
        pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
        img = np.asarray(pil_img)

        if is_mask:
            mask = np.zeros((newH, newW), dtype=np.int64)
            for i, v in enumerate(mask_values):
                if img.ndim == 2:
                    mask[img == v] = i
                else:
                    mask[(img == v).all(-1)] = i

            return mask

        else:
            if img.ndim == 2:
                img = img[np.newaxis, ...]
            else:
                img = img.transpose((2, 0, 1))

            if (img > 1).any():
                img = img / 255.0

            return img

Load the image example 🖼️

# Preprocess image
from PIL import Image
!wget https://raw.githubusercontent.com/unifyai/models/master/images/car.jpg
filename = "car.jpg"
full_img = Image.open(filename)
torch_img = torch.from_numpy(preprocess(None, full_img, 0.5, False)).unsqueeze(0).to("cuda")
# Convert to ivy
ivy.set_backend("torch")
img = ivy.asarray(torch_img.permute((0, 2, 3, 1)), dtype="float32", device="gpu:0")
img_numpy = img.cpu().numpy()

Visualise image

from IPython.display import Image as I, display
display(I(filename))

Model Inference

Initializing Native Torch UNet

torch_unet = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=1.0)
torch_unet.to("cuda")
torch_unet.eval()

Initializing Ivy UNet with Pretrained Weights ⬇️

The model is then initialized with the Pretrained Weights when pretrained=True 🔗.

# load the unet model from ivy_models
import ivy_models
ivy_unet = ivy_models.unet_carvana(n_channels=3, n_classes=2, pretrained=True)

Compile the forward pass for efficiency.

ivy_unet.compile(args=(img,))

Custom masking function

# ref: https://github.com/milesial/Pytorch-UNet/blob/2f62e6b1c8e98022a6418d31a76f6abd800e5ae7/predict.py#L62

def mask_to_image(mask: np.ndarray, mask_values):
    if isinstance(mask_values[0], list):
        out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
    elif mask_values == [0, 1]:
        out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
    else:
        out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)

    if mask.ndim == 3:
        mask = np.argmax(mask, axis=0)

    for i, v in enumerate(mask_values):
        out[mask == i] = v

    return Image.fromarray(out)

Use the model to segment your images 🚀

First, we will generate the reference mask from the reference model.

  1. Torch UNet
torch_output = torch_unet(torch_img.to(torch.float32))
torch_output = torch.nn.functional.interpolate(torch_output, (full_img.size[1], full_img.size[0]), mode="bilinear")
torch_mask = torch_output.argmax(axis=1)
torch_mask = torch_mask[0].squeeze().cpu().numpy()
torch_result = mask_to_image(torch_mask, [0,1])
torch_result

Next we will generate the mask from the Ivy native implementation

  1. Ivy UNet
output = ivy_unet(img)
output = ivy.interpolate(output.permute((0, 3, 1, 2)), (full_img.size[1], full_img.size[0]), mode="bilinear")
mask = output.argmax(axis=1)
mask = ivy.squeeze(mask[0], axis=None).to_numpy()
result = mask_to_image(mask, [0,1])
result

Great! The ivy native model and the torch model give the same result!

TensorFlow backend

Let’s look at using the TensorFlow backend.

import tensorflow as tf
ivy.set_backend("tensorflow")

ivy_unet = ivy_models.unet_carvana(n_channels=3, n_classes=2, pretrained=True)
img_tf = ivy.asarray(img_numpy)
ivy_unet = ivy.compile(ivy_unet, args=(img_tf,))
output = ivy_unet(img_tf)
output = ivy.interpolate(tf.transpose(output, (0, 3, 1, 2)), (full_img.size[1], full_img.size[0]), mode="bilinear")
mask = tf.math.argmax(output, axis=1)
mask = tf.squeeze(mask[0], axis=None).numpy()
result = mask_to_image(mask, [0,1])
result

As expected, we ended up with the same mask as before. Note how with the TensorFlow backend, we were able to use TensorFlow native functions to do the post-processing.

JAX

Next up is the JAX backend. We’ve used a lot of the notebook memory so far, so we’ll free up some space.

del torch_unet
del ivy_unet
torch.cuda.empty_cache()
# Overrides Jax's default behavior of preallocating 75% of GPU memory
# Temporary fix until this is handled by ivy's graph compiler
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import jax

jax.config.update('jax_enable_x64', True)
ivy.set_backend("jax")
ivy_unet = ivy_models.unet_carvana(n_channels=3, n_classes=2, pretrained=True)
img_jax = ivy.asarray(img_numpy)
output = ivy_unet(img_jax)
output = ivy.interpolate(ivy.permute_dims(output, (0, 3, 1, 2)), (full_img.size[1], full_img.size[0]), mode="bilinear")
mask = output.argmax(axis=1)
mask = ivy.squeeze(mask[0], axis=None).to_numpy()
result = mask_to_image(mask, [0,1])
result
/usr/local/lib/python3.10/dist-packages/ivy/func_wrapper.py:242: UserWarning: Creating many views will lead to overhead when performing inplace updates with this backend
  warnings.warn(

Once again, we ended up with the same mask as in the reference torch implementation!

Appendix: the Ivy native implementation of UNet

class UNET(ivy.Module):
    def __init__(self, n_channels, n_classes, bilinear=False, v=None):
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.factor = 2 if bilinear else 1
        super(UNET, self).__init__(v=v)

    def _build(self, *args, **kwargs):
        self.inc = UNetDoubleConv(self.n_channels, 64)
        self.down1 = UNetDown(64, 128)
        self.down2 = UNetDown(128, 256)
        self.down3 = UNetDown(256, 512)
        self.down4 = UNetDown(512, 1024 // self.factor)
        self.up1 = UNetUp(1024, 512 // self.factor, self.bilinear)
        self.up2 = UNetUp(512, 256 // self.factor, self.bilinear)
        self.up3 = UNetUp(256, 128 // self.factor, self.bilinear)
        self.up4 = UNetUp(128, 64, self.bilinear)
        self.outc = UNetOutConv(64, self.n_classes)

    def _forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


class UNetDoubleConv(ivy.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mid_channels = mid_channels if mid_channels else out_channels
        super(UNetDoubleConv, self).__init__()

    def _build(self, *args, **kwargs):
        self.double_conv = ivy.Sequential(
            ivy.Conv2D(
                self.in_channels, self.mid_channels, [3, 3], 1, 1, with_bias=False
            ),
            ivy.BatchNorm2D(self.mid_channels),
            ivy.ReLU(),
            ivy.Conv2D(
                self.mid_channels, self.out_channels, [3, 3], 1, 1, with_bias=False
            ),
            ivy.BatchNorm2D(self.out_channels),
            ivy.ReLU(),
        )

    def _forward(self, x):
        return self.double_conv(x)


class UNetDown(ivy.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        self.in_channels = in_channels
        self.out_channels = out_channels
        super().__init__()

    def _build(self, *args, **kwargs):
        self.maxpool_conv = ivy.Sequential(
            ivy.MaxPool2D(2, 2, 0), UNetDoubleConv(self.in_channels, self.out_channels)
        )

    def _forward(self, x):
        return self.maxpool_conv(x)


class UNetUp(ivy.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bilinear = bilinear
        super().__init__()

    def _build(self, *args, **kwargs):
        if self.bilinear:
            self.up = ivy.interpolate(
                scale_factor=2, mode="bilinear", align_corners=True
            )
            self.conv = UNetDoubleConv(
                self.in_channels, self.out_channels, self.in_channels // 2
            )
        else:
            self.up = ivy.Conv2DTranspose(
                self.in_channels, self.in_channels // 2, [2, 2], 2, "VALID"
            )
            self.conv = UNetDoubleConv(self.in_channels, self.out_channels)

    def _forward(self, x1, x2):
        x1 = self.up(x1)
        # input is BHWC
        diff_H = x2.shape[1] - x1.shape[1]
        diff_W = x2.shape[2] - x1.shape[2]

        pad_width = (
            (0, 0),
            (diff_H - diff_H // 2, diff_H // 2),
            (diff_W // 2, diff_W - diff_W // 2),
            (0, 0),
        )

        x1 = ivy.constant_pad(x1, pad_width)
        x = ivy.concat((x2, x1), axis=3)
        return self.conv(x)


class UNetOutConv(ivy.Module):
    def __init__(self, in_channels, out_channels):
        self.in_channels = in_channels
        self.out_channels = out_channels
        super(UNetOutConv, self).__init__()

    def _build(self, *args, **kwargs):
        self.conv = ivy.Conv2D(self.in_channels, self.out_channels, [1, 1], 1, 0)

    def _forward(self, x):
        return self.conv(x)
Accelerating MMPreTrain models with JAX
Ivy AlexNet demo