Motivation
When I told my family that I was "training a neural network", they looked at me with a blank stare and asked what that even entails. I realized that many concepts in AI have become buzzwords and many people may not really understand what it all means. This post is meant for people like my family, to cut through the buzzword noise and understand what is actually happening, to some extent. Once you go under the hood a bit, AI might not be so scary after all.
Introduction
What exactly does it mean to train a neural network or AI model? I'll attempt to explain it as simply as I can in this section. We can think of a neural network as a function \(f\), that maps inputs \(x\) to outputs \(y\).
If you think back to Algebra 1, such a function might look like
Here, \(m\) and \(b\) would be the parameters of the model. If we wanted to train a model with this architecture, where architecture here simply means the exact mathematical formula of our neural network, then that boils down to finding the \(m\) and \(b\) such that when we plug in inputs \(x\), we get outputs \(\hat{y}\) (pronounced "y-hat") that most closely match the actual ground truth outputs \(y\) (In ML, ground truth means the "correct" or "target" value; the value that we ideally want our model to predict. \(\hat{y}\) is the value our model actually predicts, which may or may not be accurate).
To recap, let's rephrase the buzzwordy sounding phrase "training a neural network" into what is actually happening: we are looking for the parameters, \(\theta\), of our model \(f\) that best predict the ground truth output, \(y\), on our input \(x\). Huh, maybe that isn't any less buzzwordy? Hopefully it makes sense.
Now, how do we find those optimal parameters, like \(m\) and \(b\) from earlier? If you think back to Calculus 1, you might remember using the first derivative test to find the minima or maxima of a function. Earlier, we said that we want to find the parameters that best predict the output. That kind of sounds like an optimization problem. Maybe if we rigorously define what "best" means as a function of the parameters, we can then use calculus techniques to find the parameters that optimize the "best" function.
Loss Functions
Introducing the concept of a loss function. Let's say the output of our neural network is
The ground truth output value is \(y\). Then, we can define \(L(\hat{y}, y)\) to be the distance between our model's output and the actual target output. In machine learning, this concept is usually called a loss function, objective function, cost function, or error function.
Note that
so we now have a function of our neural network parameters that measures how good our neural network is. With calculus, we can optimize this function. Usually, the loss/objective/cost/error function measures how wrong our model is. Therefore, we usually want to minimize this function (So that our model is minimally wrong, i.e. right). As long as we construct our neural network architecture and loss function appropriately, we should be able to take the gradient, or derivative, of the loss function with respect to the neural network parameters and see where the gradient equals zero:
This is where a minimum or maximum would occur. Gradient descent is an algorithm that does this procedure iteratively. Usually, the neural networks and loss functions we choose in practice are too complex to be solved in closed form, so an iterative solution is needed to find the optimal parameters.
Training Recap
So, to recap again, "training a neural network" simply means:
- Define a mathematical formula for your neural network, something like \(f(x) = mx + b\), but usually more complicated to capture more complex, nonlinear relationships.
- Define a loss function that can tell you how accurate your neural network is. A simple one might be \(y - \hat{y}\), which is just the difference between the target value and your network's predicted value. Ideally, minimizing or maximizing this loss function should lead you to the optimal parameters of your neural network.
- Use Calculus 1 techniques, like gradient descent, to find the parameters that minimize/maximize the loss/objective function. This requires taking the derivative of the loss function with respect to the parameters of your network. This step is commonly known as backpropagation. Importantly, your loss function and neural network should be differentiable functions.