Close Menu
geekfence.comgeekfence.com
    What's Hot

    Designing trust & safety (T&S) in customer experience management (CXM): why T&S is becoming core to CXM operating model 

    January 24, 2026

    iPhone 18 Series Could Finally Bring Back Touch ID

    January 24, 2026

    The Visual Haystacks Benchmark! – The Berkeley Artificial Intelligence Research Blog

    January 24, 2026
    Facebook X (Twitter) Instagram
    • About Us
    • Contact Us
    Facebook Instagram
    geekfence.comgeekfence.com
    • Home
    • UK Tech News
    • AI
    • Big Data
    • Cyber Security
      • Cloud Computing
      • iOS Development
    • IoT
    • Mobile
    • Software
      • Software Development
      • Software Engineering
    • Technology
      • Green Technology
      • Nanotechnology
    • Telecom
    geekfence.comgeekfence.com
    Home»Artificial Intelligence»Train Your Large Model on Multiple GPUs with Pipeline Parallelism
    Artificial Intelligence

    Train Your Large Model on Multiple GPUs with Pipeline Parallelism

    AdminBy AdminDecember 30, 2025No Comments6 Mins Read0 Views
    Facebook Twitter Pinterest LinkedIn Telegram Tumblr Email
    Train Your Large Model on Multiple GPUs with Pipeline Parallelism
    Share
    Facebook Twitter LinkedIn Pinterest Email


    import dataclasses

    import os

     

    import datasets

    import tokenizers

    import torch

    import torch.distributed as dist

    import torch.nn as nn

    import torch.nn.functional as F

    import torch.optim.lr_scheduler as lr_scheduler

    import tqdm

    from torch import Tensor

    from torch.distributed.checkpoint import load, save

    from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict, set_state_dict

    from torch.distributed.pipelining import PipelineStage, ScheduleGPipe

     

     

    # Build the model

    @dataclasses.dataclass

    class LlamaConfig:

        “”“Define Llama model hyperparameters.”“”

        vocab_size: int = 50000  # Size of the tokenizer vocabulary

        max_position_embeddings: int = 2048  # Maximum sequence length

        hidden_size: int = 768  # Dimension of hidden layers

        intermediate_size: int = 4*768  # Dimension of MLP’s hidden layer

        num_hidden_layers: int = 12  # Number of transformer layers

        num_attention_heads: int = 12  # Number of attention heads

        num_key_value_heads: int = 3  # Number of key-value heads for GQA

     

     

    class RotaryPositionEncoding(nn.Module):

        “”“Rotary position encoding.”“”

     

        def __init__(self, dim: int, max_position_embeddings: int) -> None:

            “”“Initialize the RotaryPositionEncoding module.

     

            Args:

                dim: The hidden dimension of the input tensor to which RoPE is applied

                max_position_embeddings: The maximum sequence length of the input tensor

            ““”

            super().__init__()

            self.dim = dim

            self.max_position_embeddings = max_position_embeddings

            # compute a matrix of n\theta_i

            N = 10_000.0

            inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))

            inv_freq = torch.cat((inv_freq, inv_freq), dim=–1)

            position = torch.arange(max_position_embeddings)

            sinusoid_inp = torch.outer(position, inv_freq)

            # save cosine and sine matrices as buffers, not parameters

            self.register_buffer(“cos”, sinusoid_inp.cos())

            self.register_buffer(“sin”, sinusoid_inp.sin())

     

        def forward(self, x: Tensor) -> Tensor:

            “”“Apply RoPE to tensor x.

     

            Args:

                x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim)

     

            Returns:

                Output tensor of shape (batch_size, seq_length, num_heads, head_dim)

            ““”

            batch_size, seq_len, num_heads, head_dim = x.shape

            dtype = x.dtype

            # transform the cosine and sine matrices to 4D tensor and the same dtype as x

            cos = self.cos.to(dtype)[:seq_len].view(1, seq_len, 1, –1)

            sin = self.sin.to(dtype)[:seq_len].view(1, seq_len, 1, –1)

            # apply RoPE to x

            x1, x2 = x.chunk(2, dim=–1)

            rotated = torch.cat((–x2, x1), dim=–1)

            output = (x * cos) + (rotated * sin)

            return output

     

     

    class LlamaAttention(nn.Module):

        “”“Grouped-query attention with rotary embeddings.”“”

     

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            self.hidden_size = config.hidden_size

            self.num_heads = config.num_attention_heads

            self.head_dim = self.hidden_size // self.num_heads

            self.num_kv_heads = config.num_key_value_heads  # GQA: H_kv < H_q

     

            # hidden_size must be divisible by num_heads

            assert (self.head_dim * self.num_heads) == self.hidden_size

     

            # Linear layers for Q, K, V projections

            self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)

            self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

            self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

            self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

     

        def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor:

            bs, seq_len, dim = hidden_states.size()

     

            # Project inputs to Q, K, V

            query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)

            key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

            value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

     

            # Apply rotary position embeddings

            query_states = rope(query_states)

            key_states = rope(key_states)

     

            # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention

            query_states = query_states.transpose(1, 2)

            key_states = key_states.transpose(1, 2)

            value_states = value_states.transpose(1, 2)

     

            # Use PyTorch’s optimized attention implementation

            # setting is_causal=True is incompatible with setting explicit attention mask

            attn_output = F.scaled_dot_product_attention(

                query_states,

                key_states,

                value_states,

                is_causal=True,

                dropout_p=0.0,

                enable_gqa=True,

            )

     

            # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output

            attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)

            attn_output = self.o_proj(attn_output)

            return attn_output

     

     

    class LlamaMLP(nn.Module):

        “”“Feed-forward network with SwiGLU activation.”“”

     

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            # Two parallel projections for SwiGLU

            self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

            self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

            self.act_fn = F.silu  # SwiGLU activation function

            # Project back to hidden size

            self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

     

        def forward(self, x: Tensor) -> Tensor:

            # SwiGLU activation: multiply gate and up-projected inputs

            gate = self.act_fn(self.gate_proj(x))

            up = self.up_proj(x)

            return self.down_proj(gate * up)

     

     

    class LlamaDecoderLayer(nn.Module):

        “”“Single transformer layer for a Llama model.”“”

     

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)

            self.self_attn = LlamaAttention(config)

            self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)

            self.mlp = LlamaMLP(config)

     

        def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor:

            # First residual block: Self-attention

            residual = hidden_states

            hidden_states = self.input_layernorm(hidden_states)

            attn_outputs = self.self_attn(hidden_states, rope=rope)

            hidden_states = attn_outputs + residual

     

            # Second residual block: MLP

            residual = hidden_states

            hidden_states = self.post_attention_layernorm(hidden_states)

            hidden_states = self.mlp(hidden_states) + residual

            return hidden_states

     

     

    class LlamaModel(nn.Module):

        “”“The full Llama model without any pretraining heads.”“”

     

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            self.rope = RotaryPositionEncoding(

                config.hidden_size // config.num_attention_heads,

                config.max_position_embeddings,

            )

     

            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

            self.layers = nn.ModuleDict({

                str(i): LlamaDecoderLayer(config) for i in range(config.num_hidden_layers)

            })

            self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5)

     

        def forward(self, input_ids: Tensor) -> Tensor:

            # Convert input token IDs to embeddings

            if self.embed_tokens is not None:

                hidden_states = self.embed_tokens(input_ids)

            else:

                hidden_states = input_ids

            # Process through all transformer layers, then the final norm layer

            for n in range(len(self.layers)):

                if self.layers[str(n)] is not None:

                    hidden_states = self.layers[str(n)](hidden_states, self.rope)

            if self.norm is not None:

                hidden_states = self.norm(hidden_states)

            # Return the final hidden states, and copy over the attention mask

            return hidden_states

     

     

    class LlamaForPretraining(nn.Module):

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            self.base_model = LlamaModel(config)

            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

     

        def forward(self, input_ids: Tensor) -> Tensor:

            hidden_states = self.base_model(input_ids)

            if self.lm_head is not None:

                hidden_states = self.lm_head(hidden_states)

            return hidden_states

     

     

    # Generator function to create padded sequences of fixed length

    class PretrainingDataset(torch.utils.data.Dataset):

        def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer,

                     seq_length: int, device: torch.device = None):

            self.dataset = dataset

            self.tokenizer = tokenizer

            self.device = device

            self.seq_length = seq_length

            self.bot = tokenizer.token_to_id(“[BOT]”)

            self.eot = tokenizer.token_to_id(“[EOT]”)

            self.pad = tokenizer.token_to_id(“[PAD]”)

     

        def __len__(self):

            return len(self.dataset)

     

        def __getitem__(self, index):

            “”“Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens

            are added. Clipped and padded to the sequence length.

            ““”

            seq = self.dataset[index][“text”]

            tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot]

            # pad to target sequence length

            toklen = len(tokens)

            if toklen < self.seq_length+1:

                pad_length = self.seq_length+1 – toklen

                tokens += [self.pad] * pad_length

            # return the sequence

            x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64, device=self.device)

            y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64, device=self.device)

            return x, y

     

     

    def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> None:

        dist.barrier()

        model_state, optimizer_state = get_state_dict(

            model, optimizer, options=StateDictOptions(full_state_dict=True),

        )

        load(

            {“model”: model_state, “optimizer”: optimizer_state},

            checkpoint_id=“checkpoint-dist”,

        )

        set_state_dict(

            model, optimizer,

            model_state_dict=model_state, optim_state_dict=optimizer_state,

            options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),

        )

        dist.barrier()

     

     

    def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> None:

        dist.barrier()

        model_state, optimizer_state = get_state_dict(

            model, optimizer, options=StateDictOptions(full_state_dict=True),

        )

        save(

            {“model”: model_state, “optimizer”: optimizer_state},

            checkpoint_id=“checkpoint-dist”,

        )

        dist.barrier()

     

     

    # Load the tokenizer and dataset

    tokenizer = tokenizers.Tokenizer.from_file(“bpe_50K.json”)

    dataset = datasets.load_dataset(“HuggingFaceFW/fineweb”, “sample-10BT”, split=“train”)

     

    # Initialize the distributed environment

    dist.init_process_group(backend=“nccl”)

    rank = dist.get_rank()

    local_rank = int(os.environ[“LOCAL_RANK”])

    world_size = dist.get_world_size()

    device = torch.device(f“cuda:{local_rank}”)

    print(f“World size {world_size}, rank {rank}, local rank {local_rank}. Using {device}”)

    assert world_size == 3, f“This script is designed for 3 GPUs, got {world_size}”

     

    # Create pretraining model with default config on meta device to prevent OOM

    with torch.device(“meta”):

        model_config = LlamaConfig()

        model = LlamaForPretraining(model_config)

        # Partition the model by removing some layers

        num_layers = model_config.num_hidden_layers

        partition = [num_layers // 3, 2 * num_layers // 3, num_layers]

        if rank == 0:

            # from embedding to 1/3 of the decoder layers

            for n in range(partition[0], partition[2]):

                model.base_model.layers[str(n)] = None

            model.base_model.norm = None

            model.lm_head = None

        elif rank == 1:

            # from 1/3 to 2/3 of the decoder layers

            model.base_model.embed_tokens = None

            for n in range(0, partition[0]):

                model.base_model.layers[str(n)] = None

            for n in range(partition[1], partition[2]):

                model.base_model.layers[str(n)] = None

            model.base_model.norm = None

            model.lm_head = None

        elif rank == 2:

            # from 2/3 to the end of the decoder layers and the final norm layer, LM head

            model.base_model.embed_tokens = None

            for n in range(partition[1]):

                model.base_model.layers[str(n)] = None

        else:

            raise ValueError(f“Invalid rank: {rank}”)

     

     

    # Move model from meta device to CUDA device, then initialize the weights

    def reset_all_weights(model: nn.Module) -> None:

        @torch.no_grad()

        def weight_reset(m: nn.Module):

            reset_parameters = getattr(m, “reset_parameters”, None)

            if callable(reset_parameters):

                m.reset_parameters()

     

        # Applies fn recursively to model itself and all of model.children()

        model.apply(fn=weight_reset)

     

     

    model.to_empty(device=device)

    reset_all_weights(model)

    model.train()

    stage = PipelineStage(model, stage_index=rank, num_stages=world_size, device=device)

     

    # Training parameters

    epochs = 3

    learning_rate = 1e–3

    batch_size = 64

    seq_length = 512

    num_warmup_steps = 1000

    PAD_TOKEN_ID = tokenizer.token_to_id(“[PAD]”)

     

    # DataLoader, optimizer, scheduler, and loss function

    dataset = PretrainingDataset(dataset, tokenizer, seq_length, device)

    dataloader = torch.utils.data.DataLoader(

        dataset,

        batch_size=batch_size,

    )

    num_training_steps = len(dataloader) * epochs

    print(f“Number of training steps: {num_training_steps} = {len(dataloader)} * {epochs}”)

     

    optimizer = torch.optim.AdamW(

        model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e–8, weight_decay=0.1,

    )

    warmup_scheduler = lr_scheduler.LinearLR(

        optimizer,

        start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps,

    )

    cosine_scheduler = lr_scheduler.CosineAnnealingLR(

        optimizer,

        T_max=num_training_steps – num_warmup_steps,

        eta_min=0,

    )

    scheduler = lr_scheduler.SequentialLR(

        optimizer,

        schedulers=[warmup_scheduler, cosine_scheduler],

        milestones=[num_warmup_steps],

    )

     

    # if checkpoint-dist dir exists, load the checkpoint to model and optimizer

    # Note: You should implement how to reset the epoch and step to allow correct resume

    if os.path.exists(“checkpoint-dist”):

        load_checkpoint(model, optimizer)

     

    # Create pipeline schedule

    def loss_fn(logits: Tensor, target_ids: Tensor) -> Tensor:

        logits = logits.view(–1, logits.size(–1))

        target_ids = target_ids.view(–1)

        return F.cross_entropy(logits, target_ids, ignore_index=PAD_TOKEN_ID)

     

    n_microbatches = 4  # num split per batch

    schedule = ScheduleGPipe(stage, n_microbatches=n_microbatches, loss_fn=loss_fn)

     

    # start training

    for epoch in range(epochs):

        pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”, disable=(rank != world_size – 1))

        for batch_id, batch in enumerate(pbar):

            if batch_id % 1000 == 0:

                save_checkpoint(model, optimizer)

            # zero grad before forward pass, since no explicit backward pass is called

            optimizer.zero_grad(set_to_none=True)

            # get batched data

            input_ids, target_ids = batch

            if rank == 0:

                schedule.step(input_ids)

            elif rank == world_size – 1:

                losses = []  # expects one lost per microbatch

                logits = schedule.step(target=target_ids, losses=losses)

                with torch.no_grad():

                    pbar.set_postfix(loss=sum(losses).item() / len(losses))

            else:

                schedule.step()

     

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()

            scheduler.step()

            pbar.update(1)

        pbar.close()

     

    # Save the model

    save_checkpoint(model, optimizer)

     

    # Clean up the distributed environment

    dist.destroy_process_group()



    Source link

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email

    Related Posts

    The Visual Haystacks Benchmark! – The Berkeley Artificial Intelligence Research Blog

    January 24, 2026

    Windows 365 for Agents: The Cloud PC’s next chapter

    January 23, 2026

    Why it’s critical to move beyond overly aggregated machine-learning metrics | MIT News

    January 22, 2026

    The Machine Learning Practitioner’s Guide to Model Deployment with FastAPI

    January 21, 2026

    The breakthrough that makes robot faces feel less creepy

    January 20, 2026

    Balancing cost and performance: Agentic AI development

    January 19, 2026
    Top Posts

    Understanding U-Net Architecture in Deep Learning

    November 25, 202511 Views

    Hard-braking events as indicators of road segment crash risk

    January 14, 20269 Views

    Microsoft 365 Copilot now enables you to build apps and workflows

    October 29, 20258 Views
    Don't Miss

    Designing trust & safety (T&S) in customer experience management (CXM): why T&S is becoming core to CXM operating model 

    January 24, 2026

    Customer Experience (CX) now sits at the intersection of Artificial Intelligence (AI)-enabled automation, identity and access journeys, AI-generated content…

    iPhone 18 Series Could Finally Bring Back Touch ID

    January 24, 2026

    The Visual Haystacks Benchmark! – The Berkeley Artificial Intelligence Research Blog

    January 24, 2026

    Data and Analytics Leaders Think They’re AI-Ready. They’re Probably Not. 

    January 24, 2026
    Stay In Touch
    • Facebook
    • Instagram
    About Us

    At GeekFence, we are a team of tech-enthusiasts, industry watchers and content creators who believe that technology isn’t just about gadgets—it’s about how innovation transforms our lives, work and society. We’ve come together to build a place where readers, thinkers and industry insiders can converge to explore what’s next in tech.

    Our Picks

    Designing trust & safety (T&S) in customer experience management (CXM): why T&S is becoming core to CXM operating model 

    January 24, 2026

    iPhone 18 Series Could Finally Bring Back Touch ID

    January 24, 2026

    Subscribe to Updates

    Please enable JavaScript in your browser to complete this form.
    Loading
    • About Us
    • Contact Us
    • Disclaimer
    • Privacy Policy
    • Terms and Conditions
    © 2026 Geekfence.All Rigt Reserved.

    Type above and press Enter to search. Press Esc to cancel.