@@ -279,9 +279,22 @@ def _restack_from_views(self) -> None:
279279 device = torch .cuda .current_device ()
280280 for attr_name in ("experts_gate_up_weight" , "experts_down_weight" ):
281281 old_param = getattr (self , attr_name )
282- new_data = torch .empty_like (old_param , device = device )
283- torch .nn .init .normal_ (new_data , mean = 0.0 , std = self .initializer_range )
284- setattr (self , attr_name , nn .Parameter (new_data ))
282+ if isinstance (old_param .data , DTensor ):
283+ # FSDP2 has sharded this param; materialize the local shard on CUDA
284+ # and reconstruct the DTensor wrapper so FSDP2 can manage it.
285+ local_data = old_param .data .to_local ()
286+ new_local = torch .empty (local_data .shape , dtype = local_data .dtype , device = device )
287+ torch .nn .init .normal_ (new_local , mean = 0.0 , std = self .initializer_range )
288+ new_dtensor = DTensor .from_local (
289+ new_local ,
290+ device_mesh = old_param .data .device_mesh ,
291+ placements = old_param .data .placements ,
292+ )
293+ setattr (self , attr_name , nn .Parameter (new_dtensor ))
294+ else :
295+ new_data = torch .empty_like (old_param , device = device )
296+ torch .nn .init .normal_ (new_data , mean = 0.0 , std = self .initializer_range )
297+ setattr (self , attr_name , nn .Parameter (new_data ))
285298
286299 # Re-sync views to point to the new stacked parameter
287300 self ._sync_expert_views ()
@@ -298,13 +311,15 @@ def _sync_expert_views(self) -> None:
298311 gate_up_w = self .experts_gate_up_weight
299312 if isinstance (gate_up_w , DTensor ):
300313 gate_up_w = gate_up_w .to_local ()
301- for i in range (self .num_local_experts ):
314+ num_local = gate_up_w .shape [0 ]
315+ for i in range (num_local ):
302316 object .__setattr__ (self .experts_gate_up , f"weight{ i } " , gate_up_w [i ])
303317
304318 down_w = self .experts_down_weight
305319 if isinstance (down_w , DTensor ):
306320 down_w = down_w .to_local ()
307- for i in range (self .num_local_experts ):
321+ num_local_down = down_w .shape [0 ]
322+ for i in range (num_local_down ):
308323 object .__setattr__ (self .experts_down , f"weight{ i } " , down_w [i ])
309324
310325 def set_ep_group (self , ep_group : dist .ProcessGroup , ep_mesh : DeviceMesh ) -> None :
@@ -503,12 +518,20 @@ def __init__(
503518 self ._fp8_recipe : transformer_engine .common .recipe .Recipe | None = fp8_recipe
504519 self ._fp4_recipe : transformer_engine .common .recipe .Recipe | None = fp4_recipe
505520
506- if fp8_recipe is not None and self .config .layer_precision is None :
507- if fp4_recipe is not None :
521+ if self .config .layer_precision is None :
522+ if fp8_recipe is not None and fp4_recipe is not None :
508523 raise RuntimeError ("Both FP8 and FP4 recipes provided, but no layer precision provided." )
509-
510- warnings .warn ("No layer precision provided, using FP8 recipe for all layers." , UserWarning )
511- self .config .layer_precision = ["fp8" ] * self .config .num_hidden_layers
524+ if fp8_recipe is not None :
525+ warnings .warn ("No layer precision provided, using FP8 recipe for all layers." , UserWarning )
526+ self .config .layer_precision = ["fp8" ] * self .config .num_hidden_layers
527+ elif fp4_recipe is not None :
528+ raise RuntimeError (
529+ "FP4 recipe provided but no layer_precision configured. "
530+ "Set layer_precision explicitly when using FP4."
531+ )
532+
533+ if self .config .layer_precision is not None and "fp4" in self .config .layer_precision and fp4_recipe is None :
534+ raise RuntimeError ("layer_precision contains 'fp4' entries but no fp4_recipe was provided." )
512535
513536 self .embed_tokens = nn .Embedding (config .vocab_size , config .hidden_size , self .padding_idx , dtype = config .dtype )
514537
@@ -857,6 +880,10 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
857880class HFInferenceParams (InferenceParams ):
858881 """Extension of the InferenceParams class to support HF generate() and beam search."""
859882
883+ # Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
884+ # custom TE-based cache is not compatible with torch.compile generate().
885+ is_compileable = False
886+
860887 def get_seq_length (self , layer_idx : int = 0 ) -> int :
861888 """Return the current cached sequence length.
862889
0 commit comments