Custom pipeline parallelism implementation for Llama3.1-70B enabling mechanistic interpretability research
This project presents a custom implementation of pipeline parallelism for the Llama3.1-70B model, designed specifically to enable mechanistic interpretability research. Unlike existing high-level frameworks such as vLLM and DeepSpeed, which abstract away layer-level access, our implementation provides direct access to individual transformer layers while maintaining efficient distributed inference across multiple GPUs.
- Layer-wise Access: Direct access to hidden states after each transformer layer
- Intervention Capabilities: Ability to modify representations at any pipeline stage
- Debugging Transparency: Complete visibility into tensor shapes and processing flow
- Custom Analysis: Freedom to implement custom probing and analysis tools
- Memory Efficiency: Enables research on large models without requiring prohibitive hardware
- 4-GPU Pipeline: Optimized for 4-GPU configurations with 80 transformer layers
The implementation targets a 4-GPU configuration where the Llama3.1-70B model (with 80 transformer layers) is distributed using pipeline parallelism:
| GPU | Components |
|---|---|
| GPU 0 | Layers 0-19 + Embeddings + Rotary Embeddings |
| GPU 1 | Layers 20-39 |
| GPU 2 | Layers 40-59 |
| GPU 3 | Layers 60-79 + Layer Norm + Language Model Head |
The device mapping assigns specific model components to GPUs while offloading unused components to disk. The model loading process bypasses CPU RAM entirely through Hugging Face Accelerate's intelligent dispatching mechanism:
- Meta Device Initialization: Model structure is initialized without allocating memory for weights
- Direct GPU Loading: Weights are loaded directly from disk to designated GPUs
- Disk Offloading: Unused components are offloaded to fast NVMe storage
- Python 3.8 or higher
- CUDA-capable GPUs (4 GPUs recommended)
- PyTorch 2.0+
- Transformers 4.30+
- Accelerate 0.20+
git clone https://github.com/your-username/pipeline-parallel-llama.git
cd pipeline-parallel-llama
pip install -e .pip install -r requirements.txt# Run distributed inference with 4 GPUs
torchrun --nproc_per_node=4 -m pipeline_parallel_llama.cli inference \
--model-path /path/to/llama-3.1-70b \
--prompt "The first number in this list [34, 56, 78] is: "from pipeline_parallel_llama import (
setup_distributed,
load_model_shard,
generate_pipeline
)
# Initialize distributed environment
rank, world_size, local_rank, device = setup_distributed()
# Load model shard for current GPU
model, tokenizer, start_layer, end_layer, offload_dir, rotary_emb = load_model_shard(
model_path="/path/to/llama-3.1-70b",
local_rank=local_rank,
world_size=world_size
)
# Run pipeline inference
result = generate_pipeline(
model=model,
tokenizer=tokenizer,
start_layer=start_layer,
end_layer=end_layer,
prompt="Your prompt here",
rank=rank,
world_size=world_size,
device=device,
rotary_emb=rotary_emb
)One of the most significant technical challenges was handling rotary position embeddings when models are sharded across devices. The LlamaRotaryEmbedding module becomes inaccessible to downstream GPUs, causing pipeline failures.
Solution: We implemented a standalone rotary embedding accessible to all ranks:
def create_rotary_embedding(config, device):
"""Create a standalone rotary embedding for computing position embeddings."""
return LlamaRotaryEmbedding(config=config, device=device)The pipeline requires coordinated communication between GPUs. Each rank receives activations from the previous stage, processes them through assigned layers, and forwards them to the next stage:
def send_tensors_to_next_rank(hidden_states, input_ids, position_embeddings,
batch_size, seq_len, device, dst_rank):
"""Send tensors to the next rank in the pipeline."""
# Send shape information first
shape_tensor = torch.tensor([batch_size, seq_len], dtype=torch.long, device=device)
dist.send(shape_tensor, dst=dst_rank)
# Send actual data
dist.send(hidden_states.contiguous().detach(), dst=dst_rank)
dist.send(input_ids.contiguous().detach(), dst=dst_rank)
# ... position embeddingsEach GPU processes its assigned layers while maintaining complete access to intermediate representations:
def forward_layers(model, hidden_states, input_ids, start_layer, end_layer,
rank, device, position_embeddings):
"""Forward pass through model layers on current rank."""
for i in range(start_layer, end_layer):
layer = model.model.layers[i]
layer_outputs = layer(
hidden_states,
position_ids=position_ids,
position_embeddings=position_embeddings,
# ... other parameters
)
hidden_states = layer_outputs[0]
# Full visibility into each layer's output
return hidden_statesThe implementation successfully demonstrates end-to-end inference with the following characteristics:
- Total inference time: 4.43 seconds for single token generation
- Memory efficiency: ~34-36GB allocated per GPU (within 94.5GB capacity)
- Functional accuracy: Correctly generates expected outputs
[R0] Input IDs: tensor([[128000, 791, 1176, ...]], device='cuda:0'), Shape: torch.Size([1, 19])
[R0] Embedded: torch.Size([1, 19, 8192])
[R0] Position embeddings computed: cos=torch.Size([1, 19, 128]), sin=torch.Size([1, 19, 128])
[R0] Processing layers 0-19...
[R1] Receiving from GPU 0, processing layers 20-39...
[R2] Receiving from GPU 1, processing layers 40-59...
[R3] Receiving from GPU 2, processing layers 60-79...
[R3] Generated token ID: 1958
[R0] Generated token: '34'
[R0] Final result: '34'
[R0] Time taken: 4.43s
This implementation provides several key advantages for interpretability research:
- Layer-wise Access: Direct access to hidden states after each transformer layer
- Intervention Capabilities: Ability to modify representations at any pipeline stage
- Debugging Transparency: Complete visibility into tensor shapes and processing flow
- Custom Analysis: Freedom to implement custom probing and analysis tools
- Memory Efficiency: Enables research on large models without requiring prohibitive hardware
PIPELINE_LLAMA_MODEL_PATH: Path to the Llama model directoryPIPELINE_LLAMA_PROMPT: Default prompt for inferenceLOCAL_RANK: GPU rank (set automatically by torchrun)
Adjust GPU memory limits in the model loading:
max_memory = {0: "80GB", 1: "80GB", 2: "80GB", 3: "80GB"}- Single Token Generation: Current implementation focuses on single-token generation
- No KV Caching: Lacks key-value caching for multi-token generation efficiency
- Static Pipeline: Fixed 4-GPU configuration without dynamic load balancing
- Implementation of dynamic key-value caching for multi-token generation
- Support for variable GPU configurations
- Integration with interpretability tools like TransformerLens
- Batched inference capabilities
- Memory optimization for longer sequences
We welcome contributions! Please see our Contributing Guidelines for details.
git clone https://github.com/your-username/pipeline-parallel-llama.git
cd pipeline-parallel-llama
pip install -e ".[dev]"pytest tests/If you use this work in your research, please cite:
@misc{guiomar2024pipeline,
title={Reverse Engineering a Pipeline Parallel Llama3.1-70B with transformers, accelerate and torch.distributed},
author={Guiomar, Gonçalo},
year={2024},
institution={ETH AI Center}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Author: Gonçalo Guiomar, ETH AI Center Fellow
- Institution: ETH AI Center
- Framework Dependencies: PyTorch, Transformers, Accelerate
For questions and support:
- Create an issue on GitHub
- Email: [email protected]
Note: This is a research prototype designed for mechanistic interpretability studies. While functional, it may require adaptation for production use cases.