Skip to content

Statistical Inference

Statistical inference goes beyond yes/no decisions to estimate population parameters with quantified uncertainty. This file covers confidence intervals, point and interval estimation, maximum likelihood estimation, the method of moments, and regression analysis, the bridge between raw data and predictive models in ML.

  • Hypothesis testing gives you a yes/no decision: reject or fail to reject. But often you want something more informative, a range of plausible values for the parameter you are estimating. That is what confidence intervals provide.

  • A point estimate is a single number computed from your sample, like the sample mean \(\bar{x}\). It is your best guess for the population parameter, but on its own it gives no sense of how precise the estimate is.

  • A confidence interval wraps that point estimate with a range that reflects uncertainty. It takes the form:

\[\text{CI} = \bar{x} \pm \text{ME}\]
  • The margin of error (ME) depends on three things: how confident you want to be, how much variability is in the data, and how large your sample is:
\[\text{ME} = z^\ast \cdot \frac{\sigma}{\sqrt{n}}\]
  • Here \(z^\ast\) is the critical value from the normal distribution that matches your desired confidence level. For 95% confidence, \(z^\ast = 1.96\). For 99% confidence, \(z^\ast = 2.576\).

Confidence interval: point estimate with margin of error on either side

  • A 95% confidence interval means: if you repeated the experiment many times and built an interval each time, about 95% of those intervals would contain the true population parameter. It does not mean there is a 95% probability the parameter is in this specific interval. The parameter is fixed; the intervals are what vary.

  • Worked example: You measure the heights of 50 people and find \(\bar{x} = 170\) cm with \(\sigma = 8\) cm. Construct a 95% confidence interval.

\[\text{ME} = 1.96 \cdot \frac{8}{\sqrt{50}} = 1.96 \cdot 1.131 = 2.22 \text{ cm}\]
\[\text{CI} = [170 - 2.22, \; 170 + 2.22] = [167.78, \; 172.22]\]
  • You can say with 95% confidence that the true mean height lies between 167.78 and 172.22 cm.

  • When \(\sigma\) is unknown (the usual case), use the sample standard deviation \(s\) and the t-distribution instead:

\[\text{CI} = \bar{x} \pm t^\ast_{n-1} \cdot \frac{s}{\sqrt{n}}\]
  • Wider intervals are more confident but less precise. Narrower intervals are more precise but less confident. You can narrow an interval without losing confidence by increasing the sample size.

  • Power analysis helps you plan an experiment before you run it. The question is: how large a sample do I need to detect an effect of a given size with a specified power?

  • Recall from the previous file that power = \(1 - \beta\), the probability of correctly rejecting a false \(H_0\). A common target is 80% power.

  • The required sample size for a z-test detecting a difference \(\delta\) with significance \(\alpha\) and power \(1-\beta\) is:

\[n = \left(\frac{(z_{\alpha/2} + z_{\beta}) \cdot \sigma}{\delta}\right)^2\]
  • For example, to detect a 2 cm difference in mean height (\(\sigma = 8\)) with \(\alpha = 0.05\) and 80% power (\(z_{0.025} = 1.96\), \(z_{0.20} = 0.84\)):
\[n = \left(\frac{(1.96 + 0.84) \cdot 8}{2}\right)^2 = \left(\frac{22.4}{2}\right)^2 = 11.2^2 \approx 126\]
  • You would need about 126 people per group.

  • Power analysis prevents two common mistakes: running an experiment too small to detect a real effect (underpowered), or wasting resources on an experiment far larger than necessary (overpowered).

  • Monte Carlo methods use random sampling to solve problems that are difficult or impossible to solve analytically. The core idea: if you cannot compute something exactly, simulate it many times and use the results as an approximation.

  • The name comes from the Monte Carlo casino, a nod to the role of randomness. These methods are workhorses in ML for tasks like estimating integrals, evaluating model uncertainty, and approximating complex distributions.

  • The general Monte Carlo recipe:

    • Define a domain of possible inputs
    • Generate random inputs from that domain
    • Evaluate a function on each input
    • Aggregate the results (average, count, etc.)
  • A classic example is estimating \(\pi\). Imagine a square with side length 2, centred at the origin, with a circle of radius 1 inscribed inside it. The area of the square is 4, and the area of the circle is \(\pi\).

Square with inscribed circle, random points coloured by inside/outside

  • Drop random points uniformly in the square. The fraction that land inside the circle approximates \(\pi/4\):
\[\pi \approx 4 \times \frac{\text{points inside circle}}{\text{total points}}\]
  • A point \((x, y)\) is inside the circle if \(x^2 + y^2 \le 1\). The more points you throw, the closer your estimate gets to the true value of \(\pi\).

  • In ML, Monte Carlo methods appear in:

    • Monte Carlo dropout: run inference multiple times with dropout enabled to estimate prediction uncertainty
    • MCMC (Markov Chain Monte Carlo): sample from complex posterior distributions in Bayesian models
    • Policy gradient methods: estimate gradients in reinforcement learning by sampling trajectories
  • Factor analysis is a technique for discovering hidden (latent) variables that explain the correlations among observed variables. If 10 personality survey questions can be explained by 3 underlying traits (extraversion, agreeableness, conscientiousness), factor analysis finds those traits.

  • The model assumes each observed variable \(x_i\) is a linear combination of a few latent factors \(f_j\) plus noise:

\[x_i = \lambda_{i1} f_1 + \lambda_{i2} f_2 + \ldots + \lambda_{ik} f_k + \epsilon_i\]
  • The \(\lambda\) values are called factor loadings and tell you how strongly each observed variable relates to each factor. This connects directly to the matrix decompositions from Chapter 2; factor analysis is closely related to eigenvalue decomposition and SVD.

  • Experimental design is the art of structuring an experiment so that you can draw valid conclusions. Poor design can make even a large dataset useless.

  • Key components of a well-designed experiment:

    • Independent variable (IV): what you manipulate (e.g. drug dose, model architecture)
    • Dependent variable (DV): what you measure (e.g. recovery time, accuracy)
    • Control group: receives no treatment (or a placebo), providing a baseline for comparison
    • Random assignment: participants are assigned to groups randomly, which balances out confounding variables you did not measure
  • Common experimental designs:

    • Completely randomised design: subjects are randomly assigned to treatment groups. Simple and effective when groups are comparable.
    • Randomised block design: subjects are first grouped into blocks (e.g. by age), then randomly assigned to treatments within each block. This reduces variability from the blocking factor, similar in spirit to stratified sampling.
    • Factorial design: tests multiple IVs simultaneously. A \(2 \times 3\) factorial design has 2 levels of one variable and 3 of another, giving 6 treatment combinations. This lets you detect interactions, where the effect of one variable depends on the level of another.
    • Crossover design: each subject receives all treatments in sequence (with washout periods in between). Every subject serves as their own control, reducing the effect of individual differences.
  • In ML experiments, these principles are critical. When comparing models, you should control for random seed, dataset split, and hardware. Cross-validation is a form of crossover design. Ablation studies, where you remove one component at a time, follow the logic of factorial designs.

Coding Tasks (use CoLab or notebook)

  1. Construct a 95% confidence interval for the height example, then experiment with different confidence levels and sample sizes.

    import jax.numpy as jnp
    
    x_bar = 170.0    # sample mean
    sigma = 8.0      # population std (known)
    n = 50           # sample size
    
    # Critical values for common confidence levels
    z_stars = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}
    
    for conf, z_star in z_stars.items():
        me = z_star * (sigma / jnp.sqrt(n))
        lower, upper = x_bar - me, x_bar + me
        print(f"{conf*100:.0f}% CI: [{lower:.2f}, {upper:.2f}]  (ME = {me:.2f})")
    

  2. Estimate \(\pi\) using Monte Carlo simulation. Plot how the estimate converges as you increase the number of points.

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    key = jax.random.PRNGKey(42)
    
    # Generate random points in [-1, 1] x [-1, 1]
    n_points = 100_000
    k1, k2 = jax.random.split(key)
    x = jax.random.uniform(k1, shape=(n_points,), minval=-1, maxval=1)
    y = jax.random.uniform(k2, shape=(n_points,), minval=-1, maxval=1)
    
    # Check which points are inside the unit circle
    inside = (x**2 + y**2) <= 1.0
    cumulative_inside = jnp.cumsum(inside)
    counts = jnp.arange(1, n_points + 1)
    pi_estimates = 4.0 * cumulative_inside / counts
    
    plt.figure(figsize=(10, 4))
    plt.plot(pi_estimates, color="#3498db", alpha=0.7, linewidth=0.5)
    plt.axhline(y=jnp.pi, color="#e74c3c", linestyle="--", label=f"π = {jnp.pi:.6f}")
    plt.xlabel("Number of points")
    plt.ylabel("Estimate of π")
    plt.title("Monte Carlo estimation of π")
    plt.legend()
    plt.ylim(2.8, 3.5)
    plt.show()
    
    print(f"Final estimate: {pi_estimates[-1]:.6f}")
    print(f"True value:     {jnp.pi:.6f}")
    print(f"Error:          {abs(pi_estimates[-1] - jnp.pi):.6f}")
    

  3. Perform a simple power analysis: for a given effect size and standard deviation, compute the required sample size and verify it by simulation.

    import jax
    import jax.numpy as jnp
    
    # Parameters
    delta = 2.0      # effect size (difference in means)
    sigma = 8.0      # population std
    alpha = 0.05
    power_target = 0.80
    
    # Analytical sample size
    z_alpha = 1.96   # two-tailed, alpha=0.05
    z_beta = 0.84    # power=0.80
    n_required = ((z_alpha + z_beta) * sigma / delta) ** 2
    print(f"Required n per group: {n_required:.0f}")
    
    # Verify by simulation
    key = jax.random.PRNGKey(7)
    n = int(jnp.ceil(n_required))
    n_sims = 5000
    rejections = 0
    
    for _ in range(n_sims):
        key, k1, k2 = jax.random.split(key, 3)
        group_a = jax.random.normal(k1, shape=(n,)) * sigma + 50
        group_b = jax.random.normal(k2, shape=(n,)) * sigma + 50 + delta
        pooled_se = jnp.sqrt(2 * sigma**2 / n)
        z = (group_b.mean() - group_a.mean()) / pooled_se
        p = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z)))
        if p <= alpha:
            rejections += 1
    
    print(f"Simulated power: {rejections/n_sims:.3f}")
    print(f"Target power:    {power_target:.3f}")
    

  4. Visualise how confidence interval width changes with sample size. This shows why collecting more data gives more precise estimates.

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    sigma = 8.0
    z_star = 1.96  # 95% confidence
    
    sample_sizes = jnp.array([10, 20, 30, 50, 100, 200, 500, 1000], dtype=jnp.float32)
    margins = z_star * sigma / jnp.sqrt(sample_sizes)
    
    plt.figure(figsize=(8, 4))
    plt.bar([str(int(n)) for n in sample_sizes], margins, color="#3498db", alpha=0.7)
    plt.xlabel("Sample size")
    plt.ylabel("Margin of error (cm)")
    plt.title("95% CI margin of error shrinks with larger samples")
    plt.show()