Differential Calculus¶
Differential calculus captures instantaneous rates of change. This file covers limits, derivatives, differentiation rules, the chain rule (the foundation of backpropagation), and common derivatives used throughout ML.
-
In the previous chapters, we learned how to represent data as vectors and transform it with matrices. But many real-world phenomena are not static. A car accelerates, a stock price fluctuates, a neural network's loss changes as weights update. Calculus is the mathematics of change.
-
Calculus asks two questions: how fast is something changing right now? (differential calculus) and how much has it accumulated over time? (integral calculus). This section tackles the "how fast" question.
-
Imagine you are driving and glance at your speedometer. It reads 60 km/h. That number is not the average speed of your entire trip; it is your speed at this exact instant. Differential calculus gives us the tools to compute such instantaneous rates of change.
-
But first, let us revisit the equation of a straight line: \(y = mx + b\).
-
This is the simplest relationship between two quantities.
- \(b\) is the y-intercept, where the line crosses the y-axis (the starting value when \(x = 0\)).
- \(m\) is the slope, the rate of change: for every 1 unit increase in \(x\), \(y\) changes by \(m\).
- If \(m = 3\), the line rises steeply; if \(m = 0\), the line is flat; if \(m = -2\), the line falls.
-
The slope is computed as \(m = \frac{\Delta y}{\Delta x} = \frac{y_2 - y_1}{x_2 - x_1}\), the ratio of "how much did \(y\) change" to "how much did \(x\) change."
-
Once you know \(m\) and \(b\), you can compute \(y\) for any \(x\).
-
For example, if \(m = 2\) and \(b = 3\), then at \(x = 5\): \(y = 2(5) + 3 = 13\).
-
The two parameters fully determine the line, and predicting any output is just plugging in.
-
For a straight line, the slope is the same everywhere.
-
This idea generalises beyond lines. Any function is a rule that maps inputs to outputs, and once you know its formula (its parameters and shape), you can compute the output for any input and plot the result.
-
\(y = x^2\) gives a parabola, \(y = \sin(x)\) gives a wave, \(y = e^x\) gives exponential growth. Each formula defines a specific curve, and being comfortable reading a function as a shape is essential for everything that follows.
-
For a straight line, the slope is the same everywhere. But most interesting functions are curved, so the slope varies from point to point. Calculus gives us a way to find the slope at any single point on a curve.
-
We also need the concept of a limit. A limit describes what value a function approaches as its input gets closer and closer to some target, without necessarily reaching it.
-
This reads: "as \(x\) approaches \(a\), \(f(x)\) approaches \(L\)." The function does not need to actually equal \(L\) at \(x = a\). It just needs to get arbitrarily close.
-
For example, take \(f(x) = \frac{x^2 - 1}{x - 1}\). If you plug in \(x = 1\) directly, you get \(\frac{0}{0}\), which is undefined.
-
But try values close to 1: \(f(0.9) = 1.9\), \(f(0.99) = 1.99\), \(f(1.01) = 2.01\). The outputs are clearly heading towards 2.
-
Algebraically, we can see why: factor the numerator as \((x-1)(x+1)\), cancel the \((x-1)\) terms, and we get \(f(x) = x + 1\) for all \(x \neq 1\). So as \(x \to 1\), \(f(x) \to 2\).
-
The function has a hole at \(x = 1\), but the limit still exists.
-
Limits are the foundation that everything else in calculus rests on.
-
The derivative of a function \(f(x)\) at a point \(x = a\) measures the instantaneous rate of change. Geometrically, it is the slope of the tangent line to the curve at that point.
- To compute this slope, we start with two points on the curve and compute the slope of the line through them (a secant line). Then we slide the second point closer and closer to the first, and see what slope the secant line approaches. This is the difference quotient:
-
The numerator \(f(a+h) - f(a)\) is the change in output. The denominator \(h\) is the change in input. Their ratio is the average rate of change over a tiny interval. As \(h \to 0\), this average becomes the instantaneous rate.
-
For example, let \(f(x) = x^2\). At \(x = 3\):
-
So at \(x = 3\), the function \(x^2\) is increasing at a rate of 6 units of output per unit of input.
-
A function is differentiable at a point if this limit exists. For that to happen, the function must be continuous (no jumps), smooth (no sharp corners), and defined in a neighbourhood around the point.
-
If you can draw the curve without lifting your pen and without any kinks, it is probably differentiable there.
-
Computing derivatives from the limit definition every time would be tedious. Fortunately, a handful of rules let us differentiate almost any function quickly.
-
Constant rule: the derivative of a constant is zero. If \(f(x) = 5\), then \(f'(x) = 0\). A flat line has zero slope.
-
Power rule: the workhorse of differentiation. Bring the exponent down and reduce it by one:
-
For example: \(\frac{d}{dx} x^3 = 3x^2\). The cubic becomes a quadratic. This works for any real exponent, including negatives and fractions: \(\frac{d}{dx} x^{-1} = -x^{-2}\) and \(\frac{d}{dx} \sqrt{x} = \frac{d}{dx} x^{1/2} = \frac{1}{2}x^{-1/2}\).
-
Sum/Difference rule: differentiate term by term.
- Product rule: when two functions are multiplied, the derivative is not simply the product of the derivatives. Instead:
-
Think of it as: "the rate of change of the first times the second, plus the first times the rate of change of the second." For example, \(\frac{d}{dx}[x^2 \sin x] = 2x \sin x + x^2 \cos x\).
-
Quotient rule: for a ratio of functions:
-
A useful mnemonic: "low d-high minus high d-low, over the square of what's below."
-
Chain rule: the most important rule for ML. When functions are composed (one inside another), the derivative is the product of the derivatives along the chain:
- Think of it as peeling an onion. Differentiate the outer function (keeping the inner function untouched), then multiply by the derivative of the inner function.
-
For example, \(\frac{d}{dx} (3x + 1)^5 = 5(3x+1)^4 \cdot 3 = 15(3x+1)^4\). The outer function is \((\cdot)^5\) and the inner is \(3x+1\).
-
The chain rule is the mathematical foundation of backpropagation in neural networks. A deep network is a long chain of composed functions. To compute how the loss changes with respect to each weight, we apply the chain rule repeatedly from the output layer back to the input, multiplying local derivatives at each step.
-
Here are the most common derivatives you will encounter. Each one can be derived from the limit definition, but knowing them by heart saves time:
| Function | Derivative | Notes |
|---|---|---|
| \(e^x\) | \(e^x\) | The only function that is its own derivative |
| \(a^x\) | \(a^x \ln a\) | Generalises the exponential |
| \(\ln x\) | \(\frac{1}{x}\) | The natural logarithm |
| \(\log_a x\) | \(\frac{1}{x \ln a}\) | General logarithm |
| \(\sin x\) | \(\cos x\) | |
| \(\cos x\) | \(-\sin x\) | Note the negative sign |
| \(\tan x\) | \(\sec^2 x\) |
-
The exponential function \(e^x\) is remarkable: it is the only function that equals its own derivative. This is why \(e\) appears everywhere in ML, from softmax activations to probability distributions.
-
L'Hopital's Rule handles limits that produce indeterminate forms like \(\frac{0}{0}\) or \(\frac{\infty}{\infty}\). When direct substitution gives one of these forms, you can take the derivative of the numerator and denominator separately and try the limit again:
-
Conditions: both \(f\) and \(g\) must be differentiable near \(a\), and \(g'(x) \neq 0\) near \(a\) (except possibly at \(a\) itself). The original limit must give an indeterminate form.
-
For example: \(\lim_{x \to 0} \frac{\sin x}{x}\). Direct substitution gives \(\frac{0}{0}\). Applying L'Hopital's Rule: \(\lim_{x \to 0} \frac{\cos x}{1} = 1\). This limit is fundamental, it appears in signal processing and Fourier analysis.
-
You can apply the rule repeatedly if the result is still indeterminate. For instance, \(\lim_{x \to 0} \frac{1 - \cos x}{x^2}\) gives \(\frac{0}{0}\). First application: \(\lim_{x \to 0} \frac{\sin x}{2x}\), still \(\frac{0}{0}\). Second application: \(\lim_{x \to 0} \frac{\cos x}{2} = \frac{1}{2}\).
-
If two functions are differentiable, their sum, difference, product, composition, and quotient (where the denominator is non-zero) are also differentiable. This is why we can confidently differentiate complex expressions built from simple pieces.
Coding Tasks (use CoLab or notebook)¶
-
Visualise common functions. Plot \(x^2\), \(\sin(x)\), and \(e^x\) side by side to build intuition for how different formulas produce different shapes. Try changing parameters (e.g. \(2x^2\), \(\sin(2x)\)) and observe how the curves change.
import jax.numpy as jnp import matplotlib.pyplot as plt x = jnp.linspace(-3, 3, 300) fig, axes = plt.subplots(1, 3, figsize=(12, 3)) axes[0].plot(x, x**2, color="#e74c3c") axes[0].set_title("x² (parabola)") axes[1].plot(x, jnp.sin(x), color="#3498db") axes[1].set_title("sin(x) (wave)") axes[2].plot(x, jnp.exp(x), color="#27ae60") axes[2].set_title("eˣ (exponential)") for ax in axes: ax.axhline(0, color="gray", linewidth=0.5) ax.axvline(0, color="gray", linewidth=0.5) plt.tight_layout() plt.show() -
Use JAX's automatic differentiation to compute the derivative of \(f(x) = x^3 - 2x + 1\) at several points. Compare with the analytical derivative \(f'(x) = 3x^2 - 2\).
-
Verify the chain rule numerically. Define \(f(x) = \sin(x^2)\), compute its derivative via
jax.grad, and compare with the analytical result \(2x\cos(x^2)\). -
Visualise the derivative. Plot \(f(x) = x^3 - 3x\) and its derivative \(f'(x) = 3x^2 - 3\) on the same graph. Notice where \(f'(x) = 0\) corresponds to the peaks and valleys of \(f\).
import jax import jax.numpy as jnp import matplotlib.pyplot as plt f = lambda x: x**3 - 3*x # jax.grad works on scalars; jax.vmap vectorises it to operate on an array of inputs at once df = jax.vmap(jax.grad(f)) x = jnp.linspace(-2.5, 2.5, 200) plt.plot(x, jax.vmap(f)(x), label="f(x)") plt.plot(x, df(x), label="f'(x)", linestyle="--") plt.axhline(0, color="gray", linewidth=0.5) plt.legend() plt.title("A function and its derivative") plt.show()