Fundamentals of Statistics¶
Statistics provides the language for describing data and quantifying uncertainty. This file covers distributions, random variables, PMFs, PDFs, CDFs, expectation, variance, moments, and the central limit theorem -- the concepts that underpin every ML evaluation metric and loss function.
-
Statistics is the science of learning from data. You collect observations, summarise them, and draw conclusions, often about things you cannot measure directly.
-
Imagine you want to know the average height of every adult in a country. You cannot measure everyone, so you measure a sample and use statistics to make an informed guess about the whole population.
-
There are two main branches:
- Descriptive statistics: summarising data you already have (averages, charts, tables)
- Inferential statistics: using a sample to make claims about a larger group
-
The building block of statistics is the distribution, a description of how values are spread out. Everything else, averages, tests, predictions, flows from understanding distributions.
-
A frequency distribution counts how often each value (or range of values) appears in your data. Think of sorting exam scores into bins and counting how many students fall in each bin. The result is a histogram.
-
A probability distribution replaces raw counts with probabilities. Instead of "12 students scored between 70 and 80," it says "there is a 0.24 probability of scoring between 70 and 80." The histogram bars become a smooth curve when the data is continuous.
-
The histogram on the left is built from actual data you collected. The smooth curve on the right is a mathematical model that describes the pattern behind the data. One is empirical, the other is theoretical.
-
To work with distributions mathematically, we need a way to assign numbers to outcomes. That is exactly what a random variable does.
-
A random variable is a function that maps each outcome of an experiment to a real number. Flip a coin: the outcome is "heads" or "tails," but a random variable \(X\) converts this to \(X(\text{heads}) = 1\) and \(X(\text{tails}) = 0\). Now we can do arithmetic.
-
A discrete random variable takes on a countable set of values: the number of heads in 10 flips, the roll of a die, the number of emails you receive in an hour.
-
A continuous random variable can take any value in an interval: your exact height, the time until the next bus arrives, the temperature at noon.
-
The distinction matters because it changes how we compute probabilities. For discrete variables, we sum. For continuous variables, we integrate (recall integrals from Chapter 3).
-
For a discrete random variable, the probability mass function (PMF) gives the probability of each specific value:
- For a continuous random variable, the probability density function (PDF) gives the probability of falling within a range. The probability of any single exact value is zero; only intervals have positive probability:
-
Now that we can assign numbers to outcomes, the most natural question is: what value do we expect on average?
-
Expectation (or expected value) is the weighted average of all possible values, where the weights are the probabilities. Think of it as the "centre of gravity" of the distribution.
-
If you roll a fair die many times, your average roll converges to 3.5. That is the expected value, even though you can never actually roll a 3.5.
-
For a discrete random variable:
- For a continuous random variable (using the integral from Chapter 3):
- Example: a fair six-sided die has \(p(x) = 1/6\) for \(x = 1, 2, 3, 4, 5, 6\).
-
Expectation is linear, meaning \(E[aX + b] = aE[X] + b\). This property is extremely useful and shows up constantly in ML loss functions.
-
Expectation tells us the centre, but it says nothing about how spread out the values are. To describe the full shape of a distribution, we need moments.
-
A moment is an expectation of a power of \(X\). The \(k\)-th raw moment is:
-
The first raw moment (\(k = 1\)) is just the mean: \(\mu_1' = E[X] = \mu\).
-
Raw moments are measured from zero. Often we care about deviation from the mean instead. The \(k\)-th central moment centres the measurement:
-
The first central moment is always zero (deviations above and below the mean cancel). The second central moment is the variance.
-
To compare distributions on different scales, we standardise by dividing by the appropriate power of the standard deviation \(\sigma\):
- Each moment captures a different aspect of the distribution's shape:
- 1st moment (Mean): Where the distribution is centred. The balance point.
- 2nd moment (Variance): How spread out values are around the mean. Higher variance means wider.
- 3rd moment (Skewness): Whether the distribution leans left or right. Zero skewness means symmetric.
-
4th moment (Kurtosis): How heavy the tails are. Higher kurtosis means more extreme outliers.
-
Let us work through all four moments for a concrete dataset: \(X = \{2, 4, 4, 4, 5, 5, 7, 9\}\).
-
Step 1: Mean (1st raw moment)
- Step 2: Variance (2nd central moment). Subtract the mean from each value, square, then average:
-
The standard deviation is \(\sigma = \sqrt{4} = 2\).
-
Step 3: Skewness (standardised 3rd central moment). Cube the deviations, average, divide by \(\sigma^3\):
-
Positive skewness means the right tail is longer, which makes sense since 9 is far above the mean.
-
Step 4: Kurtosis (standardised 4th central moment). Raise deviations to the 4th power:
- A normal distribution has kurtosis of 3 (called "mesokurtic"). Our value of 2.781 is close, suggesting the tails are roughly normal. Values above 3 ("leptokurtic") signal heavier tails; below 3 ("platykurtic") signal lighter tails. Some formulas report excess kurtosis by subtracting 3, so our excess kurtosis would be \(-0.219\).
Coding Tasks (use CoLab or notebook)¶
-
Compute the expected value of a loaded die where face 6 has probability 0.3 and all other faces share the remaining probability equally. Verify by simulating 100,000 rolls.
import jax import jax.numpy as jnp # Loaded die: face 6 has p=0.3, others share 0.7 equally probs = jnp.array([0.14, 0.14, 0.14, 0.14, 0.14, 0.30]) faces = jnp.array([1, 2, 3, 4, 5, 6]) # Analytical expected value ev = jnp.sum(faces * probs) print(f"Expected value (formula): {ev:.4f}") # Simulation key = jax.random.PRNGKey(42) rolls = jax.random.choice(key, faces, shape=(100_000,), p=probs) print(f"Expected value (simulation): {rolls.mean():.4f}") -
Compute all four moments (mean, variance, skewness, kurtosis) for the dataset from the worked example, then modify the data and observe how each moment changes.
import jax.numpy as jnp x = jnp.array([2, 4, 4, 4, 5, 5, 7, 9], dtype=jnp.float32) mean = jnp.mean(x) variance = jnp.mean((x - mean) ** 2) std = jnp.sqrt(variance) skewness = jnp.mean(((x - mean) / std) ** 3) kurtosis = jnp.mean(((x - mean) / std) ** 4) print(f"Mean: {mean:.3f}") print(f"Variance: {variance:.3f}") print(f"Std Dev: {std:.3f}") print(f"Skewness: {skewness:.3f}") print(f"Kurtosis: {kurtosis:.3f}") print(f"Excess K: {kurtosis - 3:.3f}") -
Visualise a PMF and CDF side by side for a fair die roll. Try changing the probabilities to see how the shapes shift.
import jax.numpy as jnp import matplotlib.pyplot as plt faces = jnp.array([1, 2, 3, 4, 5, 6]) pmf = jnp.ones(6) / 6 # fair die; try changing these! cdf = jnp.cumsum(pmf) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) ax1.bar(faces, pmf, color="#3498db", alpha=0.8) ax1.set_title("PMF") ax1.set_xlabel("Face") ax1.set_ylabel("P(X = x)") ax1.set_ylim(0, 0.5) ax2.step(faces, cdf, where="mid", color="#e74c3c", linewidth=2) ax2.set_title("CDF") ax2.set_xlabel("Face") ax2.set_ylabel("P(X ≤ x)") ax2.set_ylim(0, 1.1) plt.tight_layout() plt.show()