Statistical Inference¶
Statistical inference goes beyond yes/no decisions to estimate population parameters with quantified uncertainty. This file covers confidence intervals, point and interval estimation, maximum likelihood estimation, the method of moments, and regression analysis -- the bridge between raw data and predictive models in ML.
-
Hypothesis testing gives you a yes/no decision: reject or fail to reject. But often you want something more informative, a range of plausible values for the parameter you are estimating. That is what confidence intervals provide.
-
A point estimate is a single number computed from your sample, like the sample mean \(\bar{x}\). It is your best guess for the population parameter, but on its own it gives no sense of how precise the estimate is.
-
A confidence interval wraps that point estimate with a range that reflects uncertainty. It takes the form:
- The margin of error (ME) depends on three things: how confident you want to be, how much variability is in the data, and how large your sample is:
- Here \(z^\ast\) is the critical value from the normal distribution that matches your desired confidence level. For 95% confidence, \(z^\ast = 1.96\). For 99% confidence, \(z^\ast = 2.576\).
-
A 95% confidence interval means: if you repeated the experiment many times and built an interval each time, about 95% of those intervals would contain the true population parameter. It does not mean there is a 95% probability the parameter is in this specific interval. The parameter is fixed; the intervals are what vary.
-
Worked example: You measure the heights of 50 people and find \(\bar{x} = 170\) cm with \(\sigma = 8\) cm. Construct a 95% confidence interval.
-
You can say with 95% confidence that the true mean height lies between 167.78 and 172.22 cm.
-
When \(\sigma\) is unknown (the usual case), use the sample standard deviation \(s\) and the t-distribution instead:
-
Wider intervals are more confident but less precise. Narrower intervals are more precise but less confident. You can narrow an interval without losing confidence by increasing the sample size.
-
Power analysis helps you plan an experiment before you run it. The question is: how large a sample do I need to detect an effect of a given size with a specified power?
-
Recall from the previous file that power = \(1 - \beta\), the probability of correctly rejecting a false \(H_0\). A common target is 80% power.
-
The required sample size for a z-test detecting a difference \(\delta\) with significance \(\alpha\) and power \(1-\beta\) is:
- For example, to detect a 2 cm difference in mean height (\(\sigma = 8\)) with \(\alpha = 0.05\) and 80% power (\(z_{0.025} = 1.96\), \(z_{0.20} = 0.84\)):
-
You would need about 126 people per group.
-
Power analysis prevents two common mistakes: running an experiment too small to detect a real effect (underpowered), or wasting resources on an experiment far larger than necessary (overpowered).
-
Monte Carlo methods use random sampling to solve problems that are difficult or impossible to solve analytically. The core idea: if you cannot compute something exactly, simulate it many times and use the results as an approximation.
-
The name comes from the Monte Carlo casino, a nod to the role of randomness. These methods are workhorses in ML for tasks like estimating integrals, evaluating model uncertainty, and approximating complex distributions.
-
The general Monte Carlo recipe:
- Define a domain of possible inputs
- Generate random inputs from that domain
- Evaluate a function on each input
- Aggregate the results (average, count, etc.)
-
A classic example is estimating \(\pi\). Imagine a square with side length 2, centred at the origin, with a circle of radius 1 inscribed inside it. The area of the square is 4, and the area of the circle is \(\pi\).
- Drop random points uniformly in the square. The fraction that land inside the circle approximates \(\pi/4\):
-
A point \((x, y)\) is inside the circle if \(x^2 + y^2 \le 1\). The more points you throw, the closer your estimate gets to the true value of \(\pi\).
-
In ML, Monte Carlo methods appear in:
- Monte Carlo dropout: run inference multiple times with dropout enabled to estimate prediction uncertainty
- MCMC (Markov Chain Monte Carlo): sample from complex posterior distributions in Bayesian models
- Policy gradient methods: estimate gradients in reinforcement learning by sampling trajectories
-
Factor analysis is a technique for discovering hidden (latent) variables that explain the correlations among observed variables. If 10 personality survey questions can be explained by 3 underlying traits (extraversion, agreeableness, conscientiousness), factor analysis finds those traits.
-
The model assumes each observed variable \(x_i\) is a linear combination of a few latent factors \(f_j\) plus noise:
-
The \(\lambda\) values are called factor loadings and tell you how strongly each observed variable relates to each factor. This connects directly to the matrix decompositions from Chapter 2; factor analysis is closely related to eigenvalue decomposition and SVD.
-
Experimental design is the art of structuring an experiment so that you can draw valid conclusions. Poor design can make even a large dataset useless.
-
Key components of a well-designed experiment:
- Independent variable (IV): what you manipulate (e.g. drug dose, model architecture)
- Dependent variable (DV): what you measure (e.g. recovery time, accuracy)
- Control group: receives no treatment (or a placebo), providing a baseline for comparison
- Random assignment: participants are assigned to groups randomly, which balances out confounding variables you did not measure
-
Common experimental designs:
- Completely randomised design: subjects are randomly assigned to treatment groups. Simple and effective when groups are comparable.
- Randomised block design: subjects are first grouped into blocks (e.g. by age), then randomly assigned to treatments within each block. This reduces variability from the blocking factor, similar in spirit to stratified sampling.
- Factorial design: tests multiple IVs simultaneously. A \(2 \times 3\) factorial design has 2 levels of one variable and 3 of another, giving 6 treatment combinations. This lets you detect interactions, where the effect of one variable depends on the level of another.
- Crossover design: each subject receives all treatments in sequence (with washout periods in between). Every subject serves as their own control, reducing the effect of individual differences.
-
In ML experiments, these principles are critical. When comparing models, you should control for random seed, dataset split, and hardware. Cross-validation is a form of crossover design. Ablation studies, where you remove one component at a time, follow the logic of factorial designs.
Coding Tasks (use CoLab or notebook)¶
-
Construct a 95% confidence interval for the height example, then experiment with different confidence levels and sample sizes.
import jax.numpy as jnp x_bar = 170.0 # sample mean sigma = 8.0 # population std (known) n = 50 # sample size # Critical values for common confidence levels z_stars = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576} for conf, z_star in z_stars.items(): me = z_star * (sigma / jnp.sqrt(n)) lower, upper = x_bar - me, x_bar + me print(f"{conf*100:.0f}% CI: [{lower:.2f}, {upper:.2f}] (ME = {me:.2f})") -
Estimate \(\pi\) using Monte Carlo simulation. Plot how the estimate converges as you increase the number of points.
import jax import jax.numpy as jnp import matplotlib.pyplot as plt key = jax.random.PRNGKey(42) # Generate random points in [-1, 1] x [-1, 1] n_points = 100_000 k1, k2 = jax.random.split(key) x = jax.random.uniform(k1, shape=(n_points,), minval=-1, maxval=1) y = jax.random.uniform(k2, shape=(n_points,), minval=-1, maxval=1) # Check which points are inside the unit circle inside = (x**2 + y**2) <= 1.0 cumulative_inside = jnp.cumsum(inside) counts = jnp.arange(1, n_points + 1) pi_estimates = 4.0 * cumulative_inside / counts plt.figure(figsize=(10, 4)) plt.plot(pi_estimates, color="#3498db", alpha=0.7, linewidth=0.5) plt.axhline(y=jnp.pi, color="#e74c3c", linestyle="--", label=f"π = {jnp.pi:.6f}") plt.xlabel("Number of points") plt.ylabel("Estimate of π") plt.title("Monte Carlo estimation of π") plt.legend() plt.ylim(2.8, 3.5) plt.show() print(f"Final estimate: {pi_estimates[-1]:.6f}") print(f"True value: {jnp.pi:.6f}") print(f"Error: {abs(pi_estimates[-1] - jnp.pi):.6f}") -
Perform a simple power analysis: for a given effect size and standard deviation, compute the required sample size and verify it by simulation.
import jax import jax.numpy as jnp # Parameters delta = 2.0 # effect size (difference in means) sigma = 8.0 # population std alpha = 0.05 power_target = 0.80 # Analytical sample size z_alpha = 1.96 # two-tailed, alpha=0.05 z_beta = 0.84 # power=0.80 n_required = ((z_alpha + z_beta) * sigma / delta) ** 2 print(f"Required n per group: {n_required:.0f}") # Verify by simulation key = jax.random.PRNGKey(7) n = int(jnp.ceil(n_required)) n_sims = 5000 rejections = 0 for _ in range(n_sims): key, k1, k2 = jax.random.split(key, 3) group_a = jax.random.normal(k1, shape=(n,)) * sigma + 50 group_b = jax.random.normal(k2, shape=(n,)) * sigma + 50 + delta pooled_se = jnp.sqrt(2 * sigma**2 / n) z = (group_b.mean() - group_a.mean()) / pooled_se p = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z))) if p <= alpha: rejections += 1 print(f"Simulated power: {rejections/n_sims:.3f}") print(f"Target power: {power_target:.3f}") -
Visualise how confidence interval width changes with sample size. This shows why collecting more data gives more precise estimates.
import jax.numpy as jnp import matplotlib.pyplot as plt sigma = 8.0 z_star = 1.96 # 95% confidence sample_sizes = jnp.array([10, 20, 30, 50, 100, 200, 500, 1000], dtype=jnp.float32) margins = z_star * sigma / jnp.sqrt(sample_sizes) plt.figure(figsize=(8, 4)) plt.bar([str(int(n)) for n in sample_sizes], margins, color="#3498db", alpha=0.7) plt.xlabel("Sample size") plt.ylabel("Margin of error (cm)") plt.title("95% CI margin of error shrinks with larger samples") plt.show()