Skip to content

Text to Speech and Voice

Text-to-speech synthesis reverses the ASR pipeline, generating natural-sounding audio from written text. This file covers the TTS pipeline (text normalisation, G2P, acoustic models, vocoders), Tacotron, WaveNet, HiFi-GAN, voice cloning, voice conversion, and voice activity detection (VAD).

  • In file 01, we built the signal-processing toolkit: waveforms, spectrograms, mel filterbanks, and MFCCs. In file 02, we turned speech into text. Now we reverse the arrow: given text, synthesise natural-sounding speech. This is text-to-speech (TTS), a problem that also opens the door to voice conversion, voice cloning, and voice activity detection.

  • Think of TTS like a stage performance. The script is the text input. A director (the acoustic model) decides how each line should sound, its pitch, timing, emphasis. The orchestra (the vocoder) then performs the score, producing the actual sound waves the audience hears. Modern neural TTS replaces the stiff, robotic delivery of rule-based systems with performances that rival human speakers.

TTS pipeline: text is normalised, converted to phonemes, processed by an acoustic model to produce a mel spectrogram, then passed through a vocoder to generate the final waveform

  • Text-to-speech pipeline the standard TTS pipeline has four stages: (1) text normalisation, (2) phoneme conversion, (3) acoustic model, and (4) vocoder. Some modern systems collapse stages 3 and 4 into a single end-to-end model, but the conceptual decomposition remains useful.

  • Text normalisation converts raw text into a pronounceable form. Abbreviations expand ("Dr." to "Doctor"), numbers become words ("1984" to "nineteen eighty-four"), currency symbols are verbalised ("$5" to "five dollars"), and URLs or special characters are handled. This stage is often rule-based with language-specific grammars, though neural normalisation models exist. Errors here propagate to every downstream stage: if "St." is read as "saint" instead of "street", the entire utterance is wrong.

  • Grapheme-to-phoneme (G2P) conversion maps normalised text to a phoneme sequence. English is notoriously irregular ("though", "through", "tough" all use "ough" differently), so dictionary lookup (the CMU Pronouncing Dictionary) handles common words while a neural sequence-to-sequence model (chapter 06's encoder-decoder or chapter 07's transformer) handles out-of-vocabulary words. Languages with shallow orthographies (Spanish, Finnish) need simpler G2P. The output is typically an IPA (International Phonetic Alphabet) sequence or an equivalent internal phoneme set.

  • Acoustic models consume the phoneme sequence and produce an intermediate acoustic representation, almost always a mel spectrogram (file 01). The mel spectrogram captures the spectral envelope at each time frame, which encodes the perceptually relevant information a vocoder needs to reconstruct the waveform. The acoustic model must decide timing (how long each phoneme lasts), pitch (fundamental frequency \(F_0\)), and energy (loudness).

  • Vocoders take the mel spectrogram and produce the raw audio waveform. This is an ill-posed inversion problem: many waveforms can produce the same spectrogram because phase information was discarded. Classical vocoders (Griffin-Lim, WORLD) use iterative or signal-model approaches, but neural vocoders now dominate in quality.

  • Vocoders: WaveNet (van den Oord et al., 2016) was the first neural vocoder to produce speech nearly indistinguishable from human recordings. It models the waveform autoregressively, predicting each sample \(x_t\) conditioned on all previous samples:

\[P(x) = \prod_{t=1}^{T} P(x_t \mid x_1, \ldots, x_{t-1}, c)\]
  • where \(c\) is the conditioning signal (mel spectrogram). Each sample is 16-bit, so a naive softmax over 65536 values is impractical. WaveNet uses mu-law companding to reduce to 256 quantisation levels, or later variants use a mixture of logistics distribution.

  • WaveNet's core building block is the dilated causal convolution. Causal means filter weights only look at past samples (no future leakage). Dilated means the filter skips samples with exponentially increasing gaps: dilation factors \(1, 2, 4, 8, \ldots, 512\). This gives an exponentially large receptive field while keeping the parameter count linear.

  • The gated activation for each layer is:

\[z = \tanh(W_{f} \ast x) \odot \sigma(W_{g} \ast x)\]
  • where \(W_f\) and \(W_g\) are filter and gate convolution weights, \(\ast\) denotes dilated causal convolution, and \(\odot\) is element-wise multiplication. This gating mechanism (from chapter 06's LSTMs) allows the network to control information flow.

  • WaveNet produces exceptional quality but is painfully slow at inference: generating one second of 24 kHz audio requires 24000 sequential forward passes. This motivated all subsequent vocoder research.

  • WaveRNN (Kalchbrenner et al., 2018) replaces WaveNet's deep convolutional stack with a single-layer recurrent network. It splits each 16-bit sample into coarse (upper 8 bits) and fine (lower 8 bits) components, predicting each with a GRU (chapter 06). This dual softmax approach reduces computation significantly while maintaining high quality. WaveRNN is fast enough for real-time on mobile CPUs with careful kernel optimisation.

  • WaveGlow (Prenger et al., 2019) is a flow-based vocoder that avoids autoregressive generation entirely. It uses a sequence of invertible transformations (affine coupling layers, chapter 06's normalising flows) to map a simple Gaussian distribution to the waveform distribution. Training maximises the exact log-likelihood using the change-of-variables formula:

\[\log P(x) = \log P(z) + \sum_{i} \log \left| \det \frac{\partial f_i}{\partial f_{i-1}} \right|\]
  • where \(z = f(x)\) is the latent variable obtained by passing \(x\) through the flow. At inference, a sample \(z \sim \mathcal{N}(0, I)\) is drawn and pushed through the inverted flow in a single parallel pass. WaveGlow trades model size (large networks for the coupling layers) for generation speed.

  • HiFi-GAN (Kong et al., 2020) uses a generative adversarial network to synthesise waveforms from mel spectrograms. The generator upsamples the mel spectrogram through a series of transposed convolutions, each followed by a multi-receptive field fusion (MRF) module. The MRF module applies multiple residual blocks with different kernel sizes and dilation rates in parallel, then sums their outputs. This allows the generator to capture patterns at multiple time scales simultaneously.

HiFi-GAN generator architecture: mel spectrogram input passes through transposed convolution upsampling layers, each followed by multi-receptive field fusion blocks that combine parallel residual stacks with different dilation patterns

  • HiFi-GAN uses two discriminator types. The multi-period discriminator (MPD) reshapes the 1D waveform into 2D by folding it at different periods (2, 3, 5, 7, 11), then applies 2D convolutions. This captures periodic structures at different fundamental frequencies. The multi-scale discriminator (MSD) operates on the raw waveform, 2x downsampled, and 4x downsampled versions, capturing patterns at different temporal resolutions.

  • The training objective combines adversarial loss, mel spectrogram reconstruction loss (L1 distance between the mel spectrogram of synthesised and ground truth audio), and feature matching loss (L1 distance between intermediate discriminator features):

\[\mathcal{L}_G = \mathcal{L}_{\text{adv}}(G) + \lambda_{\text{mel}} \mathcal{L}_{\text{mel}}(G) + \lambda_{\text{fm}} \mathcal{L}_{\text{fm}}(G)\]
  • HiFi-GAN achieves synthesis quality comparable to WaveNet while being over 1000x faster, enabling real-time generation on a single GPU.

  • Neural source-filter (NSF) models combine traditional signal processing with neural networks. In the classical source-filter model, voiced speech is produced by a source excitation (periodic pulse train at the fundamental frequency \(F_0\)) passed through a vocal tract filter (the spectral envelope). NSF models replace the handcrafted filter with a neural network while keeping the explicit source signal. The input \(F_0\) contour provides fine pitch control that purely data-driven vocoders sometimes struggle with.

  • Acoustic models: Tacotron (Wang et al., 2017) was the first end-to-end neural TTS system that directly converted character sequences to mel spectrograms. It uses an encoder-decoder architecture with attention (chapter 07). The encoder processes the character/phoneme sequence with a convolution bank, highway network, and bidirectional GRU. The decoder is an autoregressive GRU that predicts mel frames one at a time, using the previous frame and the attention context as input.

  • Tacotron 2 (Shen et al., 2018) refines the architecture significantly. The encoder is a 3-layer 1D convolution stack followed by a bidirectional LSTM (chapter 06). The decoder is a 2-layer LSTM with location-sensitive attention, which conditions the attention mechanism not only on the encoder outputs and decoder state but also on the cumulative attention weights from previous steps. This prevents the common failure mode of attention skipping or repeating words.

Tacotron 2 architecture: character/phoneme encoder with convolution layers and BiLSTM, location-sensitive attention aligning to mel spectrogram frames, autoregressive decoder with stop token prediction

  • The location-sensitive attention energy for encoder position \(j\) at decoder step \(i\) is:
\[e_{i,j} = w^T \tanh(W_s s_{i-1} + W_h h_j + W_f f_{i,j} + b)\]
  • where \(s_{i-1}\) is the previous decoder state, \(h_j\) is the encoder output at position \(j\), and \(f_{i,j}\) is the location feature obtained by convolving the cumulative attention weights \(\sum_{k<i} \alpha_{k,j}\) with a 1D convolution filter. The attention weights are \(\alpha_{i,j} = \text{softmax}(e_{i,j})\).

  • Tacotron 2's decoder also predicts a stop token probability at each step, indicating when the mel spectrogram is complete. The output mel spectrogram is then passed to a vocoder (originally WaveNet, later replaced by HiFi-GAN or similar).

  • The autoregressive nature of Tacotron 2 means synthesis speed is limited by the number of mel frames. For a typical 80-frame-per-second mel spectrogram, a 5-second utterance requires 400 sequential decoder steps.

  • FastSpeech (Ren et al., 2019) solves the speed problem with a non-autoregressive acoustic model. Instead of generating mel frames sequentially, FastSpeech generates all frames in parallel. The key challenge is determining how many mel frames each phoneme should produce, which FastSpeech handles with a duration predictor.

  • The duration predictor is a small convolutional network that predicts the integer duration (number of mel frames) for each phoneme. During training, ground-truth durations are extracted from a pre-trained autoregressive teacher model (Tacotron 2) using its attention alignments. During inference, the predicted durations are used to expand the phoneme-level hidden sequence to the frame level using a length regulator that simply repeats each phoneme's hidden representation for the predicted number of frames.

  • FastSpeech 2 (Ren et al., 2021) improves on FastSpeech by removing the teacher-student distillation. It extracts ground-truth durations directly using forced alignment (from file 02's acoustic model frameworks) and adds explicit variance adaptors for pitch (\(F_0\)) and energy in addition to duration. Each adaptor is a small convolutional predictor whose output conditions the decoder:

\[ \begin{aligned} \hat{d}_i &= \text{DurationPredictor}(h_i) \\ \hat{p}_i &= \text{PitchPredictor}(h_i) \\ \hat{e}_i &= \text{EnergyPredictor}(h_i) \end{aligned} \]
  • where \(h_i\) is the encoder hidden state for phoneme \(i\). At training time, ground-truth values are used; at inference, the predicted values give explicit control over prosody. This controllability is a major advantage of FastSpeech 2: adjusting pitch, speed, or energy is as simple as scaling the predictor outputs.

  • FastSpeech 2 is typically 10-20x faster than Tacotron 2 at inference and avoids common autoregressive failure modes like word skipping, repetition, and attention collapse.

  • VITS (Kim et al., 2021) is an end-to-end TTS model that directly generates waveforms from text, eliminating the separate vocoder stage. VITS combines a conditional variational autoencoder (chapter 06) with normalising flows and adversarial training. The posterior encoder maps ground-truth mel spectrograms to a latent space, the prior encoder maps phonemes (through a transformer-based text encoder and duration predictor) to the same latent space, and the decoder (HiFi-GAN-based) generates waveforms from latent samples.

  • The training objective for VITS combines:

    • Reconstruction loss: the VAE forces the latent distribution to encode acoustic information
    • KL divergence: aligns the text-conditioned prior with the audio-conditioned posterior
    • Adversarial loss: discriminators ensure waveform quality
    • Duration loss: trains the stochastic duration predictor
  • VITS produces higher quality than two-stage systems (FastSpeech 2 + HiFi-GAN) because the acoustic model and vocoder are jointly optimised, avoiding the mismatch between predicted and ground-truth mel spectrograms that degrades two-stage systems.

  • VALL-E (Wang et al., 2023) radically reframes TTS as a language modelling problem over discrete audio tokens. It uses a neural audio codec (EnCodec) to represent speech as a sequence of discrete codes from multiple codebook levels. Given a text prompt and a 3-second enrollment utterance (also encoded as discrete tokens), VALL-E uses a transformer language model to predict the audio tokens autoregressively.

  • VALL-E uses two models: an autoregressive (AR) model that generates the first codebook level token-by-token, and a non-autoregressive (NAR) model that predicts the remaining codebook levels in parallel, conditioned on the first level and each other. This codec language model approach enables remarkable zero-shot voice cloning: a 3-second sample is enough to reproduce a speaker's voice, timbre, and even emotional tone.

  • StyleTTS (Li et al., 2022) and StyleTTS 2 disentangle speech into content and style components. A style encoder extracts a style vector from reference audio, capturing speaker identity, prosody, and recording conditions. During inference, style can be sampled from a learned prior distribution or transferred from a reference utterance. StyleTTS 2 uses diffusion models (chapter 08) for the style prior, generating diverse and natural prosody.

  • Kokoro (2024) is a lightweight, high-quality open-source TTS model notable for its small size (~82M parameters) and impressive naturalness. It uses a StyleTTS 2-inspired architecture with a diffusion-based style prior and a fine-tuned ISTFTNet vocoder that directly predicts STFT coefficients (from file 01) rather than raw waveform samples. Despite being a fraction of the size of models like VALL-E, Kokoro achieves near-human naturalness for English, Japanese, French, Korean, and Chinese, demonstrating that carefully curated training data and efficient architecture design can compete with brute-force scale. Kokoro's small footprint makes it practical for local and edge deployment.

  • Orpheus (Canopy Labs, 2025) is a family of open-source TTS models (1B and 3B parameters) built on the codec language model paradigm pioneered by VALL-E. Orpheus takes the idea further with an LLM backbone (fine-tuned Llama 3) that generates SNAC audio codec tokens directly. Its standout feature is human-like emotional expressiveness: it handles laughter, sighs, hesitations, and affective prosody with remarkable naturalness. Orpheus can be prompted with tags like [laugh] or [sigh] in the input text, giving fine-grained control over paralinguistic expression.

  • Dia (Nari Labs, 2025) is an open-source dialogue TTS model that generates realistic multi-speaker conversations from a single text transcript. Built on a 1.6B-parameter encoder-decoder transformer, Dia handles turn-taking, speaker-specific voices, and non-verbal cues (laughter, pauses) within a conversation. It also supports voice cloning from a short audio prompt, enabling zero-shot speaker generation in dialogue context.

  • Sesame CSM (Conversational Speech Model, 2025) focuses on natural multi-turn conversational speech. Rather than optimising for reading-style TTS, Sesame models the dynamics of real conversation: backchannels ("uh huh"), interruptions, rhythm changes between speakers, and emotional responsiveness. The model uses a transformer backbone conditioned on conversational context (both text and audio history), producing speech that adapts its style to the flow of the dialogue.

  • Fish Speech (Fish Audio, 2024) is an open-source TTS system that uses a dual autoregressive architecture: a large language model generates semantic tokens from text, and a smaller model converts these to VQGAN acoustic tokens, which are decoded into waveforms by a vocoder. Fish Speech supports zero-shot voice cloning from a 10-15 second reference and achieves low latency suitable for real-time applications. Its modular design allows swapping components (e.g., different vocoders) independently.

  • ChatTTS (2024) is an open-source conversational TTS model designed for dialogue applications like chatbots and virtual assistants. It generates natural, conversational-sounding speech with fine-grained control over prosodic features (laughter, pauses, filler words) using special tokens embedded in the text input. ChatTTS supports mixed Chinese-English synthesis and multi-speaker generation.

  • Bark (Suno, 2023) is a transformer-based open-source model that generates speech, music, and sound effects from text prompts. It uses a three-stage pipeline of transformer models (text → semantic tokens → coarse acoustic tokens → fine acoustic tokens) and supports voice cloning, multilingual synthesis, and non-speech audio like music and ambient sounds. Bark's generality comes at the cost of controllability — it is less precise than dedicated TTS systems but more flexible.

  • Parler-TTS (Hugging Face, 2024) takes a natural language description approach to voice control: instead of requiring a reference audio clip for style, the user provides a text description like "a female speaker with a warm, expressive voice in a quiet room." Parler-TTS is trained on annotated speech data where each utterance is paired with a natural language description of the speaking style, enabling intuitive control without any reference audio.

  • Neuphonic is an API-based TTS platform optimised for ultra-low-latency speech synthesis, targeting real-time voice agents and conversational AI applications. It achieves time-to-first-audio under 100 ms through a streaming architecture that begins generating audio before the full input text is available. Neuphonic focuses on the deployment and latency optimisation layer rather than novel model architecture, providing production-grade infrastructure around modern neural TTS.

  • KittenTTS is a compact, fast TTS model designed for efficiency and low-resource deployment. It prioritises minimal latency and small model size for edge and embedded applications, trading some naturalness for real-time performance on CPUs and mobile devices.

  • The modern TTS landscape is bifurcating into two paradigms: (1) codec language models (VALL-E, Orpheus, Fish Speech) that treat speech generation as next-token prediction over discrete audio codes, leveraging the scaling laws of LLMs; and (2) flow/diffusion-based models (VITS, StyleTTS 2, Kokoro) that generate continuous mel spectrograms or waveforms through iterative refinement. Codec LMs excel at zero-shot cloning and expressiveness; flow/diffusion models tend to be smaller and faster. Both are rapidly converging toward human-level naturalness.

  • Prosody modelling controls the "music" of speech: pitch, duration, energy, rhythm, and intonation. Without good prosody, synthesised speech sounds flat and robotic even if individual phonemes are clear. Think of prosody as the difference between a monotone GPS voice and an expressive audiobook narrator.

  • Pitch (fundamental frequency \(F_0\)) is the perceived highness or lowness of speech. It rises at the end of questions, falls at the end of statements, and varies continuously during emotional speech. \(F_0\) is extracted from audio using algorithms like CREPE (a neural pitch tracker) or YIN (autocorrelation-based, from file 01). In TTS, pitch is either predicted by the acoustic model (FastSpeech 2's pitch predictor) or implicitly learned (Tacotron 2).

  • Duration determines the speaking rate and rhythm. Stressed syllables are longer, function words are shortened, and pauses mark phrase boundaries. Duration modelling is explicit in non-autoregressive models (FastSpeech) and implicit in autoregressive models (Tacotron's attention alignment determines duration).

  • Energy (loudness) carries emphasis. "I didn't say HE stole it" vs "I didn't say he STOLE it" have different meanings conveyed entirely through energy patterns.

  • Style embeddings capture higher-level prosodic patterns. The Global Style Token (GST) framework (Wang et al., 2018) learns a bank of style tokens (soft attention over a learned set of embeddings) that capture speaking styles like "excited", "sad", or "whispering". The style embedding is extracted from a reference utterance and added to the encoder output, allowing style transfer at inference.

  • Voice conversion (VC) changes the speaker identity of an utterance while preserving the linguistic content. Imagine recording yourself and having the output sound like a specific target speaker. VC requires disentangling speaker identity from content.

Voice conversion pipeline: source speech is decomposed into content representation and speaker embedding, the target speaker embedding replaces the source, and the decoder reconstructs speech in the target voice

  • Speaker embeddings (detailed further in file 04) encode speaker identity as a fixed-dimensional vector. These can come from a pre-trained speaker verification model (x-vectors, ECAPA-TDNN). In VC, the source speech is encoded into a content representation that is speaker-independent, then decoded with the target speaker embedding.

  • Disentangled representations separate speech into independent factors: content (phonemes), speaker identity, pitch, and rhythm. Approaches include:

    • Information bottleneck: compress the content representation so tightly that speaker information is lost (AutoVC)
    • Adversarial training: train a speaker classifier on the content representation and use gradient reversal to remove speaker information
    • Vector quantisation: VQ-VAE forces the content through a discrete bottleneck, which naturally strips speaker identity (since codebook entries represent phonetic categories, not speaker traits)
  • Voice cloning synthesises speech in a target speaker's voice. Multi-speaker TTS trains on data from many speakers, conditioning the model on a speaker embedding. At inference, a new speaker's embedding is extracted from enrollment audio and used to condition generation.

  • Few-shot voice cloning adapts to a new speaker using a small amount of data (a few minutes). The speaker encoder extracts an embedding from the enrollment audio, and the TTS model generates speech conditioned on this embedding. This is the approach used in SV2TTS (Jia et al., 2018): a separately trained speaker encoder, a Tacotron 2 synthesiser conditioned on the speaker embedding, and a WaveRNN vocoder.

  • Zero-shot voice cloning requires no adaptation at all: a single short utterance (3-30 seconds) is enough. VALL-E achieves this by treating the enrollment audio as a prompt for the language model. The model learns to continue generating in the same voice because it was trained on large-scale multi-speaker data where voice consistency within an utterance is the statistical norm.

  • Voice activity detection (VAD) answers a simple binary question at each time frame: is someone speaking or not? Despite its simplicity, VAD is a critical preprocessing step for ASR (file 02), speaker diarisation (file 04), and noise reduction (file 05). A good VAD reduces computation by skipping silence and improves accuracy by preventing noise from being processed as speech.

  • Classical VAD uses energy thresholding (speech is louder than silence), zero-crossing rate (speech has characteristic crossing patterns), and spectral features. These fail in noisy environments where the signal-to-noise ratio is low.

  • Neural VAD models treat the problem as frame-level binary classification. A small RNN or CNN takes acoustic features (log mel energies from file 01) and predicts speech/non-speech probabilities.

  • WebRTC VAD (Google) is a classic lightweight VAD using a GMM-based classifier on simple spectral features. It operates at four aggressiveness levels (0-3) and is extremely fast, but struggles with music, non-speech vocalisations, and low-SNR environments. It remains widely used as a baseline due to its zero-dependency simplicity.

  • Silero VAD (Silero Team, 2021) is the de facto standard neural VAD for production use. Its architecture is a small stack of depthwise separable 1D convolutions (chapter 08's MobileNet idea applied to audio) followed by a single LSTM layer for temporal context, with a final linear head producing a speech probability per frame. The entire model is under 2MB (~1M parameters) and processes audio in 30-100 ms chunks.

    • Input: raw 16 kHz audio (no manual feature extraction — the convolutional front-end learns its own features from the waveform directly).
    • Windowed stateful inference: the LSTM hidden state carries over between chunks, so the model handles streaming audio without reprocessing the full history. Each call processes a 30, 60, or 100 ms chunk and returns a speech probability in \([0, 1]\).
    • Adaptive thresholding: rather than a single fixed threshold, Silero VAD uses separate start and end thresholds with a minimum speech/silence duration, preventing rapid toggling on noisy boundaries. A speech segment must exceed the start threshold for a minimum duration before being confirmed, and silence must persist below the end threshold before the segment is closed.
    • Performance: Silero VAD runs at 1-2% real-time factor on CPU (processing 1 second of audio takes ~10-20 ms), making it suitable for edge devices, mobile phones, and real-time pipelines. It significantly outperforms WebRTC VAD on noisy and music-heavy audio while remaining small enough for on-device deployment.
    • Silero VAD is commonly used as the front-end for Whisper (file 02) to segment long audio into utterance-level chunks before transcription, and for speaker diarisation pipelines (file 04) to identify speech regions before extracting speaker embeddings.
  • Acoustic activity detection (AAD) generalises VAD to detect any acoustic activity, not just speech. This is useful in smart home devices, security systems, and wildlife monitoring. AAD models detect events like glass breaking, dogs barking, or alarms, often using the audio classification frameworks described in file 04.

  • Evaluation metrics for TTS measure both objective quality and subjective naturalness:

    • Mean Opinion Score (MOS): human listeners rate naturalness on a 1-5 scale. The gold standard, but expensive and slow.
    • Mel cepstral distortion (MCD): measures the distance between synthesised and reference mel cepstra. Lower is better, but does not always correlate with perception.
    • PESQ / POLQA: standardised perceptual evaluation metrics originally designed for telephony.
    • Speaker similarity: cosine similarity between speaker embeddings of synthesised and reference audio (relevant for voice cloning).
    • Intelligibility: measured by feeding synthesised audio through an ASR system (file 02) and computing WER.

Coding Tasks (use CoLab or notebook)

  • Task 1: Griffin-Lim vocoder from mel spectrogram. Implement the Griffin-Lim iterative phase reconstruction algorithm to convert a mel spectrogram back to a waveform. This demonstrates the vocoder problem and why neural vocoders are needed.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Generate a synthetic waveform (sum of harmonics simulating a vowel)
sr = 16000
duration = 1.0
t = jnp.linspace(0, duration, int(sr * duration))
f0 = 220.0  # fundamental frequency
waveform = (
    0.6 * jnp.sin(2 * jnp.pi * f0 * t) +
    0.3 * jnp.sin(2 * jnp.pi * 2 * f0 * t) +
    0.1 * jnp.sin(2 * jnp.pi * 3 * f0 * t)
)

# Compute STFT
n_fft = 1024
hop_length = 256
window = jnp.hanning(n_fft)

def stft(signal, n_fft, hop_length, window):
    """Compute Short-Time Fourier Transform."""
    n_frames = 1 + (len(signal) - n_fft) // hop_length
    frames = jnp.stack([
        signal[i * hop_length : i * hop_length + n_fft] * window
        for i in range(n_frames)
    ])
    return jnp.fft.rfft(frames, n=n_fft)

def istft(stft_matrix, hop_length, window, length):
    """Compute inverse STFT with overlap-add."""
    n_fft = (stft_matrix.shape[1] - 1) * 2
    n_frames = stft_matrix.shape[0]
    frames = jnp.fft.irfft(stft_matrix, n=n_fft)
    frames = frames * window[None, :]
    output = jnp.zeros(length)
    for i in range(n_frames):
        start = i * hop_length
        end = start + n_fft
        if end <= length:
            output = output.at[start:end].add(frames[i])
    return output

# Forward STFT
S = stft(waveform, n_fft, hop_length, window)
magnitude = jnp.abs(S)

# Mel filterbank
n_mels = 80
mel_low = 0.0
mel_high = 2595 * jnp.log10(1 + (sr / 2) / 700)
mel_points = jnp.linspace(mel_low, mel_high, n_mels + 2)
hz_points = 700 * (10 ** (mel_points / 2595) - 1)
freq_bins = jnp.floor((n_fft + 1) * hz_points / sr).astype(int)

mel_filterbank = jnp.zeros((n_mels, n_fft // 2 + 1))
for m in range(n_mels):
    f_left = freq_bins[m]
    f_center = freq_bins[m + 1]
    f_right = freq_bins[m + 2]
    for k in range(f_left, f_center):
        mel_filterbank = mel_filterbank.at[m, k].set(
            (k - f_left) / max(f_center - f_left, 1)
        )
    for k in range(f_center, f_right):
        mel_filterbank = mel_filterbank.at[m, k].set(
            (f_right - k) / max(f_right - f_center, 1)
        )

# To mel and back (pseudo-inverse)
mel_spec = magnitude @ mel_filterbank.T
magnitude_reconstructed = mel_spec @ jnp.linalg.pinv(mel_filterbank.T)
magnitude_reconstructed = jnp.maximum(magnitude_reconstructed, 1e-7)

# Griffin-Lim algorithm
def griffin_lim(magnitude, n_iter, hop_length, window, signal_length):
    """Iterative phase reconstruction."""
    n_fft = (magnitude.shape[1] - 1) * 2
    key = jax.random.PRNGKey(42)
    phase = jax.random.uniform(key, magnitude.shape, minval=-jnp.pi, maxval=jnp.pi)

    for _ in range(n_iter):
        complex_spec = magnitude * jnp.exp(1j * phase)
        signal = istft(complex_spec, hop_length, window, signal_length)
        reanalysis = stft(signal, n_fft, hop_length, window)
        phase = jnp.angle(reanalysis)

    complex_spec = magnitude * jnp.exp(1j * phase)
    return istft(complex_spec, hop_length, window, signal_length)

reconstructed = griffin_lim(magnitude_reconstructed, n_iter=60, hop_length=hop_length,
                            window=window, signal_length=len(waveform))

# Plot comparison
fig, axes = plt.subplots(3, 1, figsize=(12, 8))

axes[0].plot(t[:1000], waveform[:1000], color='#3498db', linewidth=0.8)
axes[0].set_title('Original Waveform')
axes[0].set_ylabel('Amplitude')

axes[1].imshow(jnp.log1p(mel_spec.T), aspect='auto', origin='lower', cmap='magma')
axes[1].set_title('Mel Spectrogram (intermediate representation)')
axes[1].set_ylabel('Mel bin')

axes[2].plot(t[:1000], reconstructed[:1000], color='#e74c3c', linewidth=0.8)
axes[2].set_title('Griffin-Lim Reconstructed Waveform (60 iterations)')
axes[2].set_xlabel('Time (s)')
axes[2].set_ylabel('Amplitude')

plt.tight_layout()
plt.show()

# Measure reconstruction error
mse = jnp.mean((waveform[:len(reconstructed)] - reconstructed[:len(waveform)]) ** 2)
print(f"MSE between original and reconstructed: {mse:.6f}")
print("Note: phase information loss through mel inversion causes artifacts.")
  • Task 2: Duration predictor (FastSpeech-style). Train a small convolutional duration predictor that maps phoneme embeddings to durations. This is the core component enabling non-autoregressive TTS.
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

# Simulate phoneme sequences with ground-truth durations
# In real TTS, durations come from forced alignment or teacher attention
def generate_synthetic_data(key, n_samples=200, max_phonemes=30, embed_dim=64):
    """Generate synthetic phoneme embeddings and durations."""
    keys = jr.split(key, 4)
    lengths = jr.randint(keys[0], (n_samples,), 5, max_phonemes)

    all_embeddings = []
    all_durations = []
    all_masks = []

    for i in range(n_samples):
        L = int(lengths[i])
        emb = jr.normal(keys[1], (max_phonemes, embed_dim))
        # Durations: vowels (even indices) are longer, consonants shorter
        base_dur = jnp.where(jnp.arange(max_phonemes) % 2 == 0, 8.0, 4.0)
        noise = jr.normal(jr.fold_in(keys[2], i), (max_phonemes,)) * 1.5
        dur = jnp.clip(base_dur + noise, 1.0, 20.0).astype(jnp.float32)
        mask = (jnp.arange(max_phonemes) < L).astype(jnp.float32)

        all_embeddings.append(emb)
        all_durations.append(dur * mask)
        all_masks.append(mask)

    return (jnp.stack(all_embeddings), jnp.stack(all_durations),
            jnp.stack(all_masks))

key = jr.PRNGKey(42)
embeddings, durations, masks = generate_synthetic_data(key)

# Duration predictor: 2-layer 1D convolution + linear projection
def init_duration_predictor(key, embed_dim=64, hidden_dim=128, kernel_size=3):
    """Initialise duration predictor weights."""
    keys = jr.split(key, 4)
    scale1 = jnp.sqrt(2.0 / (embed_dim * kernel_size))
    scale2 = jnp.sqrt(2.0 / (hidden_dim * kernel_size))
    params = {
        'conv1_w': jr.normal(keys[0], (kernel_size, embed_dim, hidden_dim)) * scale1,
        'conv1_b': jnp.zeros(hidden_dim),
        'conv2_w': jr.normal(keys[1], (kernel_size, hidden_dim, hidden_dim)) * scale2,
        'conv2_b': jnp.zeros(hidden_dim),
        'linear_w': jr.normal(keys[2], (hidden_dim, 1)) * jnp.sqrt(2.0 / hidden_dim),
        'linear_b': jnp.zeros(1),
    }
    return params

def duration_predictor(params, x):
    """Predict log-durations from phoneme embeddings. x: (batch, seq, embed)."""
    # Conv layer 1 with ReLU
    h = jax.lax.conv_general_dilated(
        x.transpose(0, 2, 1),  # (batch, embed, seq)
        params['conv1_w'].transpose(2, 1, 0),  # (out, in, kernel)
        window_strides=(1,), padding='SAME'
    ).transpose(0, 2, 1) + params['conv1_b']  # back to (batch, seq, hidden)
    h = jax.nn.relu(h)

    # Conv layer 2 with ReLU
    h = jax.lax.conv_general_dilated(
        h.transpose(0, 2, 1),
        params['conv2_w'].transpose(2, 1, 0),
        window_strides=(1,), padding='SAME'
    ).transpose(0, 2, 1) + params['conv2_b']
    h = jax.nn.relu(h)

    # Linear projection to scalar
    log_dur = (h @ params['linear_w'] + params['linear_b']).squeeze(-1)
    return log_dur

# Loss: MSE on log-durations (standard in FastSpeech)
def loss_fn(params, embeddings, durations, masks):
    log_dur_pred = duration_predictor(params, embeddings)
    log_dur_true = jnp.log(jnp.clip(durations, 1.0, None))
    sq_err = (log_dur_pred - log_dur_true) ** 2 * masks
    return jnp.sum(sq_err) / jnp.sum(masks)

grad_fn = jax.jit(jax.value_and_grad(loss_fn))

# Training loop
params = init_duration_predictor(jr.PRNGKey(0))
lr = 1e-3
losses = []

for epoch in range(300):
    loss_val, grads = grad_fn(params, embeddings, durations, masks)
    params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    losses.append(float(loss_val))

# Evaluate on a sample
log_dur_pred = duration_predictor(params, embeddings[:1])
dur_pred = jnp.exp(log_dur_pred[0])
dur_true = durations[0]
mask = masks[0]
valid_len = int(jnp.sum(mask))

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(losses, color='#3498db', linewidth=1.5)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss (log-duration)')
axes[0].set_title('Duration Predictor Training')
axes[0].set_yscale('log')

x_pos = jnp.arange(valid_len)
width = 0.35
axes[1].bar(x_pos - width/2, dur_true[:valid_len], width, color='#27ae60',
            label='Ground truth', alpha=0.8)
axes[1].bar(x_pos + width/2, dur_pred[:valid_len], width, color='#e74c3c',
            label='Predicted', alpha=0.8)
axes[1].set_xlabel('Phoneme index')
axes[1].set_ylabel('Duration (frames)')
axes[1].set_title('Duration Prediction vs Ground Truth')
axes[1].legend()

plt.tight_layout()
plt.show()
  • Task 3: Simple neural vocoder with upsampling convolutions. Build a minimal HiFi-GAN-style generator that upsamples a mel spectrogram to a waveform using transposed convolutions and residual blocks.
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

def init_residual_block(key, channels, kernel_size, dilation):
    """Initialise a dilated residual convolution block."""
    k1, k2 = jr.split(key)
    scale = jnp.sqrt(2.0 / (channels * kernel_size))
    return {
        'conv1_w': jr.normal(k1, (kernel_size, channels, channels)) * scale,
        'conv1_b': jnp.zeros(channels),
        'conv2_w': jr.normal(k2, (kernel_size, channels, channels)) * scale,
        'conv2_b': jnp.zeros(channels),
        'dilation': dilation
    }

def residual_block(params, x):
    """x: (batch, time, channels). Dilated conv residual block with LeakyReLU."""
    h = jax.nn.leaky_relu(x, negative_slope=0.1)
    # Simplified: use standard conv (dilation handled conceptually)
    h = jax.lax.conv_general_dilated(
        h.transpose(0, 2, 1),
        params['conv1_w'].transpose(2, 1, 0),
        window_strides=(1,),
        padding='SAME',
        rhs_dilation=(params['dilation'],)
    ).transpose(0, 2, 1) + params['conv1_b']
    h = jax.nn.leaky_relu(h, negative_slope=0.1)
    h = jax.lax.conv_general_dilated(
        h.transpose(0, 2, 1),
        params['conv2_w'].transpose(2, 1, 0),
        window_strides=(1,),
        padding='SAME'
    ).transpose(0, 2, 1) + params['conv2_b']
    return x + h

def init_generator(key, n_mels=80, upsample_rates=(8, 8, 4),
                   channels=128):
    """Initialise a minimal HiFi-GAN-style generator."""
    keys = jr.split(key, 10)
    params = {}

    # Input projection: mel bins -> channels
    params['input_w'] = jr.normal(keys[0], (7, n_mels, channels)) * 0.02
    params['input_b'] = jnp.zeros(channels)

    # Upsample blocks (transposed convolutions)
    in_ch = channels
    for i, rate in enumerate(upsample_rates):
        k_size = rate * 2
        scale = jnp.sqrt(2.0 / (in_ch * k_size))
        out_ch = in_ch // 2
        params[f'up{i}_w'] = jr.normal(keys[i+1], (k_size, in_ch, out_ch)) * scale
        params[f'up{i}_b'] = jnp.zeros(out_ch)
        # Residual blocks at each scale
        params[f'res{i}_0'] = init_residual_block(jr.fold_in(keys[i+4], 0),
                                                    out_ch, 3, 1)
        params[f'res{i}_1'] = init_residual_block(jr.fold_in(keys[i+4], 1),
                                                    out_ch, 3, 3)
        in_ch = out_ch

    # Output projection to mono waveform
    params['output_w'] = jr.normal(keys[8], (7, in_ch, 1)) * 0.02
    params['output_b'] = jnp.zeros(1)
    params['upsample_rates'] = upsample_rates

    return params

def generator_forward(params, mel):
    """mel: (batch, time, n_mels) -> waveform: (batch, time * prod(rates), 1)."""
    # Input projection
    h = jax.lax.conv_general_dilated(
        mel.transpose(0, 2, 1),
        params['input_w'].transpose(2, 1, 0),
        window_strides=(1,), padding='SAME'
    ).transpose(0, 2, 1) + params['input_b']

    for i, rate in enumerate(params['upsample_rates']):
        h = jax.nn.leaky_relu(h, negative_slope=0.1)
        # Upsample via transposed convolution
        k_size = rate * 2
        h = jax.lax.conv_transpose(
            h.transpose(0, 2, 1),
            params[f'up{i}_w'].transpose(2, 1, 0),
            strides=(rate,),
            padding='SAME'
        ).transpose(0, 2, 1) + params[f'up{i}_b']
        # Residual blocks
        h = residual_block(params[f'res{i}_0'], h)
        h = residual_block(params[f'res{i}_1'], h)

    h = jax.nn.leaky_relu(h, negative_slope=0.1)
    out = jax.lax.conv_general_dilated(
        h.transpose(0, 2, 1),
        params['output_w'].transpose(2, 1, 0),
        window_strides=(1,), padding='SAME'
    ).transpose(0, 2, 1) + params['output_b']

    return jnp.tanh(out)

# Create a synthetic mel spectrogram (simulating a vowel)
n_mels = 80
n_frames = 50
mel = jnp.zeros((1, n_frames, n_mels))
# Add energy in low-frequency mel bins (simulating formants)
mel = mel.at[:, :, 5:15].set(1.0)
mel = mel.at[:, :, 20:25].set(0.6)

# Initialise and run generator
key = jr.PRNGKey(42)
params = init_generator(key, n_mels=n_mels, upsample_rates=(8, 8, 4),
                         channels=128)
waveform = generator_forward(params, mel)

print(f"Input mel shape:  {mel.shape}")
print(f"Output waveform shape: {waveform.shape}")
print(f"Upsample factor: {8 * 8 * 4} = {8*8*4}x")

fig, axes = plt.subplots(2, 1, figsize=(12, 6))

axes[0].imshow(mel[0].T, aspect='auto', origin='lower', cmap='magma')
axes[0].set_title('Input Mel Spectrogram')
axes[0].set_ylabel('Mel bin')
axes[0].set_xlabel('Frame')

waveform_np = waveform[0, :, 0]
axes[1].plot(waveform_np[:2000], color='#9b59b6', linewidth=0.5)
axes[1].set_title('Generator Output Waveform (untrained - random noise)')
axes[1].set_ylabel('Amplitude')
axes[1].set_xlabel('Sample')

plt.tight_layout()
plt.show()
print("Note: The output is noise because the generator is untrained.")
print("In practice, adversarial + mel loss training shapes this into speech.")
  • Task 4: Voice activity detection with a simple RNN. Train a small GRU-based VAD model on synthetic audio features to classify frames as speech or silence.
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

# Generate synthetic log-mel energy features with speech/silence labels
def generate_vad_data(key, n_sequences=100, n_frames=200, n_features=40):
    """Simulate log-mel features: speech regions are higher energy with structure."""
    keys = jr.split(key, 5)
    all_features = []
    all_labels = []

    for i in range(n_sequences):
        k = jr.fold_in(keys[0], i)
        k1, k2, k3 = jr.split(k, 3)

        # Random speech/silence pattern
        label = jnp.zeros(n_frames)
        n_segments = jr.randint(k1, (), 2, 6)
        for seg in range(int(n_segments)):
            start = jr.randint(jr.fold_in(k2, seg), (), 0, n_frames - 20)
            length = jr.randint(jr.fold_in(k3, seg), (), 10, 50)
            end = jnp.minimum(start + length, n_frames)
            label = label.at[int(start):int(end)].set(1.0)

        # Features: speech frames have higher energy + spectral structure
        noise = jr.normal(jr.fold_in(keys[1], i), (n_frames, n_features)) * 0.3
        speech_pattern = jnp.outer(label, jnp.exp(-jnp.arange(n_features) / 15.0))
        features = speech_pattern * 2.0 + noise + 0.1

        all_features.append(features)
        all_labels.append(label)

    return jnp.stack(all_features), jnp.stack(all_labels)

key = jr.PRNGKey(123)
features, labels = generate_vad_data(key)
train_features, train_labels = features[:80], labels[:80]
test_features, test_labels = features[80:], labels[80:]

# Simple GRU-based VAD model
def init_vad_model(key, input_dim=40, hidden_dim=64):
    keys = jr.split(key, 6)
    scale_ih = jnp.sqrt(2.0 / input_dim)
    scale_hh = jnp.sqrt(2.0 / hidden_dim)
    return {
        'W_z': jr.normal(keys[0], (input_dim, hidden_dim)) * scale_ih,
        'U_z': jr.normal(keys[1], (hidden_dim, hidden_dim)) * scale_hh,
        'b_z': jnp.zeros(hidden_dim),
        'W_r': jr.normal(keys[2], (input_dim, hidden_dim)) * scale_ih,
        'U_r': jr.normal(keys[3], (hidden_dim, hidden_dim)) * scale_hh,
        'b_r': jnp.zeros(hidden_dim),
        'W_h': jr.normal(keys[4], (input_dim, hidden_dim)) * scale_ih,
        'U_h': jr.normal(keys[5], (hidden_dim, hidden_dim)) * scale_hh,
        'b_h': jnp.zeros(hidden_dim),
        'W_out': jr.normal(jr.fold_in(keys[0], 99), (hidden_dim, 1)) * 0.1,
        'b_out': jnp.zeros(1),
    }

def gru_step(params, h, x):
    """Single GRU step."""
    z = jax.nn.sigmoid(x @ params['W_z'] + h @ params['U_z'] + params['b_z'])
    r = jax.nn.sigmoid(x @ params['W_r'] + h @ params['U_r'] + params['b_r'])
    h_tilde = jnp.tanh(x @ params['W_h'] + (r * h) @ params['U_h'] + params['b_h'])
    h_new = (1 - z) * h + z * h_tilde
    return h_new

def vad_forward(params, x):
    """x: (batch, time, features) -> logits: (batch, time)."""
    batch_size, n_frames, _ = x.shape
    hidden_dim = params['W_z'].shape[1]
    h = jnp.zeros((batch_size, hidden_dim))

    outputs = []
    for t in range(n_frames):
        h = gru_step(params, h, x[:, t, :])
        logit = (h @ params['W_out'] + params['b_out']).squeeze(-1)
        outputs.append(logit)

    return jnp.stack(outputs, axis=1)

def bce_loss(params, features, labels):
    """Binary cross-entropy loss for VAD."""
    logits = vad_forward(params, features)
    probs = jax.nn.sigmoid(logits)
    probs = jnp.clip(probs, 1e-7, 1 - 1e-7)
    loss = -(labels * jnp.log(probs) + (1 - labels) * jnp.log(1 - probs))
    return jnp.mean(loss)

grad_fn = jax.jit(jax.value_and_grad(bce_loss))

# Training
params = init_vad_model(jr.PRNGKey(0))
lr = 5e-3
losses = []

for epoch in range(200):
    loss_val, grads = grad_fn(params, train_features, train_labels)
    params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    losses.append(float(loss_val))
    if epoch % 50 == 0:
        print(f"Epoch {epoch}: loss = {loss_val:.4f}")

# Evaluate on test set
test_logits = vad_forward(params, test_features)
test_preds = (jax.nn.sigmoid(test_logits) > 0.5).astype(jnp.float32)
accuracy = jnp.mean(test_preds == test_labels)
print(f"\nTest accuracy: {accuracy:.4f}")

# Visualise a test example
idx = 0
fig, axes = plt.subplots(3, 1, figsize=(14, 7))

axes[0].imshow(test_features[idx].T, aspect='auto', origin='lower', cmap='magma')
axes[0].set_title('Log-Mel Energy Features')
axes[0].set_ylabel('Mel bin')

axes[1].fill_between(range(200), test_labels[idx], alpha=0.4, color='#27ae60',
                     label='Ground truth')
axes[1].plot(jax.nn.sigmoid(test_logits[idx]), color='#e74c3c',
             linewidth=1.5, label='Predicted probability')
axes[1].axhline(0.5, color='gray', linestyle='--', linewidth=0.8)
axes[1].set_ylabel('Speech probability')
axes[1].legend()
axes[1].set_title('VAD Predictions')

axes[2].fill_between(range(200), test_labels[idx], alpha=0.4, color='#27ae60',
                     label='Ground truth')
axes[2].fill_between(range(200), test_preds[idx], alpha=0.4, color='#f39c12',
                     label='Predicted (threshold=0.5)')
axes[2].set_ylabel('Speech / Silence')
axes[2].set_xlabel('Frame')
axes[2].legend()
axes[2].set_title('VAD Binary Decision')

plt.tight_layout()
plt.show()