!pip install ivy
!pip install dm-haiku
!pip install kornia
!pip install timm
!pip install pyvis
exit()
Quickstart
Get up to speed with Ivy with a quick, general introduction of its features and capabilities!
⚠️ If you are running this notebook in Colab, you will have to install Ivy
and some extra packages manually. You can do so by running the cell below ⬇️
If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Setting Up section of the docs.
In this notebook we’ll go over the basic aspects of Ivy, which is both a transpiler and a ML framework that you can use to write framework-agnostic code and to integrate code from any framework into your existing code, tools, or infrastructure!
Let’s import Ivy and get started!
import ivy
Get familiar with Ivy
When used as a ML framework, Ivy allows you to write framework-agnostic code that can be executed as native code in your framework of choice. This means that when executed, Ivy code will use the appropriate functions and data structures, allowing you to seamlessly integrate your code with any other code and to leverage any framework-specific benefits.
To change the backend, we simply have to call ivy.set_backend()
and pass the framework we want to use as a string, for example:
"torch") ivy.set_backend(
Now let’s take a look at the data structures of Ivy. The main one is ivy.Array
, which is an abstraction of the array
class of the backends with extra functionalities. You can also access the corresponding class directly through ivy.NativeArray
.
There is also another structure called the ivy.Container
, which is a subclass of dict that is optimized for recursive operations. If you want to learn more about it, you can defer to the following link!
"torch")
ivy.set_backend(
= ivy.array([1, 2, 3])
x print(type(x))
= ivy.native_array([1, 2, 3])
x print(type(x))
Functional API
In a similar manner, Ivy’s functional API wraps the functional API of the backends, and therefore uses native operations under the hood. Let’s see an example of this.
"jax")
ivy.set_backend(= ivy.array([[1], [2], [3]]), ivy.array([[1, 2, 3]])
x1, x2 = ivy.matmul(x1, x2)
output print(type(output.to_native()))
"tensorflow")
ivy.set_backend(= ivy.array([[1], [2], [3]]), ivy.array([[1, 2, 3]])
x1, x2 = ivy.matmul(x1, x2)
output print(type(output.to_native()))
"torch")
ivy.set_backend(= ivy.array([[1], [2], [3]]), ivy.array([[1, 2, 3]])
x1, x2 = ivy.matmul(x1, x2)
output print(type(output.to_native()))
As expected, calling ivy.matmul
with different backends performs the corresponding operation in each one of the frameworks.
Using the functional API, we can define any framework-independent function that we want:
def sigmoid(z):
return ivy.divide(1, (1 + ivy.exp(-z)))
Stateful API
Ivy also has a stateful API which builds on its functional API and the ivy.Container
class to provide high-level classes such as optimizers, network layers, or trainable modules.
The most important stateful class within Ivy is ivy.Module
, which can be used to create trainable layers and entire networks. A very simple example of an ivy.Module
could be:
class Regressor(ivy.Module):
def __init__(self, input_dim, output_dim):
self.linear0 = ivy.Linear(input_dim, 128)
self.linear1 = ivy.Linear(128, output_dim)
__init__(self)
ivy.Module.
def _forward(self, x):
= self.linear0(x)
x = ivy.functional.relu(x)
x = self.linear1(x)
x return x
To use this model, we would simply have to set a backend and instantiate the model:
'torch') # set backend to PyTorch
ivy.set_backend(
= Regressor(input_dim=1, output_dim=1)
model = ivy.Adam(0.1) optimizer
Now we can generate some sample data and train the model using Ivy as well.
= 2000
n_training_examples = ivy.random.random_normal(shape=(n_training_examples, 1), mean=0, std=0.1)
noise = ivy.linspace(-6, 3, n_training_examples).reshape((n_training_examples, 1))
x = 0.2 * x ** 2 + 0.5 * x + 0.1 + noise y
def loss_fn(pred, target):
return ivy.mean((pred - target)**2)
for epoch in range(50):
# forward pass
= model(x)
pred
# compute loss and gradients
= ivy.execute_with_gradients(lambda v: loss_fn(pred, y), model.v)
loss, grads
# update parameters
= optimizer.step(model.v, grads)
model.v
# print current loss
print(f'Epoch: {epoch + 1:2d} --- Loss: {ivy.to_numpy(loss).item():.5f}')
print('Finished training!')
Compiling code
We have just explored how to create framework agnostic functions and models with Ivy. Nonetheless, due to the wrapping Ivy performs on top of native functions, there is a slight performance overhead introduced with each function call. To address this, we can use Ivy’s graph compiler.
The purpose of the Graph Compiler is to extract a fully functional, efficient graph composed only of functions from the corresponding functional APIs of the underlying framework (backend).
On top of using the Graph Compiler to remove the overhead introduced by Ivy, it can also be used with functions and modules written in any framework. In this case, the GC will decompose any high-level API into a fully-functional graph of functions from said framework.
As an example, let’s write a simple normalize
function using Ivy:
def normalize(x):
= ivy.mean(x)
mean = ivy.std(x)
std return ivy.divide(ivy.subtract(x, mean), std)
To compile this function, simply call ivy.compile()
. To specify the underlying framework, you can pass the name of the framework as an argument using to
. Otherwise, the current backend will be used by default.
import torch
= torch.tensor([1., 2., 3.])
x0 = ivy.compile(normalize, to="torch", args=(x0,)) normalize_comp
This results in the following graph:
from IPython.display import HTML
="graph.html", notebook=True)
normalize_comp.show(fname="graph.html") HTML(filename
As anticipated, the compiled function, which uses native torch
operations directly, is faster than the original function:
%%timeit
normalize(x0)
%%timeit
normalize_comp(x0)
Additionally, we can set the return_backend_compiled_fn
arg to True
to apply the (native) target framework compilation function to Ivy’s compiled graph, making the resulting function even more efficient.
= ivy.compile(normalize, return_backend_compiled_fn=True, to="torch", args=(x0,)) normalize_native_comp
%%timeit
normalize_native_comp(x0)
In the example above, we compiled the function eagerly, which means that the compilation process happened immediately, as we have passed the arguments for tracing. However, if we don’t pass any arguments to the compile
function, compilation will occur lazily, and the graph will be built only when we call the compiled function for the first time. To summarize:
import torch
= torch.tensor([1., 2., 3.]) x1
# Arguments are available -> compilation happens eagerly
= ivy.compile(normalize, to="torch", args=(x1,))
eager_graph
# eager_graph is now torch code and runs efficiently
= eager_graph(x1) ret
# Arguments are not available -> compilation happens lazily
= ivy.compile(normalize, to="torch")
lazy_graph
# The compiled graph is initialized, compilation will happen here
= lazy_graph(x1)
ret
# lazy_graph is now torch code and runs efficiently
= lazy_graph(x1) ret
Ivy as a Transpiler
We have just learned how to write framework-agnostic code and compile it into an efficient graph. However, many codebases, libraries, and models have already been developed (and will continue to be!) using other frameworks.
To allow for speed-of-thought research and development, Ivy also allows you to use any code directly into your project, regardless of the framework it was written in. No matter what ML code you want to use, Ivy’s Transpiler is the tool for the job 🛠️
Any function
Let’s start by transpiling a very simple torch
function.
def normalize(x):
= torch.mean(x)
mean = torch.std(x)
std return torch.div(torch.sub(x, mean), std)
= ivy.transpile(normalize, source="torch", to="jax") jax_normalize
Similar to compile
, the transpile
function can be used eagerly or lazily. In this particular example, transpilation is being performed lazily, since we haven’t passed any arguments or keyword arguments to ivy.transpile
.
import jax
= jax.random.PRNGKey(42)
key 'jax_enable_x64', True)
jax.config.update(= jax.random.uniform(key, shape=(10,))
x
= jax_normalize(x)
jax_out print(jax_out, type(jax_out))
That’s pretty much it! You can now use any function you need in your projects regardless of the framework you’re using 🚀
However, transpiling functions one by one is far from ideal. But don’t worry, with transpile
, you can transpile entire libraries at once and easily bring them into your projects. Let’s see how this works by transpiling kornia
, a widely-used computer vision library written in torch
:
Any library
import kornia
import requests
import jax.numpy as jnp
import numpy as np
from PIL import Image
Let’s get the transpiled library by calling transpile
.
= ivy.transpile(kornia, source="torch", to="jax") jax_kornia
Now let’s get a sample image and preprocess so that it has the format kornia expects:
= "http://images.cocodataset.org/train2017/000000000034.jpg"
url = Image.open(requests.get(url, stream=True).raw)
raw_img = jnp.transpose(jnp.array(raw_img), (2, 0, 1))
img = jnp.expand_dims(img, 0) / 255
img display(raw_img)
And we can call any function from kornia in jax
, as simple as that!
= jax_kornia.enhance.sharpness(img, 10)
out type(out)
Finally, let’s see if the transformation has been applied correctly:
= np.uint8(np.array(out[0])*255)
np_image 1, 2, 0)))) display(Image.fromarray(np.transpose(np_image, (
It’s worth noting that every operation in the transpiled functions is performed natively in the target framework, which means that gradients can be tracked and the resulting functions are fully differentiable. Even after transpilation, you can still take advantage of the powerful features of your chosen framework.
While transpiling functions and libraries is useful, trainable modules play a critical role in ML and DL. The good news is that Ivy makes it just as easy to transpile modules and models from one framework to another with just one line of code.
Any model
For the purpose of this demonstration, let’s define a very basic CNN block using the Sequential API of keras
.
import tensorflow as tf
= tf.keras.Sequential([
model 32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 3)),
tf.keras.layers.Conv2D(
tf.keras.layers.Flatten(),10, activation='softmax')
tf.keras.layers.Dense( ])
The model we just defined is an instance of tf.keras.Model
. Using ivy.transpile
, we can effortlessly convert it into a torch.nn.Module
, for instance.
= tf.random.normal((1, 28, 28, 3))
input_array = ivy.transpile(model, to="torch", args=(input_array,)) torch_model
After transpilation, we can pass a torch
tensor and obtain the expected output. As mentioned previously, all operations are now PyTorch native functions, making them differentiable. Additionally, Ivy automatically converts all parameters of the original model to the new one, allowing you to transpile pre-trained models and fine-tune them in your preferred framework.
isinstance(torch_model, torch.nn.Module)
= torch.rand((1, 28, 28, 3)).to(ivy.default_device(as_native="True"))
input_array ="True"))
torch_model.to(ivy.default_device(as_native= torch_model(input_array)
output_array print(output_array)
While we have only transpiled a simple model for demonstration purposes, we can certainly transpile more complex models as well. Let’s take a model from timm
and see how we can build upon transpiled modules.
import timm
We will only be using the encoder, so we can remove the unnecessary layers by setting num_classes=0
, and then pass pretrained=True
to download the pre-trained parameters.
= timm.create_model("mixer_b16_224", pretrained=True, num_classes=0) mlp_encoder
Let’s transpile the model to tensorflow with ivy.transpile
🔀
= torch.randn(1, 3, 224, 224)
noise = ivy.transpile(mlp_encoder, to="tensorflow", args=(noise,)) tf_mlp_encoder
And now let’s build a model on top of our pretrained encoder!
class Classifier(tf.keras.Model):
def __init__(self):
super(Classifier, self).__init__()
self.encoder = tf_mlp_encoder
self.output_dense = tf.keras.layers.Dense(units=1000, activation="softmax")
def call(self, x):
= self.encoder(x)
x return self.output_dense(x)
= Classifier()
model
= tf.random.normal(shape=(1, 3, 224, 224))
x = model(x)
ret print(type(ret), ret.shape)
As the encoder now consists of tensorflow
functions, we can extend the transpiled modules as much as we want, leveraging existing weights and the tools and infrastructure of all frameworks 🚀
Round Up
That’s about it! You are now prepared to start using Ivy on your own! However, there are still plenty of useful resources to explore. If you want to delve deeper into Ivy’s features and learn how to use them, you can visit the Demos page and go through the notebooks 📚