Training a Neural Net with Dual Numbers

deep learning
automatic differentiation
Author

Koen Baak

Published

March 8, 2024

The first time someone builds a neural netwerk, it is often “from scratch”. Meaning, without the use of a Deep Learning framework like Tensorflow or PyTorch. A neural network is actually quite easy to implement. The hard part is the back propagation algorithm, the code that actually trains the network. To implement backpropagation you will need to compute the gradient of your neural network with respect to its parameters. This is doable by hand for a small network with only few layers. Although this is fun to do and will give you a good intuition on how neural networks train, the approach soons needs to be abandoned.

This is bad, so in practice people do not actually compute gradients by hand. Instead, they use a Deep Learning framework. These frameworks do all the gradient computation automatically for you. But how?

The two most used algorithms in automatic Differentiation are called forward mode AD and reverse mode AD. For reasons that will be mentioned later on, everyone in Deep Learning uses reverse mode AD. In fact, it would be incredibly stupid and strange to train a neural network with forward mode AD. So of course, that is precisely what we will do in this blog post.

Although it has no practicality in Deep Learning, forward mode AD is still a beautiful technique and it often considered to be the more approachable AD technique. It is based on the concenpt of Dual Numbers. Dual numbers form an alternative number system (similar to complex numbers) that work very nicely with differentiable functions. In this blogpost we will introduce them, implement them in Python and train a neural network with them on the famous MNIST dataset.

One last thing before we begin. One might ask “why not just use finite differencing?”. That is, one can reasonably suggest to just use the definition of a derivative to compute it: \[\begin{align*} f\pr(x) = \lim_{x\to\infty} \dfrac{f(x) + f(x+h)}{h}. \end{align*}\]

By just picking a small \(h\) we can compute an approximation of \(f\pr(x)\). In practice this turns out to be inpractical for training deep neural nets. When \(h\) is too large, the error between the estimated value and the true value is too large (this is called the truncation error). When \(h\) is too small, the error caused by the fact that we are working with floating point numbers instead of real real numbers is too large (the roundoff error). As neural nets involve a lot of computational steps, these errors accumulate and simply become too big.

All the code in this blog post, can be found in this repo.

Dual Numbers

We can introduce dual numbers by first remembering complex numbers. A complex number is a number of the form \(a + bi\) with \(a, b\in \R\) and \(i\) a special number satisfying \(i^2 = -1\). Similarly a dual numbers is a number of the form \(a+b\eps\) with \(a, b\in\R\) and \(\eps\) a special number satisfying \(\eps^2 = 0\). Although this set up seems extremely similar, the structure of dual numbers, \(\mathbb D\), is very different from the reals or the complex plane. This is because the duals do not form a field. With the real numbers and the complex numbers (and the rationals) we can do \(x/y\) for any \(x\) and \(y\) as long as \(y\) is not equal to zero. With the dual numbers, we can not do this. It is easy to see this, for example, we can not divide by \(\eps\), otherwise we would have \(1 = \eps / \eps = \eps^2 / \eps^2\) but \(\eps^2 = 0\) so we arrive at a contradiction. If you are comfortable with abstract algebra, we can define the dual numbers to be the quotient ring \(\R[x]/(x^2)\) and as \((x^2)\) is not a maximal ideal, \(\mathbb D\) is not a field.

One of the cool things about dual numbers is it’s interaction with derivatives. Let \(a, b\in\R\). We have \[\begin{align*} (a+b\eps)^2 &= a^2 + 2ab\eps + b\eps^2 \\ &= a^2 + 2ab\eps \\ \\ (a+b\eps)^3 &= (a^2 + 2ab\eps)(a+b\eps)\\ &= a^3 + a^2b\eps + 2a^2b\eps + 2ab^2\eps^2 \\ &= a^3 + 3a^2b\eps \\ \\ &\ldots \\ \\ (a+b\eps)^n &= (a^n + na^{n-1}b\eps). \end{align*}\]

Generally, for any polynomial \(p\) we have \[\begin{align*} p(a+b\eps) &= p(a) + p\pr(a)b\eps. \end{align*}\]

Now let us consider any function \(f\c\R\to\R\). If \(f\) is cool enough, it can be approximated by a sequence of Taylor polynomials. It becomes natural to make the following definition.

Definition.

Let \(f\c\R\to\R\) be an analytic function. The dual extension of \(f\) is the function \(\tilde f\c\mathbb{D}\to\mathbb{D}\) given by \[\begin{align*} a + b\eps \mapsto f(a) + f\pr(a)b\eps. \end{align*}\]

If we think about this long enough, we can view the first glimpse of how dual numbers can be used for automatical differentiation. Suppose there is some way in which we can evaulate a function \(f\c\R\to\R\) implemented in Python at a dual number of the form \(a+\eps\), than by looking at the result, we get both the values \(f(a)\) and \(f\pr(a)\). How we achieve this, will be explained shortly. First, it’s about time we write some code. So let us implement dual numbers in Python1!

dual/number.py
@subtract_using_negative
@commutative_multiplication
@commutative_addition
class DualNumber:
    def __init__(self, real: float = 0.0, dual: float = 0.0) -> None:
        self.real = real
        self.dual = dual

    def __repr__(self) -> str:
        return f"{self.real} + {self.dual}ε"

    def __add__(self, other: t.Any) -> "DualNumber":
        match other:
            case DualNumber():
                return DualNumber(
                    real=self.real + other.real, dual=self.dual + other.dual
                )
            case float() | int():
                return DualNumber(real=self.real + other, dual=self.dual)
            case _:
                return NotImplemented

    def __mul__(self, other: t.Any) -> "DualNumber":
        match other:
            case DualNumber():
                return DualNumber(
                    real=self.real * other.real,
                    dual=self.real * other.dual + other.real * self.dual,
                )
            case float() | int():
                return DualNumber(real=self.real * other, dual=self.dual * other)
            case _:
                return NotImplemented

    def __neg__(self) -> "DualNumber":
        return self.__mul__(-1)

AD with Dual Numbers

Before we can start differentiating arbitrary functions, we need a few more ingredients. So here it is, behold the niceness of dual extensions.

Proposition.

Let \(f\c\R\to\R\) and \(g\c\R\to\R\) be an analytic functions. We have \[\begin{align*} \widetilde{f+g} &= \tilde f + \tilde g\\ \widetilde{f\cdot g} &= \tilde f \cdot \tilde g \\ \widetilde{f\circ g} &= \tilde f \circ \tilde g. \end{align*}\]

Proof.

The result for pointwise addition is easy.

The result for pointwise multiplication is just the product rule: \[\begin{align*} (\tilde f \cdot \tilde g)(a+b\eps) &= (f(a) + f\pr(a)b\eps)(g(a) + g\pr(a)b\eps)\\ &= f(a)g(a) + f(a)g\pr(a)b\eps + f\pr(a)g(a)b\eps + f\pr(a)g\pr(a)b^2\eps^2 \\ &= f(a)g(a) + f(a)g\pr(a)b\eps + f\pr(a)g(a)b\eps \\ &= f(a)g(a) + f\pr(a)g(a) + f(a)g\pr(a)b\eps \\ &= (f\cdot g)(a) + (f\cdot g)\pr(a)b\eps \\ &= \widetilde{f\cdot g} (a + b\eps) \end{align*}\]

The result for composition is just the chain rule: \[\begin{align*} (\tilde f \circ \tilde g)(a + b\eps) &= \tilde f(g(a) + g\pr(a)b\eps) \\ &= f(g(a)) + f\pr(g(a))g\pr(a)b\eps \\ &= (f\circ g)(a) + (f\circ g)\pr(a)b\eps \\ &= \widetilde{f\circ g}(a + b\eps). \end{align*}\]

This result suggest a very effective way for automatic differentiation of a very big class of analytic functions. First, we start with a small set of functions and explicitely implement the dual extensions of these functions. Then, any function that can be build from this small set using only composition, pointwise addition and pointwise multiplication will also automatically work!

Let’s just do it. The code below automatically computes the derivative of

\[\begin{align*} f(x) = 3e^{\sin(x)} + \sin(x)^2 \end{align*}\]

at \(x=\pi\).

one_variable_autodiff.py
import typing as t
import numpy as np

from dual import DualNumber


eps = DualNumber(dual=1)


def dual_sin(x: DualNumber) -> DualNumber:
    return np.sin(x.real) + np.cos(x.real) * x.dual * eps


def dual_exp(x: DualNumber) -> DualNumber:
    return np.exp(x.real) + np.exp(x.real) * x.dual * eps


def my_func(x: DualNumber) -> DualNumber:
    return dual_exp(dual_sin(x)) * 3 + dual_sin(x) * dual_sin(x)


def compute_derivative(f: t.Callable[[DualNumber], DualNumber], x: float) -> float:
    return f(x + eps).dual


compute_derivative(my_func, np.pi)
-3.0000000000000013  

To Higher Dimensions

It would be nice if we were already almost at the end of this blog post, but alas, we are not done yet. We would be done if neural nets were functions \(\R\to\R\). But of course they are not. Neural nets are functions \(\R^n\to \R^m\) for some large \(n\) and some much smaller \(m\). We will make our definitions fast and without lingering on them too much. However, they should feel natural if one thinks about it for a bit. Note that any \(\bs x \in \mathbb{D}^n\) is of the form \(\bs a + \bs b\eps\) for some \(\bs a, \bs b \in \R^n\).

Definition.

Let \(f\c\R^n\to\R^m\) be an analytic function with \(n, m\ge 1\). The dual extension of \(f\) is the function \(\tilde f\c\mathbb{D}^n\to\mathbb{D}^m\) given by \[\begin{align*} \bs a + \bs b\eps \mapsto f(\bs a) + J_f(\bs a)\bs b\eps. \end{align*}\]

Here \(J_f(\bs a)\) denotes the Jacobian of \(f\) at \(\bs a\). The matrix-vector product \(J_f(\bs a)\bs b\) is called the Jacobian Vector Product, a term you will encounter frequently when reading about automatic differentiation. Note that there is a big difference between our dual extension for multivariate functions and the dual extension for single variable functions. In the single variable case, we automatically get the derivate value for any point \(x\) when evaluating on \(x+\eps\). In the multivariate case, this is not true. If we want to compute the full Jacobian \(J_f(\bs a)\) then we will need to evaluate the dual extension at \(\bs a + \bs e_i\eps\) for every \(1\le i \le n\). This is an extremely big problem if we want to use this method of automatic differentiation to train neural networks. For neural networks, \(n\) is the number of parameters of the network, which can be extremely big. This is the reason that forward mode AD is not used in deep learning. In fact this problem is so big that we will shortly see me jumping through some hoops, to get the training on MNIST doable in OK time. People doing real deep learning use Reverse mode AD, which needs \(m\) evaluations, a much, much smaller number for neural networks. This point is important, so let me highlight it.

Turns out the goal of this blogpost is ridiculous

Training neural nets with forward mode AD as described in this blog post is extremely ineffecient and only for educational and fun purposes. Do not do it for reals.

Of course, neural networks are not functions \(\R^n\to\R^m\) either. They are often actually functions \(\R^\bs{n}\to \R^{\bs m}\). Here \(\R^{\bs n}\) denotes the space of all tensors of shape \(\bs n\). Shit, I hear you thinking, should I now go and read that other long blogpost just linked. Well, if you are like me and would like to read an overly, unnecessarily formal treatment of tensors, you will hopelfully find happiness in the other blogpost. Otherwise, you can just pretend like tensors are normal vectors and that all the superscripts \(\bs n\) are not bold at all. In fact I can summarize the other blogpost in one sentence:

One Line Summary of This Post

Tensors are just multi dimensional arrays and their linear algebra works like you expect it to. You can just pretend that \(\R^{\bs n} = \R^n\) and most people do.

For those who were crazy enough to read the other post, and came back, let’s see the definition of the dual extension in the tensor case.

Definition.

Let \(f\c\R^{\bs n}\to\R^{\bs m}\) be an analytic function with \(\bs n\) a shape of length \(K\) and \(\bs m\) a shape of length \(L\). The dual extension of \(f\) is the function \(\tilde f\c\mathbb{D}^{\bs n}\to\mathbb{D}^{\bs m}\) given by \[\begin{align*} \bs{\mathcal{A}} + \bs{\mathcal{B}}\eps \mapsto f(\bs{\mathcal{A}}) + \mathcal{J}_f(\bs{\mathcal{A}})\ast_L\bs{\mathcal{B}}\eps. \end{align*}\]

Here \(\ast\) denotes the tensordot product and \(\mathcal{J}_f\) the Jacobian tensor.

Time to implement vectors (and tensors) of dual numbers in Python. For the components we will use NumPy arrays, which are vectors when np.ndarray(...).ndim == 1 and tensors otherwise.

dual/tensor.py
@commutative_addition
@commutative_multiplication
@subtract_using_negative
class DualTensor:
    def __init__(
        self,
        real: npt.ArrayLike | None = None,
        dual: npt.ArrayLike | None = None,
        dtype: npt.DTypeLike = np.float_,
    ) -> None:
        assert real is not None or dual is not None
        assert np.dtype(dtype=dtype).kind == "f"

        self.real = (
            np.asarray(real, dtype=dtype)
            if real is not None
            else np.zeros_like(dual, dtype=dtype)
        )
        self.dual = (
            np.asarray(dual, dtype=dtype)
            if dual is not None
            else np.zeros_like(real, dtype=dtype)
        )

        assert self.real.shape == self.dual.shape

    def __getitem__(self, item: t.Any) -> DualNumber:
        match item:
            case tuple():
                return DualNumber(real=self.real[item], dual=self.dual[item])
            case _:
                return NotImplemented

    def __add__(self, other: t.Any) -> "DualTensor":
        match other:
            case DualTensor() | DualNumber():
                return DualTensor(
                    real=self.real + other.real, dual=self.dual + other.dual
                )
            case float() | int() | np.ndarray():
                return DualTensor(real=self.real + other, dual=self.dual)
            case _:
                return NotImplemented

    def __mul__(self, other: t.Any) -> "DualTensor":
        match other:
            case DualTensor() | DualNumber():
                return DualTensor(
                    real=self.real * other.real,
                    dual=self.real * other.dual + other.real * self.dual,
                )
            case float() | int() | np.ndarray():
                return DualTensor(real=self.real * other, dual=self.dual * other)
            case _:
                return NotImplemented

    def __neg__(self) -> "DualTensor":
        return DualTensor(real=-self.real, other=-self.other)

    def __pow__(self, power, modulo=None) -> "DualTensor":
        match power:
            case int():
                return DualTensor(
                    real=self.real**power,
                    dual=self.dual * power * self.real ** (power - 1),
                )

    def __truediv__(self, other):
        match other:
            case DualTensor():
                inv = DualTensor(
                    real=1 / other.real, dual=-other.dual / (other.real**2)
                )
            case _:
                inv = 1 / other
        return self * inv

Note that we also implement division for tensors of dual numbers, even though we claimed this was impossible not long ago in this blog post. However for \(x, y \in \mathbb{D}\) we can think of \(x/y\) as \(x*\tilde f(y)\) where \(f\c a \mapsto 1/a\).

The only thing left to do now is to implement the dual versions of some primitive functions. As an example, we consider the simple case of functions \(f\c\R^n \to \R^n\) that are defined pointwise by a function \(g\c\R\to\R\), that is, \(f(\bs{x}) = (g(x_i))_{i=1}^n\). As examples, we can think of np.sin, np.exp, np.sqrt, etc. In this case the Jacobian vector product \(J_f(\bs{x})\bs{v}\) equals \((g\pr(x_i)v_i)_{i=1}^n\) (check this!).

import typing as t
import functools

import numpy as np

def dual_sin(x: DualTensor) -> DualTensor:
    return DualTensor(real=np.sin(x.real), dual=np.cos(x.real)*x.dual)

def jacobian_vector_product(f: t.Callable[[DualTensor], DualTensor], x: Tensor, vector: Tensor) -> Tensor:
    return f(DualTensor(real=x, dual=vector)).dual

We could do this in the way done before for single variable autodiff, by just creating new functions. However, it would be even more cool if we could make our DualTensor class interoprable with numpy. Luckily, numpy is actually very flexible in this regard. By defining a method __array_ufunc__ on a class, we can make numpy call this method when we try to call a numpy ufunc with as argument our own class. This is best illustrated with a minimal example:

import numpy as np


class Foo:
    def __array_ufunc__(self, ufunc: np.ufunc, method: str, *args, **kwargs) -> str:
        if ufunc is np.sin and method == "__call__":
            return "Wow Cool!"
        return NotImplemented


f = Foo()
np.sin(f)
'Wow Cool!'

Training a Neural Net on MNIST with Dual Numbers

Now that we can automatically differentiate functions \(\R^{\bs n} \to \R^{\bs m}\), we can use this to train neural networks. We will train a small neural net to recognize hand written digits in the MNSIT dataset, the “Hello World” of Deep Learning. We start by implementing the data structure of a Neural Network. The code below is quite straight forward. The only noteworthy piece is the initialization function of the weight matrix of a layer. Initialization of your weights is actually a more delicate issue then it has any right to be. When initialization your weight in a naive way there is a high probability of your gradient vanishing or exploding. The vanishing gradient problem, as it is often called, kicked me down a lot of times during my training, so I ended up implementing more sophistacated intialization techniques. Which technique to use, depends on the activation function of the layer. I implemented Xavier Initialization2 (for use with softmax) and He Initialization (for use with relu).

import typing as t
from enum import Enum
import numpy as np
import numpy.typing as npt

from dual.tensor import DualTensor


Tensor: t.TypeAlias = npt.NDArray[float] | DualTensor
LossFunction: t.TypeAlias = t.Callable[[Tensor, Tensor], float]


class Initialize(str, Enum):
    XAVIER = "Xavier"
    HE = "He"


class Layer:
    def __init__(
        self,
        model: "NeuralNetwork",
        n_neurons: int,
        activation_function: t.Callable,
        initialize: Initialize = "Xavier",
    ) -> None:
        self.model = model
        self.prev_layer = model.layers[-1] if model.layers else None
        if self.prev_layer is not None:
            self.prev_layer.next_layer = self
        self.next_layer = None
        self.n_neurons = n_neurons
        self.initialization = Initialize(initialize)
        self.weights = self.initial_weights()
        self.bias = np.zeros(shape=self.n_neurons)
        self.activation_function = activation_function

    def initial_weights(self) -> npt.NDArray[float]:
        shape = (self.n_neurons, self.input_size)
        match self.initialization:
            case Initialize.XAVIER:
                return np.random.uniform(low=-1, high=1, size=shape) * np.sqrt(
                    6 / (self.input_size + self.n_neurons)
                )
            case Initialize.HE:
                return np.random.normal(
                    loc=0.0,
                    scale=np.sqrt(2 / self.input_size),
                    size=shape,
                )

    @property
    def input_size(self) -> int:
        return self.prev_layer.n_neurons if self.prev_layer else self.model.input_size

    def compute_activation(
        self,
        x: Tensor,
        weights: Tensor | None = None,
        bias: Tensor | None = None,
    ) -> Tensor:
        weights = weights if weights is not None else self.weights
        bias = bias if bias is not None else self.bias
        return self.activation_function(weights @ x + bias)


class NeuralNetwork:
    def __init__(self, input_size: int, loss_function: LossFunction) -> None:
        self.layers: list[Layer] = []
        self.input_size = input_size
        self.loss_function = loss_function

    def add_layer(
        self, n_neurons: int, activation_function: t.Callable, initialize: Initialize
    ) -> Layer:
        layer = Layer(
            model=self,
            n_neurons=n_neurons,
            activation_function=activation_function,
            initialize=initialize,
        )
        self.layers.append(layer)
        return layer

    def __call__(self, x: Tensor) -> Tensor:
        result = x
        for layer in self.layers:
            result = layer.compute_activation(x=result)
        return result

    def compute_loss(self, x: Tensor, y: Tensor) -> float:
        return self.loss_function(self(x), y)

We will now add all the code for backpropagation. This code required more thinking. We can compute the gradient of the weight matrix of a layer as follows:

  • We are given an observation \((\bm x, y)\).
  • We compute the activation of the previous layer \(x\pr\).
  • For every position \(i, j\) in the weight matrix \(W\) of our layer, we do a forward pass through the remainder of the network with as weight matrix \(W + E_{i, j}\eps\). Here \(E_{i, j}\) is the matrix with all zeros, but a one on position \((i, j)\).
  • We end up with a dual number that is the loss of the network \(a + b\eps\).
  • The number \(b\) is now the value of our wanted gradient at positio \((i, j)\).

The gradient of a bias vector is computed in a similar way.

from dataclasses import dataclass
from datasets import Dataset


@dataclass
class Gradient:
    weights: Tensor
    bias: Tensor


class Layer:
    ...

    def push_forward(
        self,
        x: Tensor,
        y: Tensor,
        weights: Tensor | None = None,
        bias: Tensor | None = None,
    ) -> Tensor:
        x = self.compute_activation(x=x, weights=weights, bias=bias)
        if self.next_layer is not None:
            return self.next_layer.push_forward(x=x, y=y)

        return self.model.loss_function(x, y)

    def compute_gradient(self, x: Tensor, y: Tensor) -> Gradient:
        weights_gradient = np.zeros_like(self.weights)
        with np.nditer(self.weights, flags=["multi_index"]) as it:
            for _ in it:
                dual = np.zeros_like(self.weights)
                dual[it.multi_index] = 1
                dual_parameter = DualTensor(real=self.weights, dual=dual)

                weights_gradient[it.multi_index] = self.push_forward(
                    x=x, y=y, weights=dual_parameter
                ).dual

        bias_gradient = np.zeros_like(self.bias)
        for i in range(self.n_neurons):
            dual = np.zeros_like(self.bias)
            dual[i] = 1
            dual_parameter = DualTensor(real=self.bias, dual=dual)
            bias_gradient[i] = self.push_forward(x=x, y=y, bias=dual_parameter).dual

        return Gradient(weights=weights_gradient, bias=bias_gradient)

    def update_parameters(self, gradient: Gradient, learning_rate: float) -> None:
        self.weights = self.weights - learning_rate * gradient.weights
        self.bias = self.bias - learning_rate * gradient.bias


class NeuralNetwork:
    ...

    def compute_gradients(self, x: Tensor, y: Tensor) -> list[Gradient]:
        gradients = []
        for layer in self.layers:
            gradients.append(layer.compute_gradient(x=x, y=y))
            x = layer.compute_activation(x=x)
        return gradients

    def update_parameters(
        self, gradients: list[Gradient], learning_rate: float
    ) -> None:
        for i, gradient in enumerate(gradients):
            self.layers[i].update_parameters(gradient, learning_rate=learning_rate)

    def accuracy(self, data: Dataset) -> float:
        it = data.iter(batch_size=1)
        correct = 0
        for d in it:
            x = d["input"][0]
            y = d["label"][0]
            y_pred = np.argmax(self(x))
            correct += y_pred == np.argmax(y)
        return correct / data.num_rows

    def train(
        self,
        data: Dataset,
        epochs: int,
        learning_rate: float,
    ) -> None:
        for i in range(epochs):
            for datapoint in data.shuffle().iter(batch_size=1):
                x = datapoint["input"][0]
                y = datapoint["label"][0]
                gradients = self.compute_gradients(x=x, y=y)
                self.update_parameters(gradients=gradients, learning_rate=learning_rate)

We are now ready to test our code against a real dataset: MNIST. We start by preparing our data. Images in the MNIST dataset are 28 by 28 pixels, resulting in an input size of 728. Suppose that we start with a layer with 100 neurons. That means that our first layer needs has a weight matrix with 72.800 entries. To compute the gradient with respect to this matrix, we need 72.800 passes through our network! So yeah, the conclusion that this was a stupid idea, really was correct. I actually tried this, and computing one gradient took more then 2 seconds… Training on the MNIST training set would take around 33 hours. Even I thought this was a bit rediculous for training the on the Hello World dataset of Deep Learning. The solution is of course to make our network smaller, and the most efficient place to do that is by reducing the input size. So all images are resized to 12 by 12 pixels. Most digits I sampled were still easily recognizable.

import numpy as np
import numpy.typing as npt
from PIL import Image
from datasets import load_dataset
from albumentations import Resize
import matplotlib.pyplot as plt

from dual import NeuralNetwork

def transform_image(image: Image) -> npt.NDArray:
    image = np.array(image) # to numpy
    image = Resize(width=12, height=12)(image=image)["image"] # resize
    image = image / 255 # normalize
    image = image.flatten() # flatten
    return image

def one_hot_encode_label(label: int) -> int:
    result = np.zeros(10)
    result[label] = 1
    return result

def transform(batch):
    batch["input"] = [transform_image(image) for image in batch["input"]]
    batch["label"] = [one_hot_encode_label(label) for label in batch["label"]]
    return batch


ds = (
    load_dataset("mnist", keep_in_memory=True)
    .with_transform(transform=transform)
    .rename_column(original_column_name="image", new_column_name="input")
)

We now train a small network with two layers[^two-layer]. After training only one hour, the accuracy on the test set is already up to 85%.

plt.imshow(ds["train"][0]["input"].reshape(12, 12), cmap="grey")

Example of datapoint
def relu(x):
    return np.maximum(x, 0)

def cross_entropy(x, y):
    return -1*np.sum(y * np.log(x))

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))

model = NeuralNetwork(input_size=12*12, loss_function=cross_entropy)
model.add_layer(n_neurons=20, activation_function=relu, initialize="He")
model.add_layer(n_neurons=10, activation_function=softmax, initialize="Xavier")

model.train(data=ds["train"],
            learning_rate=0.1,
            epochs=1)

Footnotes

  1. The decorators on the DualNumber class are there so that we don’t have to implement both left and right addition, but can simply declare the operator to be commutative. This also holds for multiplication. Subtraction is simply defined by x-y = x + (-y). Code for the decorators can be found in dual.structure_decorators.↩︎

  2. Named after Xavier Glorot and sometimes also referred to a Glorot Initialization. Still, Xavier Initialization is the more prominant name, and this is probably the only time I have ever seen an idea named after someones first name.↩︎