Optimisation¶
Optimisation is the mathematical core of model training -- finding the parameters that minimise a loss function. This file covers critical points, convexity, gradient descent, Newton's method, constrained optimisation with Lagrange multipliers, and the optimisers (SGD, Adam) that power modern deep learning.
-
Training a neural network, fitting a regression line, tuning hyperparameters: at the core of almost every ML algorithm is an optimisation problem.
-
We have some function (a loss, a cost, an objective) and we want to find the inputs that make it as small (or large) as possible.
-
Before optimising, we need to understand zeros (or roots) of functions. A zero of \(f(x)\) is a value \(x\) where \(f(x) = 0\). Graphically, these are the x-intercepts.
-
For example, \(f(x) = x^2 - 3x + 2 = (x-1)(x-2)\) has zeros at \(x = 1\) and \(x = 2\). Between the zeros, the function is negative (\(f(1.5) = -0.25\)); outside the zeros, it is positive. The zeros divide the number line into regions where the function has constant sign.
-
The multiplicity of a zero is how many times the corresponding factor appears.
-
At a simple zero (multiplicity 1), the graph crosses the x-axis. At a double zero (multiplicity 2), the graph touches the x-axis but bounces back without crossing, appearing "flat" at that point.
-
Finding zeros matters because the zeros of the derivative \(f'(x)\) are the critical points of \(f(x)\), the candidates for maxima and minima.
-
At a maximum or minimum, the tangent line is flat (slope = 0), so \(f'(x) = 0\).
-
But not every critical point is a maximum or minimum. A point where \(f'(x) = 0\) could also be an inflection point (like \(x = 0\) for \(f(x) = x^3\)), where the function flattens momentarily but does not change direction.
-
The second derivative test resolves this. At a critical point \(x = c\) where \(f'(c) = 0\):
- If \(f''(c) > 0\): the curve is concave up (like a bowl), so \(c\) is a local minimum.
- If \(f''(c) < 0\): the curve is concave down (like a hill), so \(c\) is a local maximum.
- If \(f''(c) = 0\): the test is inconclusive; higher derivatives or other methods are needed.
-
For example, \(f(x) = x^3 - 3x\). The derivative is \(f'(x) = 3x^2 - 3 = 3(x-1)(x+1)\), so critical points are at \(x = -1\) and \(x = 1\). The second derivative is \(f''(x) = 6x\). At \(x = -1\): \(f''(-1) = -6 < 0\) (local max). At \(x = 1\): \(f''(1) = 6 > 0\) (local min).
-
A function is convex if the line segment between any two points on its graph lies above (or on) the graph. Think of it as a bowl shape, curving upward everywhere. Mathematically, \(f\) is convex if \(f''(x) \geq 0\) for all \(x\).
-
Convexity is powerful because convex functions have a remarkable property: every local minimum is also the global minimum. There are no deceptive local valleys to get trapped in. If you roll a ball into a convex bowl, it will always reach the bottom.
-
A function is concave (curving downward) if \(-f\) is convex. Points where the function transitions between concave and convex are inflection points, occurring where \(f''(x) = 0\).
-
Newton's method finds zeros of functions (and by extension, critical points of their derivatives) using tangent lines. Starting from an initial guess \(x_0\), it iteratively refines:
-
The idea: at \(x_n\), draw the tangent line and find where it crosses the x-axis. That crossing point becomes \(x_{n+1}\). For well-behaved functions with a good starting point, Newton's method converges very quickly (quadratically, meaning the number of correct digits roughly doubles each step).
-
For example, to find \(\sqrt{5}\) (a zero of \(f(x) = x^2 - 5\)): \(f'(x) = 2x\), so \(x_{n+1} = x_n - \frac{x_n^2 - 5}{2x_n}\). Starting at \(x_0 = 2\): \(x_1 = 2.25\), \(x_2 = 2.2361\ldots\), which is already accurate to four decimal places.
-
Newton's method can fail if the initial guess is far from the root, if \(f'(x) = 0\) near the root, or if the function has inflection points nearby. It also requires computing the derivative, which may be expensive.
-
For optimisation (finding minima instead of zeros), we apply Newton's method to \(f'(x) = 0\), which gives the update:
-
In multiple dimensions, this becomes \(\mathbf{x}_{n+1} = \mathbf{x}_n - H^{-1} \nabla f(\mathbf{x}_n)\), where \(H\) is the Hessian matrix. This is the second-order Taylor approximation from the previous file in action: approximate the function as a quadratic, jump to the minimum of that quadratic, repeat.
-
Lagrange multipliers solve constrained optimisation: find the optimum of \(f(x, y)\) subject to a constraint \(g(x, y) = c\). Instead of searching all of \(\mathbb{R}^n\), we are restricted to the set where the constraint holds (a curve or surface).
-
The key insight is geometric: at the constrained optimum, the gradient of \(f\) must be parallel to the gradient of \(g\). If they were not parallel, we could move along the constraint in a direction that still improves \(f\), so we would not be at the optimum yet.
-
We introduce a new variable \(\lambda\) (the Lagrange multiplier) and define the Lagrangian:
- Setting all partial derivatives to zero gives a system of equations whose solutions are the constrained optima:
- For example, maximise \(f(x,y) = x^2 y\) subject to \(x^2 + y^2 = 1\). The Lagrangian is \(\mathcal{L} = x^2 y - \lambda(x^2 + y^2 - 1)\). Taking partials:
-
From the first equation (assuming \(x \neq 0\)): \(\lambda = y\). Substituting into the second: \(x^2 = 2y^2\). Combined with the constraint: \(2y^2 + y^2 = 1\), so \(y = \frac{1}{\sqrt{3}}\). The maximum value is \(f = \frac{2}{3\sqrt{3}}\).
-
For inequality constraints (\(g(x,y) \leq c\) instead of \(= c\)), the Karush-Kuhn-Tucker (KKT) conditions generalise Lagrange multipliers. The constraint is either active (binding, treated as equality) or inactive (the solution lies in the interior and the constraint is irrelevant).
-
In practice, we rarely optimise by hand. Here are the main algorithmic families:
-
First-order methods (use only gradient): gradient descent, stochastic gradient descent (SGD), Adam. These are cheap per step but can converge slowly, especially on ill-conditioned problems.
-
Second-order methods (use gradient and Hessian): Newton's method converges fast but computing and inverting the Hessian is expensive (\(O(n^3)\) for \(n\) parameters). Quasi-Newton methods (like BFGS and L-BFGS) approximate the Hessian using only gradient information, achieving faster convergence than first-order methods without the full cost of second-order methods.
-
Conjugate gradient: efficient for large sparse systems, using only matrix-vector products instead of storing the full Hessian.
-
Gauss-Newton and Levenberg-Marquardt: specialised for least-squares problems (common in regression), approximating the Hessian via the Jacobian.
-
Natural gradient descent: accounts for the geometry of the parameter space using the Fisher information matrix, which can be more effective for probabilistic models.
-
-
The choice of optimiser depends on the problem. For deep learning, first-order methods (especially Adam) dominate because the number of parameters is enormous (millions to billions), making Hessian computation impractical. For smaller problems with smooth objectives, second-order methods can be dramatically faster.
Coding Tasks (use CoLab or notebook)¶
-
Implement Newton's method to find \(\sqrt{7}\) (a zero of \(f(x) = x^2 - 7\)). Observe the rapid convergence.
-
Use gradient descent to minimise \(f(x, y) = (x - 3)^2 + (y + 1)^2\). The minimum is at \((3, -1)\). Experiment with different learning rates.
import jax import jax.numpy as jnp def f(params): x, y = params return (x - 3)**2 + (y + 1)**2 grad_f = jax.grad(f) params = jnp.array([0.0, 0.0]) lr = 0.1 for i in range(20): g = grad_f(params) params = params - lr * g if i % 5 == 0 or i == 19: print(f"step {i:2d}: ({params[0]:.4f}, {params[1]:.4f}) loss={f(params):.6f}") -
Solve a constrained optimisation problem numerically. Maximise \(f(x,y) = xy\) subject to \(x + y = 10\) by parameterising \(y = 10 - x\) and finding the optimum of the single-variable function.
import jax import jax.numpy as jnp # Substitute constraint: y = 10 - x, so f = x(10 - x) = 10x - x² f = lambda x: x * (10 - x) df = jax.grad(f) # Gradient ascent (we want maximum, so add gradient) x = 1.0 lr = 0.1 for i in range(20): x = x + lr * df(x) print(f"x={x:.4f}, y={10-x:.4f}, f={f(x):.4f}") # should be x=5, y=5, f=25