Testing and Quality Assurance¶
Testing is how you know your code works, not just now, but after every change. This file covers the test pyramid, unit tests with pytest, mocking, testing ML-specific code, CI/CD pipelines, linting, formatting, and code review, the practices that catch bugs before they reach production.
-
ML code is notoriously undertested. "It trains, therefore it works" is the prevailing attitude. This leads to silent bugs: a data loader that shuffles incorrectly, a loss function with a sign error, a preprocessing step that drops 5% of the data. These bugs do not crash your program. They just make your model quietly worse, and you waste weeks debugging metrics that "should be higher."
-
Testing is not overhead. It is the fastest way to move fast without breaking things.
The Test Pyramid¶
-
Tests are organised in layers, from fast and narrow to slow and broad:
-
Unit tests (base): test individual functions and classes in isolation. Fast (milliseconds), numerous (hundreds to thousands). "Does
normalise_imageproduce values in [0, 1]?" -
Integration tests (middle): test that components work together. Slower (seconds). "Does the data loader produce batches in the format the model expects?"
-
End-to-end tests (top): test the full pipeline from input to output. Slow (minutes). "Does
python train.py --config test.yamlcomplete without errors and produce a valid checkpoint?"
-
-
The pyramid shape means: write many unit tests, fewer integration tests, and a handful of end-to-end tests. Unit tests catch most bugs and run in seconds. End-to-end tests catch integration issues but are slow and fragile.
Unit Tests with pytest¶
- pytest is the standard Python testing framework. A test is a function starting with
test_in a file starting withtest_:
# tests/test_utils.py
def test_normalise_image():
import numpy as np
image = np.array([0, 128, 255], dtype=np.uint8)
result = normalise_image(image, mean=128, std=128)
assert result.min() >= -1.0
assert result.max() <= 1.0
assert abs(result[1]) < 1e-6 # 128 normalised by mean=128 should be ~0
def test_normalise_empty():
import numpy as np
image = np.array([], dtype=np.uint8)
result = normalise_image(image, mean=128, std=128)
assert len(result) == 0
pytest tests/ # run all tests
pytest tests/test_utils.py # run one file
pytest -v # verbose output
pytest -x # stop on first failure
pytest -k "normalise" # run tests matching name pattern
pytest --tb=short # shorter tracebacks
Fixtures¶
- Fixtures provide reusable setup for tests. Instead of repeating setup code in every test, define it once:
import pytest
@pytest.fixture
def sample_dataset():
"""Create a small dataset for testing."""
return {
"inputs": torch.randn(10, 3, 32, 32),
"labels": torch.randint(0, 10, (10,))
}
@pytest.fixture
def trained_model():
"""Load a small pretrained model."""
model = SmallModel()
model.load_state_dict(torch.load("tests/fixtures/small_model.pt"))
return model
def test_model_output_shape(trained_model, sample_dataset):
output = trained_model(sample_dataset["inputs"])
assert output.shape == (10, 10) # batch_size x num_classes
- Fixtures can have scopes:
scope="function"(default, fresh per test),scope="module"(once per file),scope="session"(once per test run). Usescope="session"for expensive setup like loading a model.
Parametrised Tests¶
- Test the same function with multiple inputs without duplicating code:
@pytest.mark.parametrize("input,expected", [
([1, 2, 3], 6),
([], 0),
([-1, 1], 0),
([1000000, 1000000], 2000000),
])
def test_sum(input, expected):
assert sum(input) == expected
Mocking and Patching¶
- Mocking replaces a real dependency with a fake one during testing. This lets you test a function in isolation, without needing a database, API, or GPU.
from unittest.mock import patch, MagicMock
def test_training_logs_metrics():
mock_logger = MagicMock()
with patch("my_project.training.trainer.wandb") as mock_wandb:
trainer = Trainer(logger=mock_logger)
trainer.train_one_epoch()
# verify that the trainer logged metrics
mock_logger.log.assert_called()
# verify it logged a loss value
call_args = mock_logger.log.call_args
assert "loss" in call_args[1]
-
When to mock: external services (APIs, databases, cloud storage), expensive operations (GPU computation, large file I/O), and non-deterministic behaviour (random number generators, timestamps).
-
When NOT to mock: your own code. If you mock everything, your tests verify that mocks behave as expected, not that your code works. Mock at the boundaries, test your logic directly.
Testing ML Code¶
- ML code has unique testing challenges: outputs are probabilistic, training is slow, and "correct" is fuzzy.
Deterministic Seeds¶
- Set random seeds everywhere to make tests reproducible:
import random
import numpy as np
import torch
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Numerical Tolerances¶
- Floating-point comparisons require tolerances (chapter 13, IEEE 754):
# BAD: exact comparison fails due to floating point
assert model_output == 0.5
# GOOD: approximate comparison
import numpy as np
assert np.isclose(model_output, 0.5, atol=1e-5)
# For tensors
assert torch.allclose(output, expected, atol=1e-4)
What to Test in ML¶
- Shape tests: verify that outputs have the expected dimensions.
def test_model_output_shape():
model = MyModel(d_model=256, n_classes=10)
x = torch.randn(8, 32, 256) # batch=8, seq=32, dim=256
output = model(x)
assert output.shape == (8, 10)
- Gradient flow: verify that gradients are non-zero for trainable parameters.
def test_gradients_flow():
model = MyModel()
x = torch.randn(4, 3, 32, 32)
y = torch.randint(0, 10, (4,))
output = model(x)
loss = F.cross_entropy(output, y)
loss.backward()
for name, param in model.named_parameters():
assert param.grad is not None, f"No gradient for {name}"
assert param.grad.abs().sum() > 0, f"Zero gradient for {name}"
- Overfit on one batch: a model should be able to memorise a single batch. If it cannot, something is fundamentally wrong.
def test_overfit_one_batch():
model = MyModel()
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
x, y = get_single_batch()
for _ in range(100):
loss = F.cross_entropy(model(x), y)
loss.backward()
optimiser.step()
optimiser.zero_grad()
assert loss.item() < 0.01, f"Cannot overfit one batch: loss={loss.item()}"
- Data validation: verify data loading produces valid outputs.
def test_dataset_basics():
dataset = MyDataset("tests/fixtures/small_data.csv")
assert len(dataset) > 0
x, y = dataset[0]
assert x.shape == (3, 224, 224)
assert 0 <= y < 10
assert not torch.isnan(x).any()
assert not torch.isinf(x).any()
- Determinism: same input + same seed → same output.
def test_determinism():
set_seed(42)
output1 = model(input_data)
set_seed(42)
output2 = model(input_data)
assert torch.allclose(output1, output2)
CI/CD Pipelines¶
-
Continuous Integration (CI): automatically run tests on every commit or PR. If tests fail, the PR cannot be merged. This prevents broken code from reaching
main. -
GitHub Actions example (
.github/workflows/ci.yml):
name: CI
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: pip install -e ".[dev]"
- run: ruff check src/
- run: mypy src/
- run: pytest tests/ -v --tb=short
- Pre-commit hooks: run checks before each commit (locally), catching issues before they reach CI:
# .pre-commit-config.yaml
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
Linting and Formatting¶
-
Linting catches bugs and style issues without running the code. Formatting enforces consistent style automatically.
-
Ruff: a fast Python linter and formatter (replaces flake8, isort, and black in one tool):
- mypy: static type checker for Python. Catches type errors before runtime:
mypy src/
# src/model.py:42: error: Argument 1 to "forward" has incompatible type "int"; expected "Tensor"
- Type hints make code self-documenting and catch bugs:
def train(
model: nn.Module,
dataloader: DataLoader,
optimiser: torch.optim.Optimizer,
num_epochs: int = 10,
) -> float:
"""Train model and return final loss."""
...
Code Review Best Practices¶
-
For the author:
- Self-review your diff before requesting review. You will catch obvious issues.
- Keep PRs small and focused. One concern per PR.
- Write a clear description: what, why, how to test.
- Respond to every comment (even if just "done").
-
For the reviewer:
- Be kind. Critique the code, not the person. "This could be clearer" not "this is confusing."
- Distinguish blocking issues (bugs, security) from suggestions (style, naming). Use labels: "nit:", "suggestion:", "blocking:".
- Ask questions instead of making demands. "What happens if this list is empty?" is more helpful than "handle the empty case."
- Approve promptly. A PR waiting days for review blocks the author and encourages large, batched PRs (which are harder to review).