@@ -100,7 +100,6 @@ def __init__(
100100 prefix : str = "" ,
101101 quant_config : Optional [QuantizationConfig ] = None ,
102102 alt_stream : Optional [torch .cuda .Stream ] = None ,
103- fuse_wk_and_weights_proj : bool = False ,
104103 ):
105104 super ().__init__ ()
106105 self .hidden_size = hidden_size
@@ -111,7 +110,6 @@ def __init__(
111110 self .q_lora_rank = q_lora_rank
112111 self .layer_id = layer_id
113112 self .alt_stream = alt_stream
114- self .fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
115113 if is_cuda ():
116114 self .sm_count = deep_gemm .get_num_sms ()
117115 self .half_device_sm_count = ceil_align (self .sm_count // 2 , 8 )
@@ -123,28 +121,22 @@ def __init__(
123121 quant_config = quant_config ,
124122 prefix = add_prefix ("wq_b" , prefix ),
125123 )
126- if self .fuse_wk_and_weights_proj :
127- self .fused_wk_and_weights_proj = ReplicatedLinear (
128- self .hidden_size ,
129- self .head_dim + self .n_heads ,
130- bias = False ,
131- prefix = add_prefix ("fused_wk_and_weights_proj" , prefix ),
132- )
133- else :
134- self .wk = ReplicatedLinear (
135- self .hidden_size ,
136- self .head_dim ,
137- bias = False ,
138- quant_config = quant_config ,
139- prefix = add_prefix ("wk" , prefix ),
140- )
141- # NOTE: weight_proj is not quantized
142- self .weights_proj = ReplicatedLinear (
143- self .hidden_size ,
144- self .n_heads ,
145- bias = False ,
146- prefix = add_prefix ("weights_proj" , prefix ),
147- )
124+
125+ self .wk = ReplicatedLinear (
126+ self .hidden_size ,
127+ self .head_dim ,
128+ bias = False ,
129+ quant_config = quant_config ,
130+ prefix = add_prefix ("wk" , prefix ),
131+ )
132+ # NOTE: weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenience
133+ self .weights_proj = ReplicatedLinear (
134+ self .hidden_size ,
135+ self .n_heads ,
136+ bias = False ,
137+ params_dtype = torch .float32 ,
138+ prefix = add_prefix ("weights_proj" , prefix ),
139+ )
148140 self .k_norm = LayerNorm (self .head_dim , dtype = torch .float32 )
149141 self .rotary_emb = get_rope_wrapper (
150142 rope_head_dim ,
@@ -172,7 +164,6 @@ def _get_q_k_bf16(
172164 positions : torch .Tensor ,
173165 enable_dual_stream : bool ,
174166 ):
175- weights = None
176167 if enable_dual_stream :
177168 current_stream = torch .cuda .current_stream ()
178169 self .alt_stream .wait_stream (current_stream )
@@ -189,12 +180,7 @@ def _get_q_k_bf16(
189180 )
190181 with torch .cuda .stream (self .alt_stream ):
191182 # TODO we should also put DeepGEMM half SM here?
192- if self .fuse_wk_and_weights_proj :
193- key , weights = self .fused_wk_and_weights_proj (x )[0 ].split (
194- [self .head_dim , self .n_heads ], dim = - 1
195- )
196- else :
197- key , _ = self .wk (x )
183+ key , _ = self .wk (x )
198184 key = self .k_norm (key )
199185
200186 k_rope , _ = torch .split (
@@ -207,17 +193,10 @@ def _get_q_k_bf16(
207193 else :
208194 query , _ = self .wq_b (q_lora )
209195 query = rearrange (query , "l (h d) -> l h d" , d = self .head_dim )
210-
211196 q_rope , _ = torch .split (
212197 query , [self .rope_head_dim , self .head_dim - self .rope_head_dim ], dim = - 1
213198 )
214-
215- if self .fuse_wk_and_weights_proj :
216- key , weights = self .fused_wk_and_weights_proj (x )[0 ].split (
217- [self .head_dim , self .n_heads ], dim = - 1
218- )
219- else :
220- key , _ = self .wk (x )
199+ key , _ = self .wk (x )
221200 key = self .k_norm (key )
222201 k_rope , _ = torch .split (
223202 key , [self .rope_head_dim , self .head_dim - self .rope_head_dim ], dim = - 1
@@ -240,21 +219,16 @@ def _get_q_k_bf16(
240219 query = rotate_activation (query )
241220 key = rotate_activation (key )
242221
243- return query , key , weights
222+ return query , key
244223
245224 def _get_k_bf16 (
246225 self ,
247226 x : torch .Tensor ,
248227 positions : torch .Tensor ,
249228 enable_dual_stream : bool ,
250229 ):
251- # Compute only key, skip query and weights (weights is discarded if fused)
252- if self .fuse_wk_and_weights_proj :
253- key , _ = self .fused_wk_and_weights_proj (x )[0 ].split (
254- [self .head_dim , self .n_heads ], dim = - 1
255- )
256- else :
257- key , _ = self .wk (x )
230+ # Compute only key, skip query
231+ key , _ = self .wk (x )
258232 key = self .k_norm (key )
259233 k_rope , _ = torch .split (
260234 key , [self .rope_head_dim , self .head_dim - self .rope_head_dim ], dim = - 1
@@ -606,9 +580,7 @@ def forward_cuda(
606580 return_indices ,
607581 )
608582
609- query , key , weights = self ._get_q_k_bf16 (
610- q_lora , x , positions , enable_dual_stream
611- )
583+ query , key = self ._get_q_k_bf16 (q_lora , x , positions , enable_dual_stream )
612584
613585 if enable_dual_stream :
614586 current_stream = torch .cuda .current_stream ()
@@ -635,8 +607,7 @@ def forward_cuda(
635607 index_k_scale = k_scale ,
636608 )
637609
638- if not self .fuse_wk_and_weights_proj :
639- weights , _ = self .weights_proj (x )
610+ weights , _ = self .weights_proj (x .float ())
640611 weights = self ._get_logits_head_gate (weights , q_scale )
641612
642613 if is_cuda ():
@@ -801,7 +772,7 @@ def forward_npu(
801772 past_key_states = forward_batch .token_to_kv_pool .get_index_k_buffer (layer_id )
802773
803774 x = x .view (- 1 , self .hidden_size )
804- weights = self .weights_proj (x )[0 ]
775+ weights = self .weights_proj (x . float () )[0 ]
805776 block_table = (
806777 block_table [: actual_seq_lengths_q .size ()[0 ]] if is_prefill else block_table
807778 )
0 commit comments