Skip to content

Convolutional Networks

Convolutional neural networks learn spatial feature hierarchies directly from pixel data, replacing hand-designed filters with gradient-optimised ones. This file covers convolution mechanics, pooling, stride, dilation, receptive fields, and landmark architectures (LeNet, AlexNet, VGG, ResNet, Inception, EfficientNet) that defined image classification.

  • In file 01, we hand-designed filters for edge detection, blurring, and corner detection. The natural question is: can we learn the optimal filters from data? That is exactly what convolutional neural networks (CNNs) do.

  • Instead of choosing filter weights by hand, CNNs learn them via gradient descent (chapter 06), discovering features that are directly useful for the task at hand.

  • In chapter 06, we introduced the convolution operation, CNN basics, and the idea of filter learning. Here we go deeper into the architectural innovations that made CNNs the dominant paradigm in computer vision for over a decade.

  • Recall the core convolution operation: a filter \(K\) of size \(k \times k\) slides over the input feature map, computing a dot product at each position (chapter 06). The output size is controlled by three hyperparameters:

    • Stride: how many pixels the filter moves between positions. Stride 1 means the filter shifts one pixel at a time. Stride 2 means it shifts two pixels, halving the spatial dimensions. Strided convolution is an alternative to pooling for downsampling.
    • Padding: adding zeros around the input border. "Same" padding (\(p = \lfloor k/2 \rfloor\)) preserves spatial dimensions. "Valid" padding (\(p = 0\)) reduces them.
    • Dilation: inserting gaps between filter elements. A 3x3 filter with dilation 2 covers a 5x5 receptive field using only 9 parameters. Dilated convolutions expand the receptive field without increasing computation.
  • The output spatial size after convolution:

\[\text{out} = \left\lfloor \frac{\text{in} - k + 2p}{s} \right\rfloor + 1\]
  • where \(\text{in}\) is the input size, \(k\) is the kernel size, \(p\) is padding, and \(s\) is stride. This formula applies independently to height and width.

  • The receptive field of a neuron is the region of the original input that can influence its value.

    • Early layers have small receptive fields (they see local patterns like edges).
    • Deeper layers have larger receptive fields (they see larger structures like object parts).
  • The receptive field grows with each layer: roughly by \(k - 1\) pixels per convolutional layer (more with stride or dilation).

Receptive field growing across layers: layer 1 neurons see a 3x3 patch, layer 2 neurons see a 5x5 patch, layer 3 neurons see a 7x7 patch of the original input

  • Pooling layers reduce spatial dimensions while retaining the most important information.

    • Max pooling takes the maximum value in each window, preserving the strongest activation (the most prominent feature).
    • Average pooling takes the mean, smoothing the feature map. A 2x2 pool with stride 2 halves both spatial dimensions.
  • Global Average Pooling (GAP) averages the entire spatial extent of each channel into a single number, producing a vector of length equal to the number of channels. GAP replaces the fully connected layers at the end of many modern architectures, drastically reducing parameter count and acting as a structural regulariser.

  • Batch Normalisation (BatchNorm) normalises activations within each mini-batch to have zero mean and unit variance, then applies a learnable scale and shift (chapter 06). In CNNs, BatchNorm is applied per-channel: statistics are computed across the batch and spatial dimensions for each channel independently. It stabilises training, allows higher learning rates, and acts as a mild regulariser.

  • Dropout (chapter 06) randomly zeroes neurons during training.

  • In CNNs, spatial dropout (Dropout2D) drops entire feature map channels rather than individual pixels, which is more effective because neighbouring pixels in a feature map are highly correlated.

  • Data augmentation artificially expands the training set by applying random transformations to each image during training: horizontal flips, random crops, rotations, colour jitter (adjusting brightness, contrast, saturation, hue), and cutout (masking random rectangular patches). The network sees each image in many different forms, forcing it to learn transformation-invariant features rather than memorising specific pixel patterns.

  • Advanced augmentation strategies include Mixup (blending two images and their labels: \(\tilde{x} = \lambda x_i + (1-\lambda) x_j\), \(\tilde{y} = \lambda y_i + (1-\lambda) y_j\)), CutMix (pasting a rectangular patch from one image onto another and mixing labels proportionally to area), and RandAugment (randomly sampling a sequence of augmentations from a fixed set with a single strength parameter).

  • The history of CNN architectures is a story of progressively deeper, more efficient designs, each solving a problem that limited its predecessor.

  • LeNet-5 (LeCun et al., 1998) was the original CNN, designed for handwritten digit recognition. Two convolutional layers followed by three fully connected layers, with average pooling and tanh activations. It proved that learned filters outperform hand-designed features, but it was tiny by modern standards (60K parameters).

  • AlexNet (Krizhevsky et al., 2012) won the ImageNet competition by a massive margin, igniting the deep learning revolution. Key innovations: ReLU activation (instead of tanh, which suffers from vanishing gradients), dropout for regularisation, data augmentation, and training on GPUs. Five convolutional layers, three fully connected layers, 60 million parameters.

  • VGG (Simonyan and Zisserman, 2014) showed that using only 3x3 filters stacked deeply works better than larger filters. Two stacked 3x3 filters have the same receptive field as one 5x5 filter but fewer parameters (\(2 \times 3^2 = 18\) vs \(5^2 = 25\)) and an extra nonlinearity. VGG-16 (16 layers) and VGG-19 (19 layers) are still widely used as feature extractors. The architecture is remarkably simple: convolution blocks with increasing channels (64, 128, 256, 512), each followed by max pooling.

VGG architecture: stacked 3x3 conv blocks with increasing channel depth (64→128→256→512), max pooling between blocks, ending with fully connected layers

  • GoogLeNet/Inception (Szegedy et al., 2014) introduced the Inception module: instead of choosing a single filter size, use 1x1, 3x3, and 5x5 convolutions in parallel, concatenate their outputs, and let the network decide which scale is most useful. 1x1 convolutions are used as bottlenecks before the larger filters to reduce computation. GoogLeNet achieved better accuracy than VGG with 12x fewer parameters (6.8M vs 138M).

Inception module: four parallel branches (1×1, 3×3, 5×5, and pooling) with 1×1 bottlenecks, concatenated along the channel dimension

  • The Inception module captures features at multiple scales simultaneously. A 1x1 filter captures point-wise patterns, a 3x3 captures local texture, and a 5x5 captures larger structures. The concatenation combines all perspectives into a rich representation.

  • ResNet (He et al., 2016) solved the degradation problem: deeper networks performed worse than shallower ones, not because of overfitting, but because they were harder to optimise. The solution is the skip connection (residual connection):

\[\text{output} = F(x) + x\]
  • The layer learns the residual \(F(x) = \text{output} - x\). If the optimal transformation is close to identity (which is common in deep networks), learning a near-zero residual is much easier than learning the full mapping. Skip connections also provide a direct gradient highway, reducing vanishing gradients. ResNet trained networks with 152 layers, far deeper than anything before.

ResNet block: input x passes through two conv layers to produce F(x), then the skip connection adds x back, giving output F(x) + x

  • When the input and output dimensions differ (due to stride or channel change), a projection shortcut applies a 1x1 convolution to \(x\) to match dimensions: \(\text{output} = F(x) + W_s x\).

  • The bottleneck block (used in ResNet-50 and deeper) uses three convolutions: 1x1 to reduce channels, 3x3 for spatial processing, and 1x1 to expand channels back. This is cheaper than two 3x3 convolutions and allows much deeper networks.

  • DenseNet (Huang et al., 2017) takes the skip connection idea further: every layer is connected to every subsequent layer within a dense block. Layer \(l\) receives the feature maps from all preceding layers as input: \(x_l = H_l([x_0, x_1, \ldots, x_{l-1}])\), where \([\cdot]\) denotes concatenation along the channel dimension. This encourages feature reuse, strengthens gradient flow, and reduces the total number of parameters.

DenseNet dense block: every layer receives feature maps from all preceding layers via concatenation, creating dense connectivity for maximum feature reuse

  • Efficient architectures target deployment on mobile devices and edge hardware, where compute, memory, and energy are constrained.

  • MobileNet (Howard et al., 2017) replaces standard convolutions with depthwise separable convolutions, which factorise the operation into two steps:

    1. Depthwise convolution: apply a single \(k \times k\) filter per input channel (no cross-channel interaction)
    2. Pointwise convolution: apply 1x1 convolutions to combine information across channels
  • A standard \(k \times k\) convolution with \(C_{\text{in}}\) input channels and \(C_{\text{out}}\) output channels costs \(k^2 \cdot C_{\text{in}} \cdot C_{\text{out}}\) multiplications per spatial position. Depthwise separable convolution costs \(k^2 \cdot C_{\text{in}} + C_{\text{in}} \cdot C_{\text{out}}\), a reduction of roughly \(k^2\) times. For a 3x3 filter, this is approximately 9x cheaper.

Depthwise separable convolution: depthwise step applies one k×k filter per channel, then pointwise 1×1 convolutions mix channels — same output shape, ~9× fewer operations

  • MobileNet-V2 introduced the inverted residual block: expand channels with a 1x1 convolution, apply depthwise convolution in the expanded space, then project back down with a 1x1 convolution. The skip connection is placed on the narrow (bottleneck) layers, inverting the ResNet pattern. The expansion ratio is typically 6.

  • EfficientNet (Tan and Le, 2019) introduced compound scaling: instead of scaling only depth, only width, or only resolution independently, scale all three dimensions together using a fixed ratio. Given a scaling coefficient \(\phi\):

\[\text{depth}: d = \alpha^\phi, \quad \text{width}: w = \beta^\phi, \quad \text{resolution}: r = \gamma^\phi\]
  • subject to \(\alpha \cdot \beta^2 \cdot \gamma^2 \approx 2\) (so that total computation roughly doubles per unit increase in \(\phi\)). A grid search finds \(\alpha = 1.2\), \(\beta = 1.1\), \(\gamma = 1.15\) as the baseline ratios. EfficientNet-B0 through B7 scale up progressively, achieving state-of-the-art accuracy with far fewer parameters and FLOPs than previous models.

EfficientNet compound scaling: scaling width, depth, or resolution alone vs scaling all three together with a single coefficient φ

  • ShuffleNet reduces the cost of 1x1 convolutions (which dominate in MobileNet-style architectures) by using group convolutions followed by a channel shuffle. Group convolutions split channels into groups and convolve within each group independently, but this prevents cross-group information flow. The shuffle operation rearranges channels between groups, restoring the information mixing at negligible cost.

  • Transfer learning is the practice of taking a model trained on one task and adapting it to a different task. In computer vision, this almost always means starting from a model pre-trained on ImageNet (1.4 million images, 1,000 classes) and adapting to a domain-specific dataset (medical images, satellite images, manufacturing defects).

  • Feature extraction: freeze all convolutional layers, remove the final classification head, and train only a new head on top. The frozen layers act as a generic feature extractor. This works well when the target domain is similar to ImageNet and the target dataset is small.

  • Fine-tuning: unfreeze some or all convolutional layers and train with a small learning rate. The pre-trained weights serve as a starting point rather than fixed features. Fine-tuning typically starts by unfreezing only the later layers (which capture high-level, task-specific features) and optionally unfreezing earlier layers as well.

  • Transfer learning works because the early layers of a CNN learn universal features (edges, textures, colours) that are useful across tasks, while later layers learn task-specific features. A network trained to classify animals still has useful edge detectors for classifying buildings.

  • Visualising CNNs reveals what the network has learned and helps debug unexpected behaviour.

  • Activation maps (feature maps) show the output of each filter for a given input image. Early layer activations look like edge maps; deeper layers produce increasingly abstract, spatially coarse activations.

  • Grad-CAM (Gradient-weighted Class Activation Mapping, Selvaraju et al., 2017) highlights the regions of the input image that were most important for the model's prediction. It works by:

    1. Computing the gradient of the target class score with respect to the feature maps of the last convolutional layer (using the chain rule from chapter 03)
    2. Global average pooling these gradients to get per-channel importance weights
    3. Computing a weighted combination of the feature maps and applying ReLU
\[L_{\text{Grad-CAM}} = \text{ReLU}\!\left(\sum_k \alpha_k A^k\right), \quad \alpha_k = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A^k_{ij}}\]
  • where \(A^k\) is the \(k\)-th feature map, \(\alpha_k\) is the importance weight for channel \(k\), and \(y^c\) is the score for class \(c\). The result is a coarse heatmap showing which regions drove the classification. ReLU is applied because we are interested in features that have a positive influence on the class.

Grad-CAM: input image of a dog, feature maps from the last conv layer, gradient-weighted combination, and the resulting heatmap overlaid on the original image highlighting the dog's face

  • Feature inversion reconstructs an input image from its feature representation by optimising a random image to match the target features (using gradient descent on the pixel values). This reveals what information the network retains at each layer. Early layers reconstruct near-perfect images; deeper layers produce recognisable but distorted images, showing that fine spatial detail is lost while semantic content is preserved.

  • Deep Dream and neural style transfer are creative applications of feature visualisation. Deep Dream maximises the activation of neurons at a chosen layer to produce surreal, pattern-amplified images. Neural style transfer optimises a target image to match the content features (from a deep layer) of one image and the style features (Gram matrix of filter activations, which captures texture statistics) of another.

Coding Tasks (use CoLab or notebook)

  1. Implement a simple CNN from scratch in JAX with two convolutional layers, max pooling, and a classification head. Train it on a synthetic 2D pattern classification task.

    import jax
    import jax.numpy as jnp
    import jax.lax as lax
    import matplotlib.pyplot as plt
    
    def conv2d(x, kernel, stride=1):
        """Simple 2D convolution for single input, single filter."""
        return lax.conv(x[None, None], kernel[None, None], (stride, stride), 'SAME')[0, 0]
    
    def max_pool(x, size=2):
        """2x2 max pooling."""
        H, W = x.shape
        x = x[:H//size*size, :W//size*size]
        return x.reshape(H//size, size, W//size, size).max(axis=(1, 3))
    
    def init_cnn(key):
        k1, k2, k3 = jax.random.split(key, 3)
        return {
            'conv1': jax.random.normal(k1, (5, 5)) * 0.3,
            'conv2': jax.random.normal(k2, (3, 3)) * 0.3,
            'fc_w': jax.random.normal(k3, (64, 1)) * 0.1,
            'fc_b': jnp.zeros(1),
        }
    
    def forward_cnn(params, img):
        # Conv1 -> ReLU -> Pool
        h = jnp.maximum(0, conv2d(img, params['conv1']))
        h = max_pool(h)
        # Conv2 -> ReLU -> Pool
        h = jnp.maximum(0, conv2d(h, params['conv2']))
        h = max_pool(h)
        # Flatten and classify
        flat = h.ravel()
        # Pad or truncate to fixed size
        flat = jnp.pad(flat, (0, max(0, 64 - len(flat))))[:64]
        logit = (flat @ params['fc_w'] + params['fc_b']).squeeze()
        return jax.nn.sigmoid(logit)
    
    # Generate synthetic data: class 0 = low-freq pattern, class 1 = high-freq
    def make_data(key, n=200):
        images, labels = [], []
        for i in range(n):
            k1, key = jax.random.split(key)
            x, y = jnp.meshgrid(jnp.linspace(0, 4*jnp.pi, 32), jnp.linspace(0, 4*jnp.pi, 32))
            if i < n // 2:
                img = jnp.sin(x) + jax.random.normal(k1, (32, 32)) * 0.1
                labels.append(0)
            else:
                img = jnp.sin(4 * x) * jnp.sin(4 * y) + jax.random.normal(k1, (32, 32)) * 0.1
                labels.append(1)
            images.append(img)
        return images, jnp.array(labels, dtype=jnp.float32)
    
    key = jax.random.PRNGKey(42)
    images, labels = make_data(key)
    params = init_cnn(jax.random.PRNGKey(0))
    
    def loss_fn(params, img, label):
        pred = forward_cnn(params, img)
        return -(label * jnp.log(pred + 1e-7) + (1 - label) * jnp.log(1 - pred + 1e-7))
    
    grad_fn = jax.grad(loss_fn)
    lr = 0.01
    
    for epoch in range(5):
        total_loss = 0.0
        for img, label in zip(images, labels):
            grads = grad_fn(params, img, label)
            params = {k: params[k] - lr * grads[k] for k in params}
            total_loss += loss_fn(params, img, label)
        print(f"Epoch {epoch}: loss = {total_loss / len(images):.4f}")
    
    # Test accuracy
    preds = jnp.array([forward_cnn(params, img) > 0.5 for img in images])
    acc = jnp.mean(preds == labels)
    print(f"Accuracy: {acc:.2%}")
    

  2. Visualise how different filter sizes affect the receptive field. Show that two stacked 3x3 filters cover the same receptive field as one 5x5 filter but with fewer parameters.

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def compute_receptive_field(layers):
        """Compute receptive field size from a list of (kernel_size, stride) tuples."""
        rf = 1  # start with 1 pixel
        stride_product = 1
        for k, s in layers:
            rf += (k - 1) * stride_product
            stride_product *= s
        return rf
    
    # Compare architectures
    configs = {
        'Single 5x5': [(5, 1)],
        'Two 3x3':    [(3, 1), (3, 1)],
        'Three 3x3':  [(3, 1), (3, 1), (3, 1)],
        'Single 7x7': [(7, 1)],
        '3x3 stride 2 + 3x3': [(3, 2), (3, 1)],
    }
    
    print(f"{'Config':<25} {'RF':>4} {'Params (per channel)':>20}")
    print('-' * 55)
    for name, layers in configs.items():
        rf = compute_receptive_field(layers)
        # Parameters: sum of k^2 for each layer (per input-output channel pair)
        params = sum(k * k for k, s in layers)
        print(f"{name:<25} {rf:>4} {params:>20}")
    
    # Visualise receptive fields
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    for ax, (name, rf_size) in zip(axes, [('5x5 filter', 5), ('Two 3x3 filters', 5), ('Three 3x3 filters', 7)]):
        grid = jnp.zeros((9, 9))
        c = 4  # centre
        half = rf_size // 2
        grid = grid.at[c-half:c+half+1, c-half:c+half+1].set(1.0)
        ax.imshow(grid, cmap='Blues', vmin=0, vmax=1)
        ax.set_title(f'{name}\nRF = {rf_size}x{rf_size}')
        ax.set_xticks(range(9)); ax.set_yticks(range(9))
        ax.grid(True, alpha=0.3)
    plt.suptitle('Receptive Field Comparison')
    plt.tight_layout(); plt.show()
    

  3. Implement Grad-CAM from scratch. Given a pre-built simple CNN, compute the gradient-weighted activation map for a specific class and visualise it as a heatmap.

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def simple_cnn(params, img):
        """Simple CNN that returns both the prediction and last conv activations."""
        # Conv layer (our "last conv layer" for Grad-CAM)
        H, W = img.shape
        k = params['conv'].shape[0]
        pad = k // 2
        img_pad = jnp.pad(img, pad, mode='edge')
        activation_map = jnp.zeros((H, W))
        for i in range(H):
            for j in range(W):
                activation_map = activation_map.at[i, j].set(
                    jnp.sum(img_pad[i:i+k, j:j+k] * params['conv'])
                )
        activation_map = jnp.maximum(0, activation_map)  # ReLU
    
        # Global average pool -> dense -> output
        pooled = activation_map.mean()
        logit = pooled * params['w'] + params['b']
        return jax.nn.sigmoid(logit), activation_map
    
    # Create test image: bright region on the left (class indicator)
    img = jnp.zeros((32, 32))
    img = img.at[8:24, 4:16].set(1.0)
    img = img.at[5:10, 20:28].set(0.3)
    
    key = jax.random.PRNGKey(42)
    params = {
        'conv': jax.random.normal(key, (5, 5)) * 0.3,
        'w': jnp.array(2.0),
        'b': jnp.array(-0.5),
    }
    
    # Compute Grad-CAM
    def class_score(params, img):
        pred, _ = simple_cnn(params, img)
        return pred
    
    # Get activation map and gradients
    pred, act_map = simple_cnn(params, img)
    grad_fn = jax.grad(lambda img: simple_cnn(params, img)[0])
    img_grad = grad_fn(img)
    
    # Weight = global average of gradients (simplified 1-channel Grad-CAM)
    alpha = img_grad.mean()
    grad_cam = jnp.maximum(0, alpha * act_map)  # ReLU
    grad_cam = (grad_cam - grad_cam.min()) / (grad_cam.max() - grad_cam.min() + 1e-8)
    
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    axes[0].imshow(img, cmap='gray'); axes[0].set_title('Input Image'); axes[0].axis('off')
    axes[1].imshow(act_map, cmap='viridis'); axes[1].set_title('Activation Map'); axes[1].axis('off')
    axes[2].imshow(img, cmap='gray', alpha=0.6)
    axes[2].imshow(grad_cam, cmap='jet', alpha=0.4)
    axes[2].set_title(f'Grad-CAM (pred={pred:.2f})'); axes[2].axis('off')
    plt.tight_layout(); plt.show()
    

  4. Compare depthwise separable convolution with standard convolution. Count the parameters and FLOPs for both and show they produce similar outputs with far less computation.

    import jax
    import jax.numpy as jnp
    
    def standard_conv(x, kernel):
        """Standard convolution: (H, W, C_in) * (k, k, C_in, C_out) -> (H, W, C_out)."""
        H, W, C_in = x.shape
        k, _, _, C_out = kernel.shape
        pad = k // 2
        x_pad = jnp.pad(x, ((pad, pad), (pad, pad), (0, 0)), mode='constant')
        out = jnp.zeros((H, W, C_out))
        for i in range(H):
            for j in range(W):
                patch = x_pad[i:i+k, j:j+k, :]  # (k, k, C_in)
                for c in range(C_out):
                    out = out.at[i, j, c].set(jnp.sum(patch * kernel[:, :, :, c]))
        return out
    
    def depthwise_separable_conv(x, dw_kernel, pw_kernel):
        """Depthwise separable: depthwise (k,k,C_in) then pointwise (C_in, C_out)."""
        H, W, C_in = x.shape
        k = dw_kernel.shape[0]
        pad = k // 2
        x_pad = jnp.pad(x, ((pad, pad), (pad, pad), (0, 0)), mode='constant')
    
        # Depthwise: one filter per channel
        dw_out = jnp.zeros((H, W, C_in))
        for i in range(H):
            for j in range(W):
                for c in range(C_in):
                    patch = x_pad[i:i+k, j:j+k, c]
                    dw_out = dw_out.at[i, j, c].set(jnp.sum(patch * dw_kernel[:, :, c]))
    
        # Pointwise: 1x1 conv across channels
        out = dw_out @ pw_kernel
        return out
    
    # Setup
    H, W, C_in, C_out, k = 8, 8, 16, 32, 3
    key = jax.random.PRNGKey(42)
    k1, k2, k3, k4 = jax.random.split(key, 4)
    
    x = jax.random.normal(k1, (H, W, C_in))
    std_kernel = jax.random.normal(k2, (k, k, C_in, C_out)) * 0.1
    dw_kernel = jax.random.normal(k3, (k, k, C_in)) * 0.1
    pw_kernel = jax.random.normal(k4, (C_in, C_out)) * 0.1
    
    # Compare
    std_params = k * k * C_in * C_out
    dw_params = k * k * C_in + C_in * C_out
    
    std_flops = H * W * k * k * C_in * C_out
    dw_flops = H * W * (k * k * C_in + C_in * C_out)
    
    print(f"Standard conv:            {std_params:>8,} params,  {std_flops:>10,} FLOPs")
    print(f"Depthwise separable conv: {dw_params:>8,} params,  {dw_flops:>10,} FLOPs")
    print(f"Parameter reduction:      {std_params / dw_params:.1f}x")
    print(f"FLOP reduction:           {std_flops / dw_flops:.1f}x")
    
    std_out = standard_conv(x, std_kernel)
    ds_out = depthwise_separable_conv(x, dw_kernel, pw_kernel)
    print(f"\nStandard output shape:    {std_out.shape}")
    print(f"Depthwise sep output shape: {ds_out.shape}")