TDM 40200: Project 1 — 2023
Motivation: JAX
is a Python library used for high-performance numerical computing and machine learning research. In the upcoming series of projects, this library will be used to build a small neural network from scratch.
Context: This is the first project of the semester. We are going to start slowly by reviewing the JAX
library, which we will use in the following series of projects.
Scope: Python, JAX
, numpy
Questions
The following sets of documentation will be useful for this project.
Question 1
Use JAX
to perform the following numeric operations.
-
Create and display a 2x3 array called
first
of the values 1 through 6 in row-major order. -
Create and display a 3x2 array called
second
where each element is the value 2.Use the
JAX
equivalent of this function. -
Multiply
first
andsecond
together into a resultingthird
and display the result. -
Display the shape of
third
. -
Multiply all values of
third
by 2 and display the result. -
Multiply
third
, element-wise, by the matrix formed from the values 1 through 4 in row-major order and display the result.Use the
JAX
equivalent of this function. -
Use the
JAX
equivalent of this function to add a row tosecond
containing the values 3 and 4. Save the result tofourth
. -
Use the
JAX
equivalent of this function to add a column tofourth
containing the value 1 repeated 4 times. Save the result tofifth
. -
Given the
JAX
array created by the following code, change the 1 to 7.my_array = jnp.array([[2,2,2], [2,1,3]])
-
Read this section of the
JAX
docs and explain (in your own words) why the following code that should solve the previous question does not work.my_array[1,1] = 7
-
Code used to solve this problem.
-
Output from running the code.
-
A markdown cell containing an explanation on why
my_array[1,1] = 7
does not work inJAX
.
Question 2
Read this section of the JAX
documentation to recall the way you can generate random numbers using JAX
.
Check out the following code.
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
If this were written using JAX
, JAX
may well attempt to parallelize the bar
and baz
functions to run at the same time. As a result, whether bar
or baz
ran first would be unclear. This, unfortunately, would change the results of foo
, leaving the code difficult to replicate reliably.
To illustrate this, foo
could be the result of either of the following functions.
def foo1(): return bar() + 2*baz()
# or
def foo2(): return 2*baz() + bar()
If bar
executes first (like in foo1
), you will end up with a different result than if baz
executes first (like in foo2
). The following code illustrates this.
import numpy as np
import random
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo1(): return bar() + 2*baz()
def foo2(): return 2*baz() + bar()
def foo(*funcs):
functions = list(funcs)
random.shuffle(functions)
return functions[0]()
Running the following will sometimes give you the result 1.9791922366721637
, and sometimes 1.812816374227069
.
np.random.seed(0)
foo(foo1, foo2)
The way JAX
generates random values is different, and prevents such issues. At the same time, the way JAX
generates random values is not as straightforward as NumPy
. Fill in the ?
parts of the following code. The resulting code should reliably output the same value regardless of whether bar
or baz
executes first — 2.3250647
.
import jax
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=?)
def bar(key):
return ?
def baz(key):
return ?
def foo1(key1, key2):
return bar(key1) + 2*baz(key2)
def foo2(key1, key2):
return 2*baz(key2) + bar(key1)
def foo(funcs, keys):
functions = list(funcs)
random.shuffle(functions)
return ?
# the following code will always produce 2.3250647, regardless of whether bar or baz executes first
# this means this code is reproducible even in the scenario where the `bar` and `baz` functions are parallelized
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=3)
print(foo((foo1, foo2), (subkeys[0], subkeys[1])))
-
Code used to solve this problem.
-
Output from running the code.
Question 3
At the heart of JAX
is the automatic differentiation system. This system allows JAX
to compute gradients of functions, automatically. This is extremely powerful. Write a function called my_function
that accepts the value x
and returns the value 14x^3 + 13x
. Test it out given the value of x=17
.
Next, use the powerful JAX
function grad
to create a new function called my_gradient
that accepts the value x
and returns the gradient of my_function
at x
. Test it out given the value of x=17
. What was the result? Does the result match the value when you plug x=17
into the derivative of my_function
?
|
-
Code used to solve this problem.
-
Output from running the code.
Question 4
Another key utility in JAX
is the jit
function. JIT stands for just in time. Just in time compilation is a trick that can be used in some situations to greatly increase the speed or execution time of some code, by compiling it. The compiled version of the code has a myriad of optimizations applied to it that speeds up your code.
Take the following, arbitrary code, and execute it in your notebook.
%%time
key = jax.random.PRNGKey(0)
def my_model(keys):
return 14*jax.random.normal(keys[0])**2 + 13*jax.random.normal(keys[1])
for i in range(100000):
key, *subkeys = jax.random.split(key, 3)
my_value = my_model(subkeys)
How long did it take? Now, use the @jax.jit
decorator to apply the jit
transformation to your my_model
function to use just in time compilation to speed up the code. Did it work?
Well, actually, just slapping the @jax.jit
decorator on the function is not good enough. Why? Because JAX
has asynchronous dispatch by default. What this means is that, by default, JAX
will return control to Python as soon as possible, even if it is before the function has been fully evaluated. So while it may appear as if all of the 100000 loops have been executed, in reality, they may not have been.
To properly test if the JIT trick has sped things up, we need to synchronously wait for our code to finish executing. This can be easily accomplished by using the built in block_until_ready
method build into all JIT compiled functions.
For example, the following code will synchronously wait for the my_func
function to finish executing.
@jax.jit
def my_func():
return 1
my_func().block_until_ready()
Repeat the experiment but make sure we are synchronously waiting for the code to finish executing. How long did it take? Did it work?
-
Code used to solve this problem.
-
Output from running the code.
Question 5
At this point in time you may be thinking — let’s go! I can just slap @jax.jit
on all my functions an make magically fast code! Well, not so fast. There are some caveats to using JAX
and jit
that you should be aware of.
By default, JAX
will try and "trace" parameters to determine their effect on inputs of a specific shape and type. In JAX
, control flow cannot depend on these "traced" values. For example, the following code will not work because num_loops
is relied on in order to determine how many times to loop.
%%time
def my_function(x, num_loops):
for i in range(num_loops):
pass
fast_my_function = jax.jit(my_function)
fast_my_function(14, 1000000)
What is the solution to this problem? How do we fix it? Well, this is not always possible, however, we can choose to select certain arguments to be static or not "traced". If a parameter is marked as static, or not "traced", it can be JIT compiled. The catch is that any time a call to the function is made and the value of any of the static parameters is changed, the function will have to be recompiled with that new static value. So, this is only useful if you will only occasionally change the parameter.
It just so happens that the provided snippet of code is a good candidate for this, as the user will only occasionally decide to change the number of "num_loops" when running the code!
You can mark a parameter as static by specifying the argument position using the static_argnums
argument to jax.jit
, or by specifying the argumnet name using the static_argnames
argument to jax.jit
.
Force the num_loops
argument to be static and use the jax.jit
decorator to compile the function. Test out the function, in order, using the following code cells.
%%time
def my_function(x, num_loops):
for i in range(num_loops):
pass
fast_my_function = jax.jit(my_function, static_argnums=(1,))
%%time
fast_my_function(14, 1000000)
%%time
fast_my_function(14, 1000000)
%%time
fast_my_function(14, 999999)
Do your best to explain why the last code cell was once again slower.
In addition, the shapes or dimensions of all inputs and outputs must be able to be determined ahead of time. For example, the following will fail.
%%time
def my_function(x, arr_cols):
my_array = jnp.full((2, arr_cols), 5)
fast_my_function = jax.jit(my_function)
fast_my_function(5, 5)
You can, once again, fix this by specifying the static parameters.
%%time
def my_function(x, arr_cols):
my_array = jnp.full((2, arr_cols), 5)
fast_my_function = jax.jit(my_function, static_argnums=(1,))
And, once again, JAX
will recompile every time that the static argument changes.
%%time
# slow, first time compiling
fast_my_function(5, 5)
%%time
# fast, already compiled with static argument of 5
fast_my_function(5, 5)
%%time
# slow, recompiling with static argument of 6
fast_my_function(5, 6)
-
Code used to solve this problem.
-
Output from running the code.
Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connection, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted. In addition, please review our submission guidelines before submitting your project. |