Skip to content

Sampling

Sampling determines how we collect data and directly controls the quality of every conclusion we draw. This file covers random, stratified, cluster, and systematic sampling, sampling distributions, the law of large numbers, and bootstrapping -- methods essential for training/test splits and dataset curation in ML.

  • In an ideal world, you would measure every single member of the group you care about. In practice, that is almost never possible. You cannot survey every voter, test every light bulb, or scan every patient. So you take a sample and use it to learn about the whole.

  • The population is the complete set of individuals or items you want to study. The sample is the subset you actually observe.

  • A parameter is a number that describes the population (e.g. the true average height of all adults in a country).

  • A statistic is a number computed from your sample (e.g. the average height of the 500 people you measured). Statistics are used to estimate parameters.

  • The quality of your conclusions depends entirely on how you select your sample. A biased sample leads to biased conclusions, no matter how sophisticated your analysis.

  • The sampling frame is the list of all individuals from which you actually draw your sample. Ideally this matches the population perfectly, but in practice there are gaps.

  • For instnce, if you survey people by phone, you miss everyone without a phone. The difference between the frame and the population is called coverage error.

  • Sampling error is the natural discrepancy between a sample statistic and the population parameter.

  • Even a perfectly random sample will not match the population exactly. Larger samples reduce sampling error.

  • There are two broad families of sampling: probability and non-probability.

  • Probability sampling means every member of the population has a known, nonzero chance of being selected. This lets you quantify uncertainty and generalise results.

  • Simple random sampling: every individual has an equal chance of being selected, and every possible sample of size \(n\) is equally likely. Think of putting every name in a hat and drawing blindly.

  • Stratified sampling: divide the population into non-overlapping groups (strata) based on a shared characteristic (e.g. age group, region), then randomly sample from each stratum. This guarantees representation from every group and reduces variance when strata differ from each other.

  • Cluster sampling: divide the population into groups (clusters), randomly select some clusters, then include everyone in the chosen clusters. This is practical when the population is spread out geographically, like sampling entire schools rather than individual students across a district.

  • Systematic sampling: pick a random starting point, then select every \(k\)-th individual from the list. For example, start at person 7 and then take every 10th person (7, 17, 27, ...). Simple to implement but can introduce bias if the list has a hidden pattern.

Three probability sampling methods side by side: simple random, stratified, and cluster

  • Non-probability sampling does not give every member a known chance of selection. Results cannot be rigorously generalised, but these methods are often faster and cheaper.

  • Convenience sampling: select whoever is easiest to reach. Surveying people at a shopping mall is convenient but misses those who do not shop there.

  • Quota sampling: like stratified sampling, but without randomness. The researcher fills quotas (e.g. 50 men and 50 women) by picking accessible individuals from each group.

  • Snowball sampling: start with a few participants and ask them to recruit others. Useful for hard-to-reach populations (e.g. studying rare diseases), but heavily biased toward connected individuals.

  • Once you have a sampling method, a natural question arises: if I took a different sample, would I get a different statistic? Almost certainly yes. The sampling distribution is the distribution of a statistic (like the sample mean) across all possible samples of the same size.

  • Imagine drawing 1,000 different samples of 30 people and computing the mean height of each. Those 1,000 means form a distribution. Some will be a bit above the true population mean, some a bit below, and most will cluster around the true value.

  • The standard deviation of this sampling distribution is called the standard error:

\[SE = \frac{\sigma}{\sqrt{n}}\]
  • Notice that the standard error shrinks as \(n\) grows. Larger samples give more precise estimates. Quadrupling the sample size halves the standard error.

  • The most important result in statistics is the Central Limit Theorem (CLT). It says: no matter what the shape of the original population, the distribution of sample means approaches a normal distribution as the sample size increases.

CLT: a skewed population produces normally distributed sample means

  • More precisely, if \(X_1, X_2, \ldots, X_n\) are independent observations from any distribution with mean \(\mu\) and finite variance \(\sigma^2\), then as \(n\) grows:
\[\bar{X} \approx \text{Normal}\!\left(\mu, \frac{\sigma^2}{n}\right)\]
  • The CLT is what makes most of inferential statistics work. It lets us use the normal distribution as an approximation even when the underlying data is not normal, as long as the sample is large enough.

  • How large is "large enough"? A common rule of thumb is \(n \ge 30\), but this depends on how non-normal the population is. For highly skewed distributions, you may need more. For roughly symmetric populations, even \(n = 10\) can be sufficient.

  • The CLT has three key conditions:

    • Independence: each observation must not influence the others
    • Finite variance: the population variance must exist (rules out some exotic distributions)
    • Identical distribution: all observations come from the same distribution

Coding Tasks (use CoLab or notebook)

  1. Demonstrate the CLT visually: draw samples from a highly skewed distribution, compute sample means, and watch the histogram of means become bell-shaped.

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    key = jax.random.PRNGKey(0)
    
    # Exponential distribution (very skewed)
    population = jax.random.exponential(key, shape=(100_000,))
    
    fig, axes = plt.subplots(1, 4, figsize=(14, 3))
    sample_sizes = [1, 5, 30, 100]
    
    for ax, n in zip(axes, sample_sizes):
        keys = jax.random.split(key, 2000)
        means = jnp.array([jax.random.choice(k, population, shape=(n,)).mean() for k in keys])
        ax.hist(means, bins=40, color="#3498db", alpha=0.7, density=True)
        ax.set_title(f"n = {n}")
        ax.set_xlim(0, 4)
    
    fig.suptitle("CLT: sample means become normal as n increases", fontsize=13)
    plt.tight_layout()
    plt.show()
    

  2. Compare simple random sampling with stratified sampling. Create a population with distinct groups and show that stratified sampling gives lower variance in estimates.

    import jax
    import jax.numpy as jnp
    
    key = jax.random.PRNGKey(42)
    
    # Population: two distinct groups
    group_a = jax.random.normal(key, shape=(500,)) + 10   # mean ~10
    key, subkey = jax.random.split(key)
    group_b = jax.random.normal(subkey, shape=(500,)) + 20  # mean ~20
    population = jnp.concatenate([group_a, group_b])
    
    # Simple random sampling: 1000 trials, sample size 20
    srs_means = []
    for i in range(1000):
        key, subkey = jax.random.split(key)
        sample = jax.random.choice(subkey, population, shape=(20,), replace=False)
        srs_means.append(sample.mean())
    srs_means = jnp.array(srs_means)
    
    # Stratified sampling: 10 from each group
    strat_means = []
    for i in range(1000):
        key, k1, k2 = jax.random.split(key, 3)
        s_a = jax.random.choice(k1, group_a, shape=(10,), replace=False)
        s_b = jax.random.choice(k2, group_b, shape=(10,), replace=False)
        strat_means.append(jnp.concatenate([s_a, s_b]).mean())
    strat_means = jnp.array(strat_means)
    
    print(f"Simple Random - Mean: {srs_means.mean():.3f}, Std: {srs_means.std():.3f}")
    print(f"Stratified    - Mean: {strat_means.mean():.3f}, Std: {strat_means.std():.3f}")
    print(f"Stratified sampling reduced variance by {(1 - strat_means.var()/srs_means.var())*100:.1f}%")
    

  3. Explore how sample size affects standard error. Plot the standard error against sample size and confirm the \(1/\sqrt{n}\) relationship.

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    key = jax.random.PRNGKey(7)
    population = jax.random.normal(key, shape=(50_000,)) * 10 + 50
    
    sample_sizes = [5, 10, 20, 50, 100, 200, 500, 1000]
    std_errors = []
    
    for n in sample_sizes:
        means = []
        for _ in range(500):
            key, subkey = jax.random.split(key)
            sample = jax.random.choice(subkey, population, shape=(n,))
            means.append(sample.mean())
        std_errors.append(jnp.array(means).std())
    
    plt.figure(figsize=(8, 4))
    plt.plot(sample_sizes, std_errors, "o-", color="#e74c3c", label="Observed SE")
    theoretical = population.std() / jnp.sqrt(jnp.array(sample_sizes, dtype=jnp.float32))
    plt.plot(sample_sizes, theoretical, "--", color="#3498db", label="σ/√n (theoretical)")
    plt.xlabel("Sample size (n)")
    plt.ylabel("Standard error")
    plt.legend()
    plt.title("Standard error shrinks with larger samples")
    plt.show()