14. Introduction to Artificial Neural Networks#
GPU
This lecture was built using a machine with the latest CUDA and CUDANN frameworks installed with access to a GPU.
To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.
To run this lecture on your own machine, you need to install the software listed following this notice.
!pip install --upgrade jax
In addition to what’s included in base Anaconda, we need to install the following packages
!pip install kaleido
!conda install -y -c plotly plotly plotly-orca retrying
Show code cell output
Collecting kaleido
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/79.9 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━ 60.6/79.9 MB 353.1 MB/s eta 0:00:01
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 79.9/79.9 MB 231.9 MB/s eta 0:00:00
?25h
Installing collected packages: kaleido
Successfully installed kaleido-0.2.1
Channels:
- plotly
- default
- defaults
Platform: linux-64
Collecting package metadata (repodata.json): -
\
|
/
-
\
|
done
Solving environment: -
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
done
## Package Plan ##
environment location: /home/runner/miniconda3/envs/quantecon
added / updated specs:
- plotly
- plotly-orca
- retrying
The following packages will be downloaded:
package | build
---------------------------|-----------------
plotly-orca-1.3.1 | 1 56.6 MB plotly
retrying-1.3.3 | pyhd3eb1b0_2 14 KB
------------------------------------------------------------
Total: 56.6 MB
The following NEW packages will be INSTALLED:
plotly-orca plotly/linux-64::plotly-orca-1.3.1-1
retrying pkgs/main/noarch::retrying-1.3.3-pyhd3eb1b0_2
Downloading and Extracting Packages:
plotly-orca-1.3.1 | 56.6 MB | | 0%
retrying-1.3.3 | 14 KB | | 0%
retrying-1.3.3 | 14 KB | ##################################### | 100%
retrying-1.3.3 | 14 KB | ##################################### | 100%
plotly-orca-1.3.1 | 56.6 MB | ##################################### | 100%
plotly-orca-1.3.1 | 56.6 MB | ##################################### | 100%
plotly-orca-1.3.1 | 56.6 MB | ##################################### | 100%
Preparing transaction: |
/
done
Verifying transaction: \
|
done
Executing transaction: -
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
done
Note
If you are running this on Google Colab the above cell will
present an error. This is because Google Colab doesn’t use Anaconda to manage
the Python packages. However this lecture will still execute as Google Colab
has plotly
installed.
14.1. Overview#
Substantial parts of machine learning and artificial intelligence are about
approximating an unknown function with a known function
estimating the known function from a set of data on the left- and right-hand variables
This lecture describes the structure of a plain vanilla artificial neural network (ANN) of a type that is widely used to
approximate a function
To introduce elementary concepts, we study an example in which
We’ll describe the following concepts that are brick and mortar for neural networks:
a neuron
an activation function
a network of neurons
A neural network as a composition of functions
back-propagation and its relationship to the chain rule of differential calculus
14.2. A Deep (but not Wide) Artificial Neural Network#
We describe a “deep” neural network of “width” one.
Deep means that the network composes a large number of functions organized into nodes of a graph.
Width refers to the number of right hand side variables on the right hand side of the function being approximated.
Setting “width” to one means that the network composes just univariate functions.
Let
We assume that
We want to approximate
For a network of depth
an input
an affine function
, where is a scalar weight placed on the input and is a scalar biasan activation function
that takes as an argument and produces an output
An example of an activation function
Another popular activation function is the rectified linear unit (ReLU) function
Yet another activation function is the identity function
As activation functions below, we’ll use the sigmoid function for layers
To approximate a function
Let
We construct
If
The larger is the integer
Evidently, if we know the parameters
starting from
The value of
14.3. Calibrating Parameters#
We now consider a neural network like the one describe above with width 1, depth
Let
As mentioned above, for a given input
For a given prediction
This criterion is a function of the parameters
We’re interested in solving the following problem:
where
Stack weights and biases into a vector of parameters
Applying a “poor man’s version” of a stochastic gradient descent algorithm for finding a zero of a function leads to the following update rule for parameters:
where
(See this and this to gather insights about how stochastic gradient descent relates to Newton’s method.)
To implement one step of this parameter update rule, we want the vector of derivatives
In the neural network literature, this step is accomplished by what is known as back propagation.
14.4. Back Propagation and the Chain Rule#
Thanks to properties of
the chain and product rules for differentiation from differential calculus, and
lower triangular matrices
back propagation can actually be accomplished in one step by
inverting a lower triangular matrix, and
matrix multiplication
(This idea is from the last 7 minutes of this great youtube video by MIT’s Alan Edelman)
Here goes.
Define the derivative of
or
Repeated application of the chain rule and product rule to our recursion (14.1) allows us to obtain:
After imposing
or
which implies that
which in turn implies
We can then solve the above problem by applying our update for
14.5. Training Set#
Choosing a training set amounts to a choice of measure
In this spirit, we shall use a uniform grid of, say, 50 or 200 points.
There are many possible approaches to the minimization problem posed above:
batch gradient descent in which you use an average gradient over the training set
stochastic gradient descent in which you sample points randomly and use individual gradients
something in-between (so-called “mini-batch gradient descent”)
The update rule (14.2) described above amounts to a stochastic gradient descent algorithm.
from IPython.display import Image
import jax.numpy as jnp
from jax import grad, jit, jacfwd, vmap
from jax import random
import jax
import plotly.graph_objects as go
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1.):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
def compute_xδw_seq(params, x):
# Initialize arrays
δ = jnp.zeros(len(params))
xs = jnp.zeros(len(params) + 1)
ws = jnp.zeros(len(params))
bs = jnp.zeros(len(params))
h = jax.nn.sigmoid
xs = xs.at[0].set(x)
for i, (w, b) in enumerate(params[:-1]):
output = w * xs[i] + b
activation = h(output[0, 0])
# Store elements
δ = δ.at[i].set(grad(h)(output[0, 0]))
ws = ws.at[i].set(w[0, 0])
bs = bs.at[i].set(b[0])
xs = xs.at[i+1].set(activation)
final_w, final_b = params[-1]
preds = final_w * xs[-2] + final_b
# Store elements
δ = δ.at[-1].set(1.)
ws = ws.at[-1].set(final_w[0, 0])
bs = bs.at[-1].set(final_b[0])
xs = xs.at[-1].set(preds[0, 0])
return xs, δ, ws, bs
def loss(params, x, y):
xs, δ, ws, bs = compute_xδw_seq(params, x)
preds = xs[-1]
return 1 / 2 * (y - preds) ** 2
# Parameters
N = 3 # Number of layers
layer_sizes = [1, ] * (N + 1)
param_scale = 0.1
step_size = 0.01
params = init_network_params(layer_sizes, random.PRNGKey(1))
x = 5
y = 3
xs, δ, ws, bs = compute_xδw_seq(params, x)
dxs_ad = jacfwd(lambda params, x: compute_xδw_seq(params, x)[0], argnums=0)(params, x)
dxs_ad_mat = jnp.block([dx.reshape((-1, 1)) for dx_tuple in dxs_ad for dx in dx_tuple ])[1:]
jnp.block([[δ * xs[:-1]], [δ]])
Array([[1.0165801 , 0.06087969, 0.09382247],
[0.20331602, 0.08501981, 1. ]], dtype=float32)
L = jnp.diag(δ * ws, k=-1)
L = L[1:, 1:]
D = jax.scipy.linalg.block_diag(*[row.reshape((1, 2)) for row in jnp.block([[δ * xs[:-1]], [δ]]).T])
dxs_la = jax.scipy.linalg.solve_triangular(jnp.eye(N) - L, D, lower=True)
# Check that the `dx` generated by the linear algebra method
# are the same as the ones generated using automatic differentiation
jnp.max(jnp.abs(dxs_ad_mat - dxs_la))
Array(0., dtype=float32)
grad_loss_ad = jnp.block([dx.reshape((-1, 1)) for dx_tuple in grad(loss)(params, x, y) for dx in dx_tuple ])
# Check that the gradient of the loss is the same for both approaches
jnp.max(jnp.abs(-(y - xs[-1]) * dxs_la[-1] - grad_loss_ad))
Array(5.9604645e-08, dtype=float32)
@jit
def update_ad(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
@jit
def update_la(params, x, y):
xs, δ, ws, bs = compute_xδw_seq(params, x)
N = len(params)
L = jnp.diag(δ * ws, k=-1)
L = L[1:, 1:]
D = jax.scipy.linalg.block_diag(*[row.reshape((1, 2)) for row in jnp.block([[δ * xs[:-1]], [δ]]).T])
dxs_la = jax.scipy.linalg.solve_triangular(jnp.eye(N) - L, D, lower=True)
grads = -(y - xs[-1]) * dxs_la[-1]
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads.reshape((-1, 2)))]
# Check that both updates are the same
update_la(params, x, y)
[(Array([[-0.00826643]], dtype=float32), Array([0.94700736], dtype=float32)),
(Array([[-2.0638916]], dtype=float32), Array([-0.7872697], dtype=float32)),
(Array([[1.6248171]], dtype=float32), Array([1.5765371], dtype=float32))]
update_ad(params, x, y)
[(Array([[-0.00826644]], dtype=float32), Array([0.94700736], dtype=float32)),
(Array([[-2.0638916]], dtype=float32), Array([-0.7872697], dtype=float32)),
(Array([[1.6248171]], dtype=float32), Array([1.5765371], dtype=float32))]
14.6. Example 1#
Consider the function
on
We use a uniform grid of 200 points and update the parameters for each point on the grid 300 times.
Weights are initialized randomly.
def f(x):
return -3 * x + 2
M = 200
grid = jnp.linspace(0.5, 3, num=M)
f_val = f(grid)
indices = jnp.arange(M)
key = random.PRNGKey(0)
def train(params, grid, f_val, key, num_epochs=300):
for epoch in range(num_epochs):
key, _ = random.split(key)
random_permutation = random.permutation(random.PRNGKey(1), indices)
for x, y in zip(grid[random_permutation], f_val[random_permutation]):
params = update_la(params, x, y)
return params
# Parameters
N = 3 # Number of layers
layer_sizes = [1, ] * (N + 1)
params_ex1 = init_network_params(layer_sizes, key)
%%time
params_ex1 = train(params_ex1, grid, f_val, key, num_epochs=500)
CPU times: user 24 s, sys: 4.18 s, total: 28.2 s
Wall time: 18.1 s
predictions = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex1, grid)[0][:, -1]
fig = go.Figure()
fig.add_trace(go.Scatter(x=grid, y=f_val, name=r'$-3x+2$'))
fig.add_trace(go.Scatter(x=grid, y=predictions, name='Approximation'))
# Export to PNG file
Image(fig.to_image(format="png"))
# fig.show() will provide interactive plot when running
# notebook locally

14.7. How Deep?#
It is fun to think about how deepening the neural net for the above example affects the quality of approximation
If the network is too deep, you’ll run into the vanishing gradient problem
Other parameters such as the step size and the number of epochs can be as important or more important than the number of layers in the situation considered in this lecture.
Indeed, since
is a linear function of , a one-layer network with the identity map as an activation would probably work best.
14.8. Example 2#
We use the same setup as for the previous example with
def f(x):
return jnp.log(x)
grid = jnp.linspace(0.5, 3, num=M)
f_val = f(grid)
# Parameters
N = 1 # Number of layers
layer_sizes = [1, ] * (N + 1)
params_ex2_1 = init_network_params(layer_sizes, key)
# Parameters
N = 2 # Number of layers
layer_sizes = [1, ] * (N + 1)
params_ex2_2 = init_network_params(layer_sizes, key)
# Parameters
N = 3 # Number of layers
layer_sizes = [1, ] * (N + 1)
params_ex2_3 = init_network_params(layer_sizes, key)
params_ex2_1 = train(params_ex2_1, grid, f_val, key, num_epochs=300)
params_ex2_2 = train(params_ex2_2, grid, f_val, key, num_epochs=300)
params_ex2_3 = train(params_ex2_3, grid, f_val, key, num_epochs=300)
predictions_1 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_1, grid)[0][:, -1]
predictions_2 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_2, grid)[0][:, -1]
predictions_3 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_3, grid)[0][:, -1]
fig = go.Figure()
fig.add_trace(go.Scatter(x=grid, y=f_val, name=r'$\log{x}$'))
fig.add_trace(go.Scatter(x=grid, y=predictions_1, name='One-layer neural network'))
fig.add_trace(go.Scatter(x=grid, y=predictions_2, name='Two-layer neural network'))
fig.add_trace(go.Scatter(x=grid, y=predictions_3, name='Three-layer neural network'))
# Export to PNG file
Image(fig.to_image(format="png"))
# fig.show() will provide interactive plot when running
# notebook locally

## to check that gpu is activated in environment
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
gpu
/tmp/ipykernel_3082/1861301157.py:4: DeprecationWarning:
jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
Note
Cloud Environment: This lecture site is built in a server environment that doesn’t have access to a gpu
If you run this lecture locally this lets you know where your code is being executed, either
via the cpu
or the gpu