Skip to content

Commit 49e8426

Browse files
committed
Mixtral recipes
1 parent 86c8329 commit 49e8426

File tree

78 files changed

+13575
-10
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+13575
-10
lines changed

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,22 @@ def _restack_from_views(self) -> None:
279279
device = torch.cuda.current_device()
280280
for attr_name in ("experts_gate_up_weight", "experts_down_weight"):
281281
old_param = getattr(self, attr_name)
282-
new_data = torch.empty_like(old_param, device=device)
283-
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
284-
setattr(self, attr_name, nn.Parameter(new_data))
282+
if isinstance(old_param.data, DTensor):
283+
# FSDP2 has sharded this param; materialize the local shard on CUDA
284+
# and reconstruct the DTensor wrapper so FSDP2 can manage it.
285+
local_data = old_param.data.to_local()
286+
new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device)
287+
torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range)
288+
new_dtensor = DTensor.from_local(
289+
new_local,
290+
device_mesh=old_param.data.device_mesh,
291+
placements=old_param.data.placements,
292+
)
293+
setattr(self, attr_name, nn.Parameter(new_dtensor))
294+
else:
295+
new_data = torch.empty_like(old_param, device=device)
296+
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
297+
setattr(self, attr_name, nn.Parameter(new_data))
285298

286299
# Re-sync views to point to the new stacked parameter
287300
self._sync_expert_views()
@@ -298,13 +311,15 @@ def _sync_expert_views(self) -> None:
298311
gate_up_w = self.experts_gate_up_weight
299312
if isinstance(gate_up_w, DTensor):
300313
gate_up_w = gate_up_w.to_local()
301-
for i in range(self.num_local_experts):
314+
num_local = gate_up_w.shape[0]
315+
for i in range(num_local):
302316
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])
303317

304318
down_w = self.experts_down_weight
305319
if isinstance(down_w, DTensor):
306320
down_w = down_w.to_local()
307-
for i in range(self.num_local_experts):
321+
num_local_down = down_w.shape[0]
322+
for i in range(num_local_down):
308323
object.__setattr__(self.experts_down, f"weight{i}", down_w[i])
309324

310325
def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None:
@@ -503,12 +518,20 @@ def __init__(
503518
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
504519
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
505520

506-
if fp8_recipe is not None and self.config.layer_precision is None:
507-
if fp4_recipe is not None:
521+
if self.config.layer_precision is None:
522+
if fp8_recipe is not None and fp4_recipe is not None:
508523
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
509-
510-
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
511-
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
524+
if fp8_recipe is not None:
525+
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
526+
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
527+
elif fp4_recipe is not None:
528+
raise RuntimeError(
529+
"FP4 recipe provided but no layer_precision configured. "
530+
"Set layer_precision explicitly when using FP4."
531+
)
532+
533+
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
534+
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
512535

513536
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
514537

@@ -857,6 +880,10 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
857880
class HFInferenceParams(InferenceParams):
858881
"""Extension of the InferenceParams class to support HF generate() and beam search."""
859882

883+
# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
884+
# custom TE-based cache is not compatible with torch.compile generate().
885+
is_compileable = False
886+
860887
def get_seq_length(self, layer_idx: int = 0) -> int:
861888
"""Return the current cached sequence length.
862889
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# syntax=docker/dockerfile:1.4
2+
FROM nvcr.io/nvidia/pytorch:26.02-py3
3+
4+
RUN --mount=type=cache,target=/root/.cache/pip \
5+
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
6+
PIP_CONSTRAINT= pip install -r /requirements.txt
7+
8+
WORKDIR /workspace/bionemo
9+
COPY . .
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# TransformerEngine-accelerated Mixtral training with a native PyTorch training loop
2+
3+
This folder demonstrates how to train TE-accelerated Mixtral with a native PyTorch training loop using FSDP2 for
4+
distributed training. The recipe mirrors the structure and conventions of `llama3_native_te`, and includes a Lingua-style
5+
configuration for natural-language pre-training on DCLM Baseline 1.0.
6+
7+
## Commands
8+
9+
Single GPU sanity run:
10+
11+
```bash
12+
python train_fsdp2.py --config-name L0_sanity
13+
```
14+
15+
Single GPU Lingua smoke run:
16+
17+
```bash
18+
python train_fsdp2.py --config-name L2_lingua_8x1B num_train_steps=20 checkpoint.ckpt_dir=./checkpoints
19+
```
20+
21+
Cluster or multi-GPU run:
22+
23+
```bash
24+
torchrun --standalone --nproc_per_node=2 train_fsdp2.py --config-name L2_lingua_8x1B
25+
```
26+
27+
## Notes
28+
29+
- The Lingua config uses the `meta-llama/Meta-Llama-3-8B` tokenizer and streams `mlfoundations/dclm-baseline-1.0`.
30+
- `expert_parallel_size` remains `1` in this v1 recipe so it matches the existing Llama3 Lingua recipe structure.
31+
- Use `HF_TOKEN` for Hugging Face access and `WANDB_KEY` for Weights & Biases logging.

0 commit comments

Comments
 (0)