!pip install ivy
Transpile code
Convert a torch
function to jax
with just one line of code.
⚠️ If you are running this notebook in Colab, you will have to install Ivy
and some dependencies manually. You can do so by running the cell below ⬇️
If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Get Started section of the docs.
Using what we learnt in the previous two notebooks for Unify and Compile, the workflow for converting directly from torch
to jax
would be as follows, first unifying to ivy
code, and then compiling to the jax
backend:
import ivy
import torch
"jax")
ivy.set_backend(
def normalize(x):
= torch.mean(x)
mean = torch.std(x)
std return torch.div(torch.sub(x, mean), std)
# convert the function to Ivy code
= ivy.unify(normalize)
ivy_normalize
# compile the Ivy code into jax functions
= ivy.compile(ivy_normalize) jax_normalize
normalize
is now compiled to jax
, ready to be integrated into your wider jax
project.
This workflow is common, and so in order to avoid repeated calls to ivy.unify
followed by ivy.compile
, there is another convenience function ivy.transpile
, which basically acts as a shorthand for this pair of function calls:
= ivy.transpile(normalize, source="torch", to="jax") jax_normalize
Again, normalize
is now a jax
function, ready to be integrated into your jax
project.
import jax
= jax.random.PRNGKey(42)
key 'jax_enable_x64', True)
jax.config.update(= jax.random.uniform(key, shape=(10,))
x
print(jax_normalize(x))
[-0.93968587 0.26075466 -0.22723222 -1.06276492 -0.47426987 1.72835908
1.71737559 -0.50411096 -0.65419174 0.15576624]
Round Up
That’s it, you can now transpile code from one framework to another with one line of code! However, there are still other important topics to master before you’re ready to unify ML code like a pro 🥷. In the next notebooks we’ll be learning about the various different ways that ivy.unify
, ivy.compile
and ivy.transpile
can be called, and what implications each of these have!