!pip install -q ivy
!pip install -q dm-haiku
!git clone https://github.com/unifyai/models.git
# Installing models package from cloned repository! 😄
!cd models/ && pip install .
exit()
Image Segmentation with Ivy UNet
Use the Ivy UNet
model for image segmentation.
Since we want the packages to be available after installing, after running the first cell, the notebook will automatically restart.
You can then do Runtime -> Run all after the notebook has restarted, to run all of the cells.
Make sure you run this demo with GPU enabled!
Imports
import ivy
import torch
import numpy as np
Data Preparation
Custom Preprocessing
# ref: https://github.com/milesial/Pytorch-UNet/blob/2f62e6b1c8e98022a6418d31a76f6abd800e5ae7/utils/data_loading.py#L65
def preprocess(mask_values, pil_img, scale, is_mask):
= pil_img.size
w, h = int(scale * w), int(scale * h)
newW, newH assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
= pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
pil_img = np.asarray(pil_img)
img
if is_mask:
= np.zeros((newH, newW), dtype=np.int64)
mask for i, v in enumerate(mask_values):
if img.ndim == 2:
== v] = i
mask[img else:
== v).all(-1)] = i
mask[(img
return mask
else:
if img.ndim == 2:
= img[np.newaxis, ...]
img else:
= img.transpose((2, 0, 1))
img
if (img > 1).any():
= img / 255.0
img
return img
Load the image example 🖼️
# Preprocess image
from PIL import Image
!wget https://raw.githubusercontent.com/unifyai/models/master/images/car.jpg
= "car.jpg"
filename = Image.open(filename)
full_img = torch.from_numpy(preprocess(None, full_img, 0.5, False)).unsqueeze(0).to("cuda") torch_img
# Convert to ivy
"torch")
ivy.set_backend(= ivy.asarray(torch_img.permute((0, 2, 3, 1)), dtype="float32", device="gpu:0")
img = img.cpu().numpy() img_numpy
Visualise image
from IPython.display import Image as I, display
display(I(filename))
Model Inference
Initializing Native Torch UNet
= torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=1.0)
torch_unet "cuda")
torch_unet.to(eval() torch_unet.
Initializing Ivy UNet with Pretrained Weights ⬇️
The model is then initialized with the Pretrained Weights when pretrained=True
🔗.
# load the unet model from ivy_models
import ivy_models
= ivy_models.unet_carvana(n_channels=3, n_classes=2, pretrained=True) ivy_unet
Compile the forward pass for efficiency.
compile(args=(img,)) ivy_unet.
Custom masking function
# ref: https://github.com/milesial/Pytorch-UNet/blob/2f62e6b1c8e98022a6418d31a76f6abd800e5ae7/predict.py#L62
def mask_to_image(mask: np.ndarray, mask_values):
if isinstance(mask_values[0], list):
= np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
out elif mask_values == [0, 1]:
= np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
out else:
= np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
out
if mask.ndim == 3:
= np.argmax(mask, axis=0)
mask
for i, v in enumerate(mask_values):
== i] = v
out[mask
return Image.fromarray(out)
Use the model to segment your images 🚀
First, we will generate the reference mask from the reference model.
- Torch UNet
= torch_unet(torch_img.to(torch.float32))
torch_output = torch.nn.functional.interpolate(torch_output, (full_img.size[1], full_img.size[0]), mode="bilinear")
torch_output = torch_output.argmax(axis=1)
torch_mask = torch_mask[0].squeeze().cpu().numpy()
torch_mask = mask_to_image(torch_mask, [0,1])
torch_result torch_result
Next we will generate the mask from the Ivy native implementation
- Ivy UNet
= ivy_unet(img)
output = ivy.interpolate(output.permute((0, 3, 1, 2)), (full_img.size[1], full_img.size[0]), mode="bilinear")
output = output.argmax(axis=1)
mask = ivy.squeeze(mask[0], axis=None).to_numpy()
mask = mask_to_image(mask, [0,1])
result result
Great! The ivy native model and the torch model give the same result!
TensorFlow backend
Let’s look at using the TensorFlow backend.
import tensorflow as tf
"tensorflow")
ivy.set_backend(
= ivy_models.unet_carvana(n_channels=3, n_classes=2, pretrained=True)
ivy_unet = ivy.asarray(img_numpy)
img_tf = ivy.compile(ivy_unet, args=(img_tf,)) ivy_unet
= ivy_unet(img_tf)
output = ivy.interpolate(tf.transpose(output, (0, 3, 1, 2)), (full_img.size[1], full_img.size[0]), mode="bilinear")
output = tf.math.argmax(output, axis=1)
mask = tf.squeeze(mask[0], axis=None).numpy()
mask = mask_to_image(mask, [0,1])
result result
As expected, we ended up with the same mask as before. Note how with the TensorFlow backend, we were able to use TensorFlow native functions to do the post-processing.
JAX
Next up is the JAX backend. We’ve used a lot of the notebook memory so far, so we’ll free up some space.
del torch_unet
del ivy_unet
torch.cuda.empty_cache()
# Overrides Jax's default behavior of preallocating 75% of GPU memory
# Temporary fix until this is handled by ivy's graph compiler
import os
"XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ[
import jax
'jax_enable_x64', True)
jax.config.update("jax")
ivy.set_backend(= ivy_models.unet_carvana(n_channels=3, n_classes=2, pretrained=True) ivy_unet
= ivy.asarray(img_numpy)
img_jax = ivy_unet(img_jax)
output = ivy.interpolate(ivy.permute_dims(output, (0, 3, 1, 2)), (full_img.size[1], full_img.size[0]), mode="bilinear")
output = output.argmax(axis=1)
mask = ivy.squeeze(mask[0], axis=None).to_numpy()
mask = mask_to_image(mask, [0,1])
result result
/usr/local/lib/python3.10/dist-packages/ivy/func_wrapper.py:242: UserWarning: Creating many views will lead to overhead when performing inplace updates with this backend
warnings.warn(
Once again, we ended up with the same mask as in the reference torch implementation!
Appendix: the Ivy native implementation of UNet
class UNET(ivy.Module):
def __init__(self, n_channels, n_classes, bilinear=False, v=None):
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.factor = 2 if bilinear else 1
super(UNET, self).__init__(v=v)
def _build(self, *args, **kwargs):
self.inc = UNetDoubleConv(self.n_channels, 64)
self.down1 = UNetDown(64, 128)
self.down2 = UNetDown(128, 256)
self.down3 = UNetDown(256, 512)
self.down4 = UNetDown(512, 1024 // self.factor)
self.up1 = UNetUp(1024, 512 // self.factor, self.bilinear)
self.up2 = UNetUp(512, 256 // self.factor, self.bilinear)
self.up3 = UNetUp(256, 128 // self.factor, self.bilinear)
self.up4 = UNetUp(128, 64, self.bilinear)
self.outc = UNetOutConv(64, self.n_classes)
def _forward(self, x):
= self.inc(x)
x1 = self.down1(x1)
x2 = self.down2(x2)
x3 = self.down3(x3)
x4 = self.down4(x4)
x5 = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
logits return logits
class UNetDoubleConv(ivy.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
self.in_channels = in_channels
self.out_channels = out_channels
self.mid_channels = mid_channels if mid_channels else out_channels
super(UNetDoubleConv, self).__init__()
def _build(self, *args, **kwargs):
self.double_conv = ivy.Sequential(
ivy.Conv2D(self.in_channels, self.mid_channels, [3, 3], 1, 1, with_bias=False
),self.mid_channels),
ivy.BatchNorm2D(
ivy.ReLU(),
ivy.Conv2D(self.mid_channels, self.out_channels, [3, 3], 1, 1, with_bias=False
),self.out_channels),
ivy.BatchNorm2D(
ivy.ReLU(),
)
def _forward(self, x):
return self.double_conv(x)
class UNetDown(ivy.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
self.in_channels = in_channels
self.out_channels = out_channels
super().__init__()
def _build(self, *args, **kwargs):
self.maxpool_conv = ivy.Sequential(
2, 2, 0), UNetDoubleConv(self.in_channels, self.out_channels)
ivy.MaxPool2D(
)
def _forward(self, x):
return self.maxpool_conv(x)
class UNetUp(ivy.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
self.in_channels = in_channels
self.out_channels = out_channels
self.bilinear = bilinear
super().__init__()
def _build(self, *args, **kwargs):
if self.bilinear:
self.up = ivy.interpolate(
=2, mode="bilinear", align_corners=True
scale_factor
)self.conv = UNetDoubleConv(
self.in_channels, self.out_channels, self.in_channels // 2
)else:
self.up = ivy.Conv2DTranspose(
self.in_channels, self.in_channels // 2, [2, 2], 2, "VALID"
)self.conv = UNetDoubleConv(self.in_channels, self.out_channels)
def _forward(self, x1, x2):
= self.up(x1)
x1 # input is BHWC
= x2.shape[1] - x1.shape[1]
diff_H = x2.shape[2] - x1.shape[2]
diff_W
= (
pad_width 0, 0),
(- diff_H // 2, diff_H // 2),
(diff_H // 2, diff_W - diff_W // 2),
(diff_W 0, 0),
(
)
= ivy.constant_pad(x1, pad_width)
x1 = ivy.concat((x2, x1), axis=3)
x return self.conv(x)
class UNetOutConv(ivy.Module):
def __init__(self, in_channels, out_channels):
self.in_channels = in_channels
self.out_channels = out_channels
super(UNetOutConv, self).__init__()
def _build(self, *args, **kwargs):
self.conv = ivy.Conv2D(self.in_channels, self.out_channels, [1, 1], 1, 0)
def _forward(self, x):
return self.conv(x)