Skip to content

Statistical Inference

Statistical inference goes beyond yes/no decisions to estimate population parameters with quantified uncertainty. This file covers confidence intervals, point and interval estimation, maximum likelihood estimation, the method of moments, and regression analysis -- the bridge between raw data and predictive models in ML.

  • Hypothesis testing gives you a yes/no decision: reject or fail to reject. But often you want something more informative, a range of plausible values for the parameter you are estimating. That is what confidence intervals provide.

  • A point estimate is a single number computed from your sample, like the sample mean \(\bar{x}\). It is your best guess for the population parameter, but on its own it gives no sense of how precise the estimate is.

  • A confidence interval wraps that point estimate with a range that reflects uncertainty. It takes the form:

\[\text{CI} = \bar{x} \pm \text{ME}\]
  • The margin of error (ME) depends on three things: how confident you want to be, how much variability is in the data, and how large your sample is:
\[\text{ME} = z^\ast \cdot \frac{\sigma}{\sqrt{n}}\]
  • Here \(z^\ast\) is the critical value from the normal distribution that matches your desired confidence level. For 95% confidence, \(z^\ast = 1.96\). For 99% confidence, \(z^\ast = 2.576\).

Confidence interval: point estimate with margin of error on either side

  • A 95% confidence interval means: if you repeated the experiment many times and built an interval each time, about 95% of those intervals would contain the true population parameter. It does not mean there is a 95% probability the parameter is in this specific interval. The parameter is fixed; the intervals are what vary.

  • Worked example: You measure the heights of 50 people and find \(\bar{x} = 170\) cm with \(\sigma = 8\) cm. Construct a 95% confidence interval.

\[\text{ME} = 1.96 \cdot \frac{8}{\sqrt{50}} = 1.96 \cdot 1.131 = 2.22 \text{ cm}\]
\[\text{CI} = [170 - 2.22, \; 170 + 2.22] = [167.78, \; 172.22]\]
  • You can say with 95% confidence that the true mean height lies between 167.78 and 172.22 cm.

  • When \(\sigma\) is unknown (the usual case), use the sample standard deviation \(s\) and the t-distribution instead:

\[\text{CI} = \bar{x} \pm t^\ast_{n-1} \cdot \frac{s}{\sqrt{n}}\]
  • Wider intervals are more confident but less precise. Narrower intervals are more precise but less confident. You can narrow an interval without losing confidence by increasing the sample size.

  • Power analysis helps you plan an experiment before you run it. The question is: how large a sample do I need to detect an effect of a given size with a specified power?

  • Recall from the previous file that power = \(1 - \beta\), the probability of correctly rejecting a false \(H_0\). A common target is 80% power.

  • The required sample size for a z-test detecting a difference \(\delta\) with significance \(\alpha\) and power \(1-\beta\) is:

\[n = \left(\frac{(z_{\alpha/2} + z_{\beta}) \cdot \sigma}{\delta}\right)^2\]
  • For example, to detect a 2 cm difference in mean height (\(\sigma = 8\)) with \(\alpha = 0.05\) and 80% power (\(z_{0.025} = 1.96\), \(z_{0.20} = 0.84\)):
\[n = \left(\frac{(1.96 + 0.84) \cdot 8}{2}\right)^2 = \left(\frac{22.4}{2}\right)^2 = 11.2^2 \approx 126\]
  • You would need about 126 people per group.

  • Power analysis prevents two common mistakes: running an experiment too small to detect a real effect (underpowered), or wasting resources on an experiment far larger than necessary (overpowered).

  • Monte Carlo methods use random sampling to solve problems that are difficult or impossible to solve analytically. The core idea: if you cannot compute something exactly, simulate it many times and use the results as an approximation.

  • The name comes from the Monte Carlo casino, a nod to the role of randomness. These methods are workhorses in ML for tasks like estimating integrals, evaluating model uncertainty, and approximating complex distributions.

  • The general Monte Carlo recipe:

    • Define a domain of possible inputs
    • Generate random inputs from that domain
    • Evaluate a function on each input
    • Aggregate the results (average, count, etc.)
  • A classic example is estimating \(\pi\). Imagine a square with side length 2, centred at the origin, with a circle of radius 1 inscribed inside it. The area of the square is 4, and the area of the circle is \(\pi\).

Square with inscribed circle, random points coloured by inside/outside

  • Drop random points uniformly in the square. The fraction that land inside the circle approximates \(\pi/4\):
\[\pi \approx 4 \times \frac{\text{points inside circle}}{\text{total points}}\]
  • A point \((x, y)\) is inside the circle if \(x^2 + y^2 \le 1\). The more points you throw, the closer your estimate gets to the true value of \(\pi\).

  • In ML, Monte Carlo methods appear in:

    • Monte Carlo dropout: run inference multiple times with dropout enabled to estimate prediction uncertainty
    • MCMC (Markov Chain Monte Carlo): sample from complex posterior distributions in Bayesian models
    • Policy gradient methods: estimate gradients in reinforcement learning by sampling trajectories
  • Factor analysis is a technique for discovering hidden (latent) variables that explain the correlations among observed variables. If 10 personality survey questions can be explained by 3 underlying traits (extraversion, agreeableness, conscientiousness), factor analysis finds those traits.

  • The model assumes each observed variable \(x_i\) is a linear combination of a few latent factors \(f_j\) plus noise:

\[x_i = \lambda_{i1} f_1 + \lambda_{i2} f_2 + \ldots + \lambda_{ik} f_k + \epsilon_i\]
  • The \(\lambda\) values are called factor loadings and tell you how strongly each observed variable relates to each factor. This connects directly to the matrix decompositions from Chapter 2; factor analysis is closely related to eigenvalue decomposition and SVD.

  • Experimental design is the art of structuring an experiment so that you can draw valid conclusions. Poor design can make even a large dataset useless.

  • Key components of a well-designed experiment:

    • Independent variable (IV): what you manipulate (e.g. drug dose, model architecture)
    • Dependent variable (DV): what you measure (e.g. recovery time, accuracy)
    • Control group: receives no treatment (or a placebo), providing a baseline for comparison
    • Random assignment: participants are assigned to groups randomly, which balances out confounding variables you did not measure
  • Common experimental designs:

    • Completely randomised design: subjects are randomly assigned to treatment groups. Simple and effective when groups are comparable.
    • Randomised block design: subjects are first grouped into blocks (e.g. by age), then randomly assigned to treatments within each block. This reduces variability from the blocking factor, similar in spirit to stratified sampling.
    • Factorial design: tests multiple IVs simultaneously. A \(2 \times 3\) factorial design has 2 levels of one variable and 3 of another, giving 6 treatment combinations. This lets you detect interactions, where the effect of one variable depends on the level of another.
    • Crossover design: each subject receives all treatments in sequence (with washout periods in between). Every subject serves as their own control, reducing the effect of individual differences.
  • In ML experiments, these principles are critical. When comparing models, you should control for random seed, dataset split, and hardware. Cross-validation is a form of crossover design. Ablation studies, where you remove one component at a time, follow the logic of factorial designs.

Coding Tasks (use CoLab or notebook)

  1. Construct a 95% confidence interval for the height example, then experiment with different confidence levels and sample sizes.

    import jax.numpy as jnp
    
    x_bar = 170.0    # sample mean
    sigma = 8.0      # population std (known)
    n = 50           # sample size
    
    # Critical values for common confidence levels
    z_stars = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}
    
    for conf, z_star in z_stars.items():
        me = z_star * (sigma / jnp.sqrt(n))
        lower, upper = x_bar - me, x_bar + me
        print(f"{conf*100:.0f}% CI: [{lower:.2f}, {upper:.2f}]  (ME = {me:.2f})")
    

  2. Estimate \(\pi\) using Monte Carlo simulation. Plot how the estimate converges as you increase the number of points.

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    key = jax.random.PRNGKey(42)
    
    # Generate random points in [-1, 1] x [-1, 1]
    n_points = 100_000
    k1, k2 = jax.random.split(key)
    x = jax.random.uniform(k1, shape=(n_points,), minval=-1, maxval=1)
    y = jax.random.uniform(k2, shape=(n_points,), minval=-1, maxval=1)
    
    # Check which points are inside the unit circle
    inside = (x**2 + y**2) <= 1.0
    cumulative_inside = jnp.cumsum(inside)
    counts = jnp.arange(1, n_points + 1)
    pi_estimates = 4.0 * cumulative_inside / counts
    
    plt.figure(figsize=(10, 4))
    plt.plot(pi_estimates, color="#3498db", alpha=0.7, linewidth=0.5)
    plt.axhline(y=jnp.pi, color="#e74c3c", linestyle="--", label=f"π = {jnp.pi:.6f}")
    plt.xlabel("Number of points")
    plt.ylabel("Estimate of π")
    plt.title("Monte Carlo estimation of π")
    plt.legend()
    plt.ylim(2.8, 3.5)
    plt.show()
    
    print(f"Final estimate: {pi_estimates[-1]:.6f}")
    print(f"True value:     {jnp.pi:.6f}")
    print(f"Error:          {abs(pi_estimates[-1] - jnp.pi):.6f}")
    

  3. Perform a simple power analysis: for a given effect size and standard deviation, compute the required sample size and verify it by simulation.

    import jax
    import jax.numpy as jnp
    
    # Parameters
    delta = 2.0      # effect size (difference in means)
    sigma = 8.0      # population std
    alpha = 0.05
    power_target = 0.80
    
    # Analytical sample size
    z_alpha = 1.96   # two-tailed, alpha=0.05
    z_beta = 0.84    # power=0.80
    n_required = ((z_alpha + z_beta) * sigma / delta) ** 2
    print(f"Required n per group: {n_required:.0f}")
    
    # Verify by simulation
    key = jax.random.PRNGKey(7)
    n = int(jnp.ceil(n_required))
    n_sims = 5000
    rejections = 0
    
    for _ in range(n_sims):
        key, k1, k2 = jax.random.split(key, 3)
        group_a = jax.random.normal(k1, shape=(n,)) * sigma + 50
        group_b = jax.random.normal(k2, shape=(n,)) * sigma + 50 + delta
        pooled_se = jnp.sqrt(2 * sigma**2 / n)
        z = (group_b.mean() - group_a.mean()) / pooled_se
        p = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z)))
        if p <= alpha:
            rejections += 1
    
    print(f"Simulated power: {rejections/n_sims:.3f}")
    print(f"Target power:    {power_target:.3f}")
    

  4. Visualise how confidence interval width changes with sample size. This shows why collecting more data gives more precise estimates.

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    sigma = 8.0
    z_star = 1.96  # 95% confidence
    
    sample_sizes = jnp.array([10, 20, 30, 50, 100, 200, 500, 1000], dtype=jnp.float32)
    margins = z_star * sigma / jnp.sqrt(sample_sizes)
    
    plt.figure(figsize=(8, 4))
    plt.bar([str(int(n)) for n in sample_sizes], margins, color="#3498db", alpha=0.7)
    plt.xlabel("Sample size")
    plt.ylabel("Margin of error (cm)")
    plt.title("95% CI margin of error shrinks with larger samples")
    plt.show()