{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "149f5bcc",
   "metadata": {},
   "source": [
    "# Introduction to Artificial Neural Networks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3986b826",
   "metadata": {},
   "source": [
    "# GPU\n",
    "\n",
    "This lecture was built using a machine with the latest CUDA and CUDANN frameworks installed with access to a GPU.\n",
    "\n",
    "To run this lecture on [Google Colab](https://colab.research.google.com/), click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.\n",
    "\n",
    "To run this lecture on your own machine, you need to install the software listed following this notice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfa728f5",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "!pip install --upgrade jax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "970ccfab",
   "metadata": {},
   "source": [
    "In addition to what’s included in base Anaconda, we need to install the following packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b735efb",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "!pip install kaleido\n",
    "!conda install -y -c plotly plotly plotly-orca retrying"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84408049",
   "metadata": {},
   "source": [
    ">**Note**\n",
    ">\n",
    ">If you are running this on Google Colab the above cell will\n",
    "present an error. This is because Google Colab doesn’t use Anaconda to manage\n",
    "the Python packages. However this lecture will still execute as Google Colab\n",
    "has `plotly` installed."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "407e3a2d",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "Substantial parts of **machine learning** and **artificial intelligence** are about\n",
    "\n",
    "- approximating an unknown function with a known function  \n",
    "- estimating the known function from a set of data on the left- and right-hand variables  \n",
    "\n",
    "\n",
    "This lecture describes the structure of a plain vanilla **artificial neural network**  (ANN) of a type that is widely used to\n",
    "approximate a function $ f $ that maps   $ x $ in  a space $ X $ into  $ y $ in a space $ Y $.\n",
    "\n",
    "To introduce elementary concepts, we study an example in which $ x $ and $ y $ are scalars.\n",
    "\n",
    "We’ll describe the following concepts that are brick and mortar for neural networks:\n",
    "\n",
    "- a neuron  \n",
    "- an activation function  \n",
    "- a network of neurons  \n",
    "- A neural network as a composition of functions  \n",
    "- back-propagation and its relationship  to the chain rule of differential calculus  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dc54bdd",
   "metadata": {},
   "source": [
    "## A Deep (but not Wide) Artificial Neural Network\n",
    "\n",
    "We describe a  “deep” neural network of “width” one.\n",
    "\n",
    "**Deep** means that the network composes a large number of functions organized into nodes of a graph.\n",
    "\n",
    "**Width** refers to the number of right hand  side variables on the right hand side of the function being approximated.\n",
    "\n",
    "Setting “width” to one means that the network    composes just univariate functions.\n",
    "\n",
    "Let $ x \\in \\mathbb{R} $ be a scalar and $ y \\in \\mathbb{R} $ be another scalar.\n",
    "\n",
    "We assume  that $ y $ is  a nonlinear function of $ x $:\n",
    "\n",
    "$$\n",
    "y = f(x)\n",
    "$$\n",
    "\n",
    "We want to approximate  $ f(x) $ with another function that we define recursively.\n",
    "\n",
    "For a network of depth $ N \\geq 1 $, each **layer** $ i =1, \\ldots N $ consists of\n",
    "\n",
    "- an input $ x_i $  \n",
    "- an **affine function** $ w_i x_i + bI $, where $ w_i $ is a scalar **weight** placed on the input $ x_i $ and $ b_i $ is a scalar **bias**  \n",
    "- an **activation function** $ h_i $ that takes $ (w_i x_i + b_i) $ as an argument and produces an output $ x_{i+1} $  \n",
    "\n",
    "\n",
    "An example of an activation function $ h $ is the **sigmoid** function\n",
    "\n",
    "$$\n",
    "h (z) = \\frac{1}{1 + e^{-z}}\n",
    "$$\n",
    "\n",
    "Another popular activation function is the **rectified linear unit** (ReLU) function\n",
    "\n",
    "$$\n",
    "h(z) = \\max (0, z)\n",
    "$$\n",
    "\n",
    "Yet another activation function is the identity function\n",
    "\n",
    "$$\n",
    "h(z) = z\n",
    "$$\n",
    "\n",
    "As activation functions below, we’ll use the sigmoid function for layers $ 1 $ to $ N-1 $ and the identity function for  layer $ N $.\n",
    "\n",
    "To approximate a function $ f(x) $ we construct   $ \\hat f(x) $  by proceeding as follows.\n",
    "\n",
    "Let\n",
    "\n",
    "$$\n",
    "l_{i}\\left(x\\right)=w_{i}x+b_{i} .\n",
    "$$\n",
    "\n",
    "We construct  $ \\hat f $ by iterating on compositions of functions $ h_i \\circ l_i $:\n",
    "\n",
    "$$\n",
    "f(x)\\approx\\hat{f}(x)=h_{N}\\circ l_{N}\\circ h_{N-1}\\circ l_{1}\\circ\\cdots\\circ h_{1}\\circ l_{1}(x)\n",
    "$$\n",
    "\n",
    "If $ N >1 $, we call the right side a “deep” neural net.\n",
    "\n",
    "The larger is the integer $ N $, the “deeper” is the neural net.\n",
    "\n",
    "Evidently,  if we know  the parameters $ \\{w_i, b_i\\}_{i=1}^N $, then we can compute\n",
    "$ \\hat f(x) $ for a given $ x = \\tilde x $ by iterating on the recursion\n",
    "\n",
    "\n",
    "<a id='equation-eq-recursion'></a>\n",
    "$$\n",
    "x_{i+1} = h_i \\circ l_i(x_i) , \\quad, i = 1, \\ldots N \\tag{14.1}\n",
    "$$\n",
    "\n",
    "starting from $ x_1 = \\tilde x $.\n",
    "\n",
    "The value of $ x_{N+1} $ that emerges from this iterative scheme\n",
    "equals $ \\hat f(\\tilde x) $."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ef9b6f2",
   "metadata": {},
   "source": [
    "## Calibrating  Parameters\n",
    "\n",
    "We now consider a  neural network like the one describe above  with width 1, depth $ N $,  and activation functions $ h_{i} $ for $ 1\\leqslant i\\leqslant N $ that map $ \\mathbb{R} $ into itself.\n",
    "\n",
    "Let $ \\left\\{ \\left(w_{i},b_{i}\\right)\\right\\} _{i=1}^{N} $ denote a sequence of weights and biases.\n",
    "\n",
    "As mentioned above, for a given input $ x_{1} $, our approximating function $ \\hat f $ evaluated\n",
    "at $ x_1 $ equals the “output” $ x_{N+1} $ from our network that  can be computed by iterating on $ x_{i+1}=h_{i}\\left(w_{i}x_{i}+b_{i}\\right) $.\n",
    "\n",
    "For a given **prediction** $ \\hat{y} (x) $ and **target** $ y= f(x) $, consider the loss function\n",
    "\n",
    "$$\n",
    "\\mathcal{L} \\left(\\hat{y},y\\right)(x)=\\frac{1}{2}\\left(\\hat{y}-y\\right)^{2}(x) .\n",
    "$$\n",
    "\n",
    "This criterion is a function of the parameters $ \\left\\{ \\left(w_{i},b_{i}\\right)\\right\\} _{i=1}^{N} $\n",
    "and the point $ x $.\n",
    "\n",
    "We’re interested in solving the following problem:\n",
    "\n",
    "$$\n",
    "\\min_{\\left\\{ \\left(w_{i},b_{i}\\right)\\right\\} _{i=1}^{N}} \\int {\\mathcal L}\\left(x_{N+1},y\\right)(x) d \\mu(x)\n",
    "$$\n",
    "\n",
    "where $ \\mu(x) $ is some measure of  points $ x \\in \\mathbb{R} $ over which we want a good approximation $ \\hat f(x) $ to $ f(x) $.\n",
    "\n",
    "Stack  weights and biases into a vector of parameters $ p $:\n",
    "\n",
    "$$\n",
    "p = \\begin{bmatrix}     \n",
    "  w_1 \\cr \n",
    "  b_1 \\cr\n",
    "  w_2 \\cr\n",
    "  b_2 \\cr\n",
    "  \\vdots \\cr\n",
    "  w_N \\cr\n",
    "  b_N \n",
    "\\end{bmatrix}\n",
    "$$\n",
    "\n",
    "Applying a “poor man’s version” of a **stochastic gradient descent** algorithm for finding a zero of a function leads to the following update rule for parameters:\n",
    "\n",
    "\n",
    "<a id='equation-eq-sgd'></a>\n",
    "$$\n",
    "p_{k+1}=p_k-\\alpha\\frac{d \\mathcal{L}}{dx_{N+1}}\\frac{dx_{N+1}}{dp_k} \\tag{14.2}\n",
    "$$\n",
    "\n",
    "where $ \\frac{d {\\mathcal L}}{dx_{N+1}}=-\\left(x_{N+1}-y\\right) $ and $ \\alpha > 0 $ is a step size.\n",
    "\n",
    "(See [this](https://en.wikipedia.org/wiki/Gradient_descent#Description) and [this](https://en.wikipedia.org/wiki/Newton%27s_method) to gather insights about how stochastic gradient descent\n",
    "relates to Newton’s method.)\n",
    "\n",
    "To implement one step of this parameter update rule, we want  the vector of derivatives $ \\frac{dx_{N+1}}{dp_k} $.\n",
    "\n",
    "In the neural network literature, this step is accomplished by what is known as **back propagation**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76d0f097",
   "metadata": {},
   "source": [
    "## Back Propagation and the Chain Rule\n",
    "\n",
    "Thanks to  properties of\n",
    "\n",
    "- the chain and product rules for differentiation from differential calculus, and  \n",
    "- lower triangular matrices  \n",
    "\n",
    "\n",
    "back propagation can actually be  accomplished in one step by\n",
    "\n",
    "- inverting a lower triangular matrix,  and  \n",
    "- matrix multiplication  \n",
    "\n",
    "\n",
    "(This idea  is from the last 7 minutes of this great youtube video by MIT’s Alan Edelman)\n",
    "\n",
    "Here goes.\n",
    "\n",
    "Define the derivative of $ h(z) $ with respect to $ z $ evaluated at $ z = z_i $  as $ \\delta_i $:\n",
    "\n",
    "$$\n",
    "\\delta_i = \\frac{d}{d z} h(z)|_{z=z_i}\n",
    "$$\n",
    "\n",
    "or\n",
    "\n",
    "$$\n",
    "\\delta_{i}=h'\\left(w_{i}x_{i}+b_{i}\\right).\n",
    "$$\n",
    "\n",
    "Repeated application of the chain rule and product rule to our recursion [(14.1)](#equation-eq-recursion) allows us to obtain:\n",
    "\n",
    "$$\n",
    "dx_{i+1}=\\delta_{i}\\left(dw_{i}x_{i}+w_{i}dx_{i}+b_{i}\\right)\n",
    "$$\n",
    "\n",
    "After imposing $ dx_{1}=0 $, we get the following system of equations:\n",
    "\n",
    "$$\n",
    "\\left(\\begin{array}{c}\n",
    "dx_{2}\\\\\n",
    "\\vdots\\\\\n",
    "dx_{N+1}\n",
    "\\end{array}\\right)=\\underbrace{\\left(\\begin{array}{ccccc}\n",
    "\\delta_{1}w_{1} & \\delta_{1} & 0 & 0 & 0\\\\\n",
    "0 & 0 & \\ddots & 0 & 0\\\\\n",
    "0 & 0 & 0 & \\delta_{N}w_{N} & \\delta_{N}\n",
    "\\end{array}\\right)}_{D}\\left(\\begin{array}{c}\n",
    "dw_{1}\\\\\n",
    "db_{1}\\\\\n",
    "\\vdots\\\\\n",
    "dw_{N}\\\\\n",
    "db_{N}\n",
    "\\end{array}\\right)+\\underbrace{\\left(\\begin{array}{cccc}\n",
    "0 & 0 & 0 & 0\\\\\n",
    "w_{2} & 0 & 0 & 0\\\\\n",
    "0 & \\ddots & 0 & 0\\\\\n",
    "0 & 0 & w_{N} & 0\n",
    "\\end{array}\\right)}_{L}\\left(\\begin{array}{c}\n",
    "dx_{2}\\\\\n",
    "\\vdots\\\\\n",
    "dx_{N+1}\n",
    "\\end{array}\\right)\n",
    "$$\n",
    "\n",
    "or\n",
    "\n",
    "$$\n",
    "d x = D dp + L dx\n",
    "$$\n",
    "\n",
    "which implies that\n",
    "\n",
    "$$\n",
    "dx = (I -L)^{-1} D dp\n",
    "$$\n",
    "\n",
    "which in turn  implies\n",
    "\n",
    "$$\n",
    "\\left(\\begin{array}{c}\n",
    "dx_{N+1}/dw_{1}\\\\\n",
    "dx_{N+1}/db_{1}\\\\\n",
    "\\vdots\\\\\n",
    "dx_{N+1}/dw_{N}\\\\\n",
    "dx_{N+1}/db_{N}\n",
    "\\end{array}\\right)=e_{N}\\left(I-L\\right)^{-1}D.\n",
    "$$\n",
    "\n",
    "We can then solve the above problem by applying our update for $ p $ multiple times for a collection of input-output pairs $ \\left\\{ \\left(x_{1}^{i},y^{i}\\right)\\right\\} _{i=1}^{M} $ that we’ll call our “training set”."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1ab5df7",
   "metadata": {},
   "source": [
    "## Training Set\n",
    "\n",
    "Choosing a  training set amounts to a choice of measure $ \\mu $ in the above  formulation of our  function approximation problem as a minimization problem.\n",
    "\n",
    "In this spirit,  we shall use a uniform grid of, say, 50 or 200 points.\n",
    "\n",
    "There are many possible approaches to the minimization  problem posed above:\n",
    "\n",
    "- batch gradient descent in which you use an average gradient over the training set  \n",
    "- stochastic gradient descent in which you sample points randomly and use individual gradients  \n",
    "- something in-between (so-called “mini-batch gradient descent”)  \n",
    "\n",
    "\n",
    "The update rule [(14.2)](#equation-eq-sgd) described above  amounts  to a stochastic gradient descent algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d53882fc",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "import jax.numpy as jnp\n",
    "from jax import grad, jit, jacfwd, vmap\n",
    "from jax import random\n",
    "import jax\n",
    "import plotly.graph_objects as go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00c46953",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# A helper function to randomly initialize weights and biases\n",
    "# for a dense neural network layer\n",
    "def random_layer_params(m, n, key, scale=1.):\n",
    "    w_key, b_key = random.split(key)\n",
    "    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n",
    "\n",
    "# Initialize all layers for a fully-connected neural network with sizes \"sizes\"\n",
    "def init_network_params(sizes, key):\n",
    "    keys = random.split(key, len(sizes))\n",
    "    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3085e7ed",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "def compute_xδw_seq(params, x):\n",
    "    # Initialize arrays\n",
    "    δ = jnp.zeros(len(params))\n",
    "    xs = jnp.zeros(len(params) + 1)\n",
    "    ws = jnp.zeros(len(params))\n",
    "    bs = jnp.zeros(len(params))\n",
    "    \n",
    "    h = jax.nn.sigmoid\n",
    "    \n",
    "    xs = xs.at[0].set(x)\n",
    "    for i, (w, b) in enumerate(params[:-1]):\n",
    "        output = w * xs[i] + b\n",
    "        activation = h(output[0, 0])\n",
    "        \n",
    "        # Store elements\n",
    "        δ = δ.at[i].set(grad(h)(output[0, 0]))\n",
    "        ws = ws.at[i].set(w[0, 0])\n",
    "        bs = bs.at[i].set(b[0])\n",
    "        xs = xs.at[i+1].set(activation)\n",
    "\n",
    "    final_w, final_b = params[-1]\n",
    "    preds = final_w * xs[-2] + final_b\n",
    "    \n",
    "    # Store elements\n",
    "    δ = δ.at[-1].set(1.)\n",
    "    ws = ws.at[-1].set(final_w[0, 0])\n",
    "    bs = bs.at[-1].set(final_b[0])\n",
    "    xs = xs.at[-1].set(preds[0, 0])\n",
    "    \n",
    "    return xs, δ, ws, bs\n",
    "    \n",
    "\n",
    "def loss(params, x, y):\n",
    "    xs, δ, ws, bs = compute_xδw_seq(params, x)\n",
    "    preds = xs[-1]\n",
    "    \n",
    "    return 1 / 2 * (y - preds) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3368a3f6",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "N = 3  # Number of layers\n",
    "layer_sizes = [1, ] * (N + 1)\n",
    "param_scale = 0.1\n",
    "step_size = 0.01\n",
    "params = init_network_params(layer_sizes, random.PRNGKey(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b48925b",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "x = 5\n",
    "y = 3\n",
    "xs, δ, ws, bs = compute_xδw_seq(params, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "475280c6",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "dxs_ad = jacfwd(lambda params, x: compute_xδw_seq(params, x)[0], argnums=0)(params, x)\n",
    "dxs_ad_mat = jnp.block([dx.reshape((-1, 1)) for dx_tuple in dxs_ad for dx in dx_tuple ])[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c97cb59e",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "jnp.block([[δ * xs[:-1]], [δ]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10ee0647",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "L = jnp.diag(δ * ws, k=-1)\n",
    "L = L[1:, 1:]\n",
    "\n",
    "D = jax.scipy.linalg.block_diag(*[row.reshape((1, 2)) for row in jnp.block([[δ * xs[:-1]], [δ]]).T])\n",
    "\n",
    "dxs_la = jax.scipy.linalg.solve_triangular(jnp.eye(N) - L, D, lower=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd4f27dd",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Check that the `dx` generated by the linear algebra method\n",
    "# are the same as the ones generated using automatic differentiation\n",
    "jnp.max(jnp.abs(dxs_ad_mat - dxs_la))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8d3973c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "grad_loss_ad = jnp.block([dx.reshape((-1, 1)) for dx_tuple in grad(loss)(params, x, y) for dx in dx_tuple ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0266355a",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Check that the gradient of the loss is the same for both approaches\n",
    "jnp.max(jnp.abs(-(y - xs[-1]) * dxs_la[-1] - grad_loss_ad))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e34cd400",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "@jit\n",
    "def update_ad(params, x, y):\n",
    "    grads = grad(loss)(params, x, y)\n",
    "    return [(w - step_size * dw, b - step_size * db)\n",
    "          for (w, b), (dw, db) in zip(params, grads)]\n",
    "\n",
    "@jit\n",
    "def update_la(params, x, y):\n",
    "    xs, δ, ws, bs = compute_xδw_seq(params, x)\n",
    "    N = len(params)\n",
    "    L = jnp.diag(δ * ws, k=-1)\n",
    "    L = L[1:, 1:]\n",
    "\n",
    "    D = jax.scipy.linalg.block_diag(*[row.reshape((1, 2)) for row in jnp.block([[δ * xs[:-1]], [δ]]).T])\n",
    "    \n",
    "    dxs_la = jax.scipy.linalg.solve_triangular(jnp.eye(N) - L, D, lower=True)\n",
    "    \n",
    "    grads = -(y - xs[-1]) * dxs_la[-1]\n",
    "    \n",
    "    return [(w - step_size * dw, b - step_size * db) \n",
    "            for (w, b), (dw, db) in zip(params, grads.reshape((-1, 2)))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bd967a6",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Check that both updates are the same\n",
    "update_la(params, x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f47b6203",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "update_ad(params, x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "289598ad",
   "metadata": {},
   "source": [
    "## Example 1\n",
    "\n",
    "Consider the function\n",
    "\n",
    "$$\n",
    "f\\left(x\\right)=-3x+2\n",
    "$$\n",
    "\n",
    "on $ \\left[0.5,3\\right] $.\n",
    "\n",
    "We use a uniform grid of 200 points and update the parameters for each point on the grid 300 times.\n",
    "\n",
    "$ h_{i} $ is the sigmoid activation function for all layers except the final one for which we use the identity function and $ N=3 $.\n",
    "\n",
    "Weights are initialized randomly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f0ccf9",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "    return -3 * x + 2\n",
    "\n",
    "M = 200\n",
    "grid = jnp.linspace(0.5, 3, num=M)\n",
    "f_val = f(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c9e997",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "indices = jnp.arange(M)\n",
    "key = random.PRNGKey(0)\n",
    "\n",
    "def train(params, grid, f_val, key, num_epochs=300):\n",
    "    for epoch in range(num_epochs):\n",
    "        key, _ = random.split(key)\n",
    "        random_permutation = random.permutation(random.PRNGKey(1), indices)\n",
    "        for x, y in zip(grid[random_permutation], f_val[random_permutation]):\n",
    "            params = update_la(params, x, y)\n",
    "            \n",
    "    return params "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "403d4d5e",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "N = 3  # Number of layers\n",
    "layer_sizes = [1, ] * (N + 1)\n",
    "params_ex1 = init_network_params(layer_sizes, key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6898137",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "%%time \n",
    "params_ex1 = train(params_ex1, grid, f_val, key, num_epochs=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e69a07d",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "predictions = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex1, grid)[0][:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29781858",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "fig.add_trace(go.Scatter(x=grid, y=f_val, name=r'$-3x+2$'))\n",
    "fig.add_trace(go.Scatter(x=grid, y=predictions, name='Approximation'))\n",
    "\n",
    "# Export to PNG file\n",
    "Image(fig.to_image(format=\"png\"))\n",
    "# fig.show() will provide interactive plot when running\n",
    "# notebook locally"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2032d0d3",
   "metadata": {},
   "source": [
    "## How Deep?\n",
    "\n",
    "It  is  fun to think about how deepening the neural net for the above example affects the quality of  approximation\n",
    "\n",
    "- If the network is too deep, you’ll run into the [vanishing gradient problem](http://neuralnetworksanddeeplearning.com/chap5.html)  \n",
    "- Other parameters such as the step size and the number of epochs can be as  important or more important than the number of layers in the situation considered in this lecture.  \n",
    "- Indeed, since $ f $ is a linear function of $ x $, a one-layer network with the identity map as an activation would probably work best.  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4523b31",
   "metadata": {},
   "source": [
    "## Example 2\n",
    "\n",
    "We use the same setup as for the previous example with\n",
    "\n",
    "$$\n",
    "f\\left(x\\right)=\\log\\left(x\\right)\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ce69565",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "    return jnp.log(x)\n",
    "\n",
    "grid = jnp.linspace(0.5, 3, num=M)\n",
    "f_val = f(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f2a434a",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "N = 1  # Number of layers\n",
    "layer_sizes = [1, ] * (N + 1)\n",
    "params_ex2_1 = init_network_params(layer_sizes, key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a03159e",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "N = 2  # Number of layers\n",
    "layer_sizes = [1, ] * (N + 1)\n",
    "params_ex2_2 = init_network_params(layer_sizes, key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2880bb3d",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "N = 3  # Number of layers\n",
    "layer_sizes = [1, ] * (N + 1)\n",
    "params_ex2_3 = init_network_params(layer_sizes, key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "867a4342",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "params_ex2_1 = train(params_ex2_1, grid, f_val, key, num_epochs=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "617804f9",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "params_ex2_2 = train(params_ex2_2, grid, f_val, key, num_epochs=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78bff774",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "params_ex2_3 = train(params_ex2_3, grid, f_val, key, num_epochs=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "636d1d3b",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "predictions_1 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_1, grid)[0][:, -1]\n",
    "predictions_2 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_2, grid)[0][:, -1]\n",
    "predictions_3 = vmap(compute_xδw_seq, in_axes=(None, 0))(params_ex2_3, grid)[0][:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df289072",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "fig.add_trace(go.Scatter(x=grid, y=f_val, name=r'$\\log{x}$'))\n",
    "fig.add_trace(go.Scatter(x=grid, y=predictions_1, name='One-layer neural network'))\n",
    "fig.add_trace(go.Scatter(x=grid, y=predictions_2, name='Two-layer neural network'))\n",
    "fig.add_trace(go.Scatter(x=grid, y=predictions_3, name='Three-layer neural network'))\n",
    "\n",
    "# Export to PNG file\n",
    "Image(fig.to_image(format=\"png\"))\n",
    "# fig.show() will provide interactive plot when running\n",
    "# notebook locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b5449e0",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "## to check that gpu is activated in environment\n",
    "\n",
    "from jax.lib import xla_bridge\n",
    "print(xla_bridge.get_backend().platform)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "409c71cc",
   "metadata": {},
   "source": [
    ">**Note**\n",
    ">\n",
    ">**Cloud Environment:** This lecture site is built in a server environment that doesn’t have access to a `gpu`\n",
    "If you run this lecture locally this lets you know where your code is being executed, either\n",
    "via the `cpu` or the `gpu`"
   ]
  }
 ],
 "metadata": {
  "date": 1747120412.4908237,
  "filename": "back_prop.md",
  "kernelspec": {
   "display_name": "Python",
   "language": "python3",
   "name": "python3"
  },
  "title": "Introduction to Artificial Neural Networks"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}