14. Introduction to Artificial Neural Networks#

GPU

This lecture was built using a machine with the latest CUDA and CUDANN frameworks installed with access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install the software listed following this notice.

!pip install --upgrade jax

In addition to what’s included in base Anaconda, we need to install the following packages

!pip install kaleido
!conda install -y -c plotly plotly plotly-orca retrying
Hide code cell output
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/79.9 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━ 60.6/79.9 MB 353.1 MB/s eta 0:00:01
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 79.9/79.9 MB 231.9 MB/s eta 0:00:00
?25h
Installing collected packages: kaleido
Successfully installed kaleido-0.2.1
Channels:
 - plotly
 - default
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): - 
\ 
| 
/ 
- 
\ 
| 
done
Solving environment: - 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
done
## Package Plan ##

  environment location: /home/runner/miniconda3/envs/quantecon

  added / updated specs:
    - plotly
    - plotly-orca
    - retrying


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    plotly-orca-1.3.1          |                1        56.6 MB  plotly
    retrying-1.3.3             |     pyhd3eb1b0_2          14 KB
    ------------------------------------------------------------
                                           Total:        56.6 MB

The following NEW packages will be INSTALLED:

  plotly-orca        plotly/linux-64::plotly-orca-1.3.1-1 
  retrying           pkgs/main/noarch::retrying-1.3.3-pyhd3eb1b0_2 



Downloading and Extracting Packages:

plotly-orca-1.3.1    | 56.6 MB   |                                       |   0% 

retrying-1.3.3       | 14 KB     |                                       |   0% 
retrying-1.3.3       | 14 KB     | ##################################### | 100% 
retrying-1.3.3       | 14 KB     | ##################################### | 100% 
plotly-orca-1.3.1    | 56.6 MB   | ##################################### | 100% 
plotly-orca-1.3.1    | 56.6 MB   | ##################################### | 100% 
plotly-orca-1.3.1    | 56.6 MB   | ##################################### | 100% 
                                                                                


                                                                                

Preparing transaction: | 
/ 
done
Verifying transaction: \ 
| 
done
Executing transaction: - 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
done

Note

If you are running this on Google Colab the above cell will present an error. This is because Google Colab doesn’t use Anaconda to manage the Python packages. However this lecture will still execute as Google Colab has plotly installed.

14.1. Overview#

Substantial parts of machine learning and artificial intelligence are about

  • approximating an unknown function with a known function

  • estimating the known function from a set of data on the left- and right-hand variables

This lecture describes the structure of a plain vanilla artificial neural network (ANN) of a type that is widely used to approximate a function f that maps x in a space X into y in a space Y.

To introduce elementary concepts, we study an example in which x and y are scalars.

We’ll describe the following concepts that are brick and mortar for neural networks:

  • a neuron

  • an activation function

  • a network of neurons

  • A neural network as a composition of functions

  • back-propagation and its relationship to the chain rule of differential calculus

14.2. A Deep (but not Wide) Artificial Neural Network#

We describe a “deep” neural network of “width” one.

Deep means that the network composes a large number of functions organized into nodes of a graph.

Width refers to the number of right hand side variables on the right hand side of the function being approximated.

Setting “width” to one means that the network composes just univariate functions.

Let xR be a scalar and yR be another scalar.

We assume that y is a nonlinear function of x:

y=f(x)

We want to approximate f(x) with another function that we define recursively.

For a network of depth N1, each layer i=1,N consists of

  • an input xi

  • an affine function wixi+bI, where wi is a scalar weight placed on the input xi and bi is a scalar bias

  • an activation function hi that takes (wixi+bi) as an argument and produces an output xi+1

An example of an activation function h is the sigmoid function

h(z)=11+ez

Another popular activation function is the rectified linear unit (ReLU) function

h(z)=max(0,z)

Yet another activation function is the identity function

h(z)=z

As activation functions below, we’ll use the sigmoid function for layers 1 to N1 and the identity function for layer N.

To approximate a function f(x) we construct f^(x) by proceeding as follows.

Let

li(x)=wix+bi.

We construct f^ by iterating on compositions of functions hili:

f(x)f^(x)=hNlNhN1l1h1l1(x)

If N>1, we call the right side a “deep” neural net.

The larger is the integer N, the “deeper” is the neural net.

Evidently, if we know the parameters {wi,bi}i=1N, then we can compute f^(x) for a given x=x~ by iterating on the recursion

(14.1)#xi+1=hili(xi),,i=1,N

starting from x1=x~.

The value of xN+1 that emerges from this iterative scheme equals f^(x~).

14.3. Calibrating Parameters#

We now consider a neural network like the one describe above with width 1, depth N, and activation functions hi for 1iN that map R into itself.

Let {(wi,bi)}i=1N denote a sequence of weights and biases.

As mentioned above, for a given input x1, our approximating function f^ evaluated at x1 equals the “output” xN+1 from our network that can be computed by iterating on xi+1=hi(wixi+bi).

For a given prediction y^(x) and target y=f(x), consider the loss function

L(y^,y)(x)=12(y^y)2(x).

This criterion is a function of the parameters {(wi,bi)}i=1N and the point x.

We’re interested in solving the following problem:

min{(wi,bi)}i=1NL(xN+1,y)(x)dμ(x)

where μ(x) is some measure of points xR over which we want a good approximation f^(x) to f(x).

Stack weights and biases into a vector of parameters p:

p=[w1b1w2b2wNbN]

Applying a “poor man’s version” of a stochastic gradient descent algorithm for finding a zero of a function leads to the following update rule for parameters:

(14.2)#pk+1=pkαdLdxN+1dxN+1dpk

where dLdxN+1=(xN+1y) and α>0 is a step size.

(See this and this to gather insights about how stochastic gradient descent relates to Newton’s method.)

To implement one step of this parameter update rule, we want the vector of derivatives dxN+1dpk.

In the neural network literature, this step is accomplished by what is known as back propagation.

14.4. Back Propagation and the Chain Rule#

Thanks to properties of

  • the chain and product rules for differentiation from differential calculus, and

  • lower triangular matrices

back propagation can actually be accomplished in one step by

  • inverting a lower triangular matrix, and

  • matrix multiplication

(This idea is from the last 7 minutes of this great youtube video by MIT’s Alan Edelman)

Here goes.

Define the derivative of h(z) with respect to z evaluated at z=zi as δi:

δi=ddzh(z)|z=zi

or

δi=h(wixi+bi).

Repeated application of the chain rule and product rule to our recursion (14.1) allows us to obtain:

dxi+1=δi(dwixi+widxi+bi)

After imposing dx1=0, we get the following system of equations:

(dx2dxN+1)=(δ1w1δ10000000000δNwNδN)D(dw1db1dwNdbN)+(0000w200000000wN0)L(dx2dxN+1)

or

dx=Ddp+Ldx

which implies that

dx=(IL)1Ddp

which in turn implies

(dxN+1/dw1dxN+1/db1dxN+1/dwNdxN+1/dbN)=eN(IL)1D.

We can then solve the above problem by applying our update for p multiple times for a collection of input-output pairs {(x1i,yi)}i=1M that we’ll call our “training set”.

14.5. Training Set#

Choosing a training set amounts to a choice of measure μ in the above formulation of our function approximation problem as a minimization problem.

In this spirit, we shall use a uniform grid of, say, 50 or 200 points.

There are many possible approaches to the minimization problem posed above:

  • batch gradient descent in which you use an average gradient over the training set

  • stochastic gradient descent in which you sample points randomly and use individual gradients

  • something in-between (so-called “mini-batch gradient descent”)

The update rule (14.2) described above amounts to a stochastic gradient descent algorithm.

from IPython.display import Image
import jax.numpy as jnp
from jax import grad, jit, jacfwd, vmap
from jax import random
import jax
import plotly.graph_objects as go
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1.):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
def compute_xδw_seq(params, x):
    # Initialize arrays
    δ = jnp.zeros(len(params))
    xs = jnp.zeros(len(params) + 1)
    ws = jnp.zeros(len(params))
    bs = jnp.zeros(len(params))
    
    h = jax.nn.sigmoid
    
    xs = xs.at[0].set(x)
    for i, (w, b) in enumerate(params[:-1]):
        output = w * xs[i] + b
        activation = h(output[0, 0])
        
        # Store elements
        δ = δ.at[i].set(grad(h)(output[0, 0]))
        ws = ws.at[i].set(w[0, 0])
        bs = bs.at[i].set(b[0])
        xs = xs.at[i+1].set(activation)

    final_w, final_b = params[-1]
    preds = final_w * xs[-2] + final_b
    
    # Store elements
    δ = δ.at[-1].set(1.)
    ws = ws.at[-1].set(final_w[0, 0])
    bs = bs.at[-1].set(final_b[0])
    xs = xs.at[-1].set(preds[0, 0])
    
    return xs, δ, ws, bs
    

def loss(params, x, y):
    xs, δ, ws, bs = compute_xδw_seq(params, x)
    preds = xs[-1]
    
    return 1 / 2 * (y - preds) ** 2
# Parameters
N = 3  # Number of layers
layer_sizes = [1, ] * (N + 1)
param_scale = 0.1
step_size = 0.01
params = init_network_params(layer_sizes, random.PRNGKey(1))
x = 5
y = 3
xs, δ, ws, bs = compute_xδw_seq(params, x)
dxs_ad = jacfwd(lambda params, x: compute_xδw_seq(params, x)[0], argnums=0)(params, x)
dxs_ad_mat = jnp.block([dx.reshape((-1, 1)) for dx_tuple in dxs_ad for dx in dx_tuple ])[1:]
jnp.block([[δ * xs[:-1]], [δ]])
Array([[1.0165801 , 0.06087969, 0.09382247],
       [0.20331602, 0.08501981, 1.        ]], dtype=float32)
L = jnp.diag(δ * ws, k=-1)
L = L[1:, 1:]

D = jax.scipy.linalg.block_diag(*[row.reshape((1, 2)) for row in jnp.block([[δ * xs[:-1]], [δ]]).T])

dxs_la = jax.scipy.linalg.solve_triangular(jnp.eye(N) - L, D, lower=True)
# Check that the `dx` generated by the linear algebra method
# are the same as the ones generated using automatic differentiation
jnp.max(jnp.abs(dxs_ad_mat - dxs_la))
Array(0., dtype=float32)
grad_loss_ad = jnp.block([dx.reshape((-1, 1)) for dx_tuple in grad(loss)(params, x, y) for dx in dx_tuple ])
# Check that the gradient of the loss is the same for both approaches
jnp.max(jnp.abs(-(y - xs[-1]) * dxs_la[-1] - grad_loss_ad))
Array(5.9604645e-08, dtype=float32)
@jit
def update_ad(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

@jit
def update_la(params, x, y):
    xs, δ, ws, bs = compute_xδw_seq(params, x)
    N = len(params)
    L = jnp.diag(δ * ws, k=-1)
    L = L[1:, 1:]

    D = jax.scipy.linalg.block_diag(*[row.reshape((1, 2)) for row in jnp.block([[δ * xs[:-1]], [δ]]).T])
    
    dxs_la = jax.scipy.linalg.solve_triangular(jnp.eye(N) - L, D, lower=True)
    
    grads = -(y - xs[-1]) * dxs_la[-1]
    
    return [(w - step_size * dw, b - step_size * db) 
            for (w, b), (dw, db) in zip(params, grads.reshape((-1, 2)))]
    
# Check that both updates are the same
update_la(params, x, y)
[(Array([[-0.00826643]], dtype=float32), Array([0.94700736], dtype=float32)),
 (Array([[-2.0638916]], dtype=float32), Array([-0.7872697], dtype=float32)),
 (Array([[1.6248171]], dtype=float32), Array([1.5765371], dtype=float32))]
update_ad(params, x, y)
[(Array([[-0.00826644]], dtype=float32), Array([0.94700736], dtype=float32)),
 (Array([[-2.0638916]], dtype=float32), Array([-0.7872697], dtype=float32)),
 (Array([[1.6248171]], dtype=float32), Array([1.5765371], dtype=float32))]

14.6. Example 1#

Consider the function

f(x)=3x+2

on [0.5,3].

We use a uniform grid of 200 points and update the parameters for each point on the grid 300 times.

hi is the sigmoid activation function for all layers except the final one for which we use the identity function and N=3.

Weights are initialized randomly.

def f(x):
    return -3 * x + 2

M = 200
grid = jnp.linspace(0.5, 3, num=M)
f_val = f(grid)
indices = jnp.arange(M)
key = random.PRNGKey(0)

def train(params, grid, f_val, key, num_epochs=300):
    for epoch in range(num_epochs):
        key, _ = random.split(key)
        random_permutation = random.permutation(random.PRNGKey(1), indices)
        for x, y in zip(grid[random_permutation], f_val[random_permutation]):
            params = update_la(params, x, y)
            
    return params 
# Parameters
N = 3  # Number of layers
layer_sizes = [1, ] * (N + 1)
params_ex1 = init_network_params(layer_sizes, key)
%%time 
params_ex1 = train(params_ex1, grid, f_val, key, num_epochs=500)
CPU times: user 24 s, sys: 4.18 s, total: 28.2 s
Wall time: 18.1 s
predictions = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex1, grid)[0][:, -1]
fig = go.Figure()
fig.add_trace(go.Scatter(x=grid, y=f_val, name=r'$-3x+2$'))
fig.add_trace(go.Scatter(x=grid, y=predictions, name='Approximation'))

# Export to PNG file
Image(fig.to_image(format="png"))
# fig.show() will provide interactive plot when running
# notebook locally
_images/b34154f214c38ed18046a1569778c0b2509d03dd245a22f4b2d0ed6186e4def5.png

14.7. How Deep?#

It is fun to think about how deepening the neural net for the above example affects the quality of approximation

  • If the network is too deep, you’ll run into the vanishing gradient problem

  • Other parameters such as the step size and the number of epochs can be as important or more important than the number of layers in the situation considered in this lecture.

  • Indeed, since f is a linear function of x, a one-layer network with the identity map as an activation would probably work best.

14.8. Example 2#

We use the same setup as for the previous example with

f(x)=log(x)
def f(x):
    return jnp.log(x)

grid = jnp.linspace(0.5, 3, num=M)
f_val = f(grid)
# Parameters
N = 1  # Number of layers
layer_sizes = [1, ] * (N + 1)
params_ex2_1 = init_network_params(layer_sizes, key)
# Parameters
N = 2  # Number of layers
layer_sizes = [1, ] * (N + 1)
params_ex2_2 = init_network_params(layer_sizes, key)
# Parameters
N = 3  # Number of layers
layer_sizes = [1, ] * (N + 1)
params_ex2_3 = init_network_params(layer_sizes, key)
params_ex2_1 = train(params_ex2_1, grid, f_val, key, num_epochs=300)
params_ex2_2 = train(params_ex2_2, grid, f_val, key, num_epochs=300)
params_ex2_3 = train(params_ex2_3, grid, f_val, key, num_epochs=300)
predictions_1 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_1, grid)[0][:, -1]
predictions_2 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_2, grid)[0][:, -1]
predictions_3 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_3, grid)[0][:, -1]
fig = go.Figure()
fig.add_trace(go.Scatter(x=grid, y=f_val, name=r'$\log{x}$'))
fig.add_trace(go.Scatter(x=grid, y=predictions_1, name='One-layer neural network'))
fig.add_trace(go.Scatter(x=grid, y=predictions_2, name='Two-layer neural network'))
fig.add_trace(go.Scatter(x=grid, y=predictions_3, name='Three-layer neural network'))

# Export to PNG file
Image(fig.to_image(format="png"))
# fig.show() will provide interactive plot when running
# notebook locally
_images/d33077bf392713eab66f3e336bcaa9e4b04ea016ba23bdbc6e148ae8fca52bf4.png
## to check that gpu is activated in environment

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
gpu
/tmp/ipykernel_3082/1861301157.py:4: DeprecationWarning:

jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.

Note

Cloud Environment: This lecture site is built in a server environment that doesn’t have access to a gpu If you run this lecture locally this lets you know where your code is being executed, either via the cpu or the gpu