161161 check_numerics ,
162162 flatten_items ,
163163 get_or_none ,
164+ maybe_shard ,
164165 save_and_offload_only_these_names_regex ,
165166 shapes ,
166167 split_prng_key ,
167- with_sharding_constraint ,
168168)
169169
170170
@@ -1560,18 +1560,18 @@ class Config(BaseLayer.Config):
15601560 logit_sink : Optional [bool ] = None
15611561
15621562 # Partition spec for query ([batch, seq, q_heads, head_dim]) after input projections.
1563- q_partition_spec : Optional [PartitionSpec ] = None
1563+ q_partition_spec : Optional [Sequence [ Union [ str , Sequence [ str ], None ]] ] = None
15641564
15651565 # Partition spec for key ([batch, seq, kv_heads, head_dim]) after input projections.
15661566 # Follows `q_partition_spec` if None.
1567- k_partition_spec : Optional [PartitionSpec ] = None
1567+ k_partition_spec : Optional [Sequence [ Union [ str , Sequence [ str ], None ]] ] = None
15681568
15691569 # Partition spec for value ([batch, seq, kv_heads, head_dim]) after input projections.
15701570 # Follows `q_partition_spec` if None.
1571- v_partition_spec : Optional [PartitionSpec ] = None
1571+ v_partition_spec : Optional [Sequence [ Union [ str , Sequence [ str ], None ]] ] = None
15721572
15731573 # Partition spec for output ([batch, seq, hidden_dim]) after output projections.
1574- o_partition_spec : Optional [PartitionSpec ] = None
1574+ o_partition_spec : Optional [Sequence [ Union [ str , Sequence [ str ], None ]] ] = None
15751575
15761576 def __init__ (self , cfg : Config , * , parent : Module ):
15771577 super ().__init__ (cfg , parent = parent )
@@ -1736,12 +1736,9 @@ def _forward_for_mode(
17361736 time_step = cached_states ["time_step" ]
17371737 query_positions = query_positions + time_step [:, None ] # [batch, steps]
17381738 q_proj , k_proj , v_proj = self .i_proj (query , query_positions = query_positions , ** kv_kwargs )
1739- if cfg .q_partition_spec :
1740- q_proj = with_sharding_constraint (q_proj , cfg .q_partition_spec )
1741- if cfg .q_partition_spec or cfg .k_partition_spec :
1742- k_proj = with_sharding_constraint (k_proj , cfg .k_partition_spec or cfg .q_partition_spec )
1743- if cfg .q_partition_spec or cfg .v_partition_spec :
1744- v_proj = with_sharding_constraint (v_proj , cfg .v_partition_spec or cfg .q_partition_spec )
1739+ q_proj = maybe_shard (q_proj , cfg .q_partition_spec )
1740+ k_proj = maybe_shard (k_proj , cfg .k_partition_spec or cfg .q_partition_spec )
1741+ v_proj = maybe_shard (v_proj , cfg .v_partition_spec or cfg .q_partition_spec )
17451742
17461743 if cfg .scale_kv_before_cache_update :
17471744 if has_external_kv_state :
@@ -1844,8 +1841,7 @@ def _forward_for_mode(
18441841
18451842 # [batch, target_length, output_dim].
18461843 o_proj = self .o_proj (context )
1847- if cfg .o_partition_spec :
1848- o_proj = with_sharding_constraint (o_proj , cfg .o_partition_spec )
1844+ o_proj = maybe_shard (o_proj , cfg .o_partition_spec )
18491845 outputs = self ._remat_name (o_proj , "o_proj" )
18501846 self ._add_tensor_stats ("o_proj_outputs" , outputs )
18511847 return_aux = return_aux or set ()
@@ -3608,15 +3604,17 @@ def extend_step(
36083604def set_attention_partition_specs (
36093605 cfg : MultiheadAttention .Config ,
36103606 * ,
3607+ batch_axis_names : Union [str , Sequence [str ]] = ("data" , "fsdp" ),
36113608 fsdp_axis_names : Union [str , Sequence [str ]] = "fsdp" ,
36123609 tp_axis_names : Union [str , Sequence [str ]] = "model" ,
3610+ seq_axis_names : Union [str , Sequence [str ]] = "seq" ,
3611+ set_attn_activation_specs : bool = False ,
36133612):
36143613 """Sets `cfg` to shard attention weights over both fsdp and tp axes.
36153614
36163615 Args:
36173616 cfg: A MultiheadAttention layer config to apply sharding spec to.
3618- fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
3619- tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
3617+ **kwargs: See `set_double_shard_weights_config`.
36203618 """
36213619 # Shard weights.
36223620 input_linear_cfg = cfg .input_linear
@@ -3625,6 +3623,10 @@ def set_attention_partition_specs(
36253623 input_linear_cfg .layer .param_partition_spec = (fsdp_axis_names , tp_axis_names , None )
36263624 cfg .output_linear .param_partition_spec = (fsdp_axis_names , tp_axis_names , None )
36273625
3626+ if set_attn_activation_specs :
3627+ cfg .q_partition_spec = (batch_axis_names , seq_axis_names , tp_axis_names , None )
3628+ cfg .o_partition_spec = (batch_axis_names , seq_axis_names , tp_axis_names )
3629+
36283630
36293631def set_feed_forward_partition_specs (
36303632 cfg : TransformerFeedForwardLayer .Config ,
@@ -3638,10 +3640,7 @@ def set_feed_forward_partition_specs(
36383640
36393641 Args:
36403642 cfg: A TransformerFeedForwardLayer layer config to apply sharding spec to.
3641- batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors.
3642- fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
3643- tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
3644- seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors.
3643+ **kwargs: See `set_double_shard_weights_config`.
36453644 """
36463645 # Shard weights.
36473646 cfg .linear1 .param_partition_spec = (fsdp_axis_names , tp_axis_names )
@@ -3658,6 +3657,7 @@ def set_double_shard_weights_config(
36583657 fsdp_axis_names : Union [str , Sequence [str ]] = "fsdp" ,
36593658 tp_axis_names : Union [str , Sequence [str ]] = "model" ,
36603659 seq_axis_names : Union [str , Sequence [str ]] = "seq" ,
3660+ set_attn_activation_specs : bool = False ,
36613661):
36623662 """Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes.
36633663
@@ -3667,32 +3667,35 @@ def set_double_shard_weights_config(
36673667 fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
36683668 tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
36693669 seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors.
3670+ set_attn_activation_specs: Whether to set activation spec of qkvo projections. This may be
3671+ required in for some complex sharding cases.
36703672 """
36713673
36723674 # pytype: disable=attribute-error
36733675 if not isinstance (cfg , Sequence ):
36743676 cfg = [cfg ]
36753677
3678+ axis_names = dict (
3679+ batch_axis_names = batch_axis_names ,
3680+ fsdp_axis_names = fsdp_axis_names ,
3681+ tp_axis_names = tp_axis_names ,
3682+ seq_axis_names = seq_axis_names ,
3683+ )
3684+
36763685 for layer_cfg in cfg :
36773686 set_attention_partition_specs (
36783687 layer_cfg .self_attention .attention ,
3679- fsdp_axis_names = fsdp_axis_names ,
3680- tp_axis_names = tp_axis_names ,
3688+ set_attn_activation_specs = set_attn_activation_specs ,
3689+ ** axis_names ,
36813690 )
36823691 if layer_cfg .cross_attention is not None :
36833692 set_attention_partition_specs (
36843693 layer_cfg .cross_attention .attention ,
3685- fsdp_axis_names = fsdp_axis_names ,
3686- tp_axis_names = tp_axis_names ,
3694+ set_attn_activation_specs = set_attn_activation_specs ,
3695+ ** axis_names ,
36873696 )
36883697 if isinstance (layer_cfg .feed_forward , TransformerFeedForwardLayer .Config ):
3689- set_feed_forward_partition_specs (
3690- layer_cfg .feed_forward ,
3691- batch_axis_names = batch_axis_names ,
3692- fsdp_axis_names = fsdp_axis_names ,
3693- tp_axis_names = tp_axis_names ,
3694- seq_axis_names = seq_axis_names ,
3695- )
3698+ set_feed_forward_partition_specs (layer_cfg .feed_forward , ** axis_names )
36963699 # pytype: enable=attribute-error
36973700
36983701
0 commit comments