A Tiny Deep Learning Framework from Scratch
Arslan Ashraf
In this post, we will present a tiny deep learning framework that can perform binary classification. This framework is built to very loosely follow the Tensorflow API.
The important idea is that we will employ reverse-mode auto differentiation [1] to calculate derivatives of any arbitrary variable with respect to any of its inputs or parameters. To facilitate an efficient method for doing so, we will build out a graph of the neural network function.
To build a graph representation of the function, each node in this graph will be a Python Variable object which will hold the numerical value for weights and biases. This object will also keep track of its parent nodes which are also Variable objects and the operation that was performed on its parent node(s) which yielded its own value.
The key part of this graph is that each node will have either one or two parent nodes. The variable will also have the functionality to calculate partial derivatives and an application of the chain rule.
For example, the following function is graphed below using graphviz after building a graph of a standard operation in a neural network. In the graph below, \( n = 2 \).
\[ f(W, b) = sigmoid \Big( \sum_{i = 1}^{n} w_i x_i + b \Big) \]
To efficiently perform backpropagation, we need to keep track of which node has a dependency on which other nodes. For this, we will perform a topological sort algorithm on the model graph.
Topological Sort
The image below is the topologically sorted list of nodes in the graph above.
To test that this framework works as intended, we will train a small model to perform binary classification on the Iris dataset. We will only select data for two flower types, setosa and versicolor.
For demonstration purposes, all we have to do is run python main.py in the base directory of the Github repo and a model is trained on the Iris dataset. The images below show the model training loss over 30 epochs.
References
[1] Murphy, Kevin (2022). Probabilistic Machine Learning: An Introduction
[2] Karpathy, Andrej. https://www.youtube.com/watch?v=VMj-3S1tku0
[3] https://www.youtube.com/watch?v=dB-u77Y5a6A