1313# limitations under the License.
1414
1515import math
16+
1617import torch
1718import torch .nn as nn
19+
1820from auto_round .utils import convert_dtype_torch2str , logger
1921
2022try :
2123 import auto_round_kernel as ark
24+
2225 ARK_INSTALLED = True
2326except :
2427 ARK_INSTALLED = False
3134
3235AWQ_REVERSE_ORDER = [0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ]
3336
37+
3438def unpack_awq (qweight : torch .Tensor , qzeros : torch .Tensor , bits : int ):
3539 shifts = torch .arange (0 , 32 , bits , device = "cpu" )
3640
@@ -51,6 +55,7 @@ def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
5155
5256 return iweights , izeros
5357
58+
5459def reverse_awq_order (iweights : torch .Tensor , izeros : torch .Tensor , bits : int ):
5560 reverse_order_tensor = torch .arange (
5661 iweights .shape [- 1 ],
@@ -66,6 +71,7 @@ def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
6671 iweights = iweights [:, reverse_order_tensor ]
6772 return iweights , izeros
6873
74+
6975def convert_dtype_torch2str (dtype ):
7076 if dtype == torch .int8 :
7177 return "int8"
@@ -80,14 +86,13 @@ def convert_dtype_torch2str(dtype):
8086 else :
8187 assert False , "Unsupported pytorch dtype {} to str dtype" .format (dtype )
8288
89+
8390class QuantLinearAWQ (nn .Module ):
8491 QUANT_TYPE = "ark_awq"
8592
8693 def __init__ (self , w_bit , group_size , in_features , out_features , bias , zero_point , dev ):
8794 super ().__init__ ()
88- assert (
89- ARK_INSTALLED
90- ), "Please install auto_round_kernel package."
95+ assert ARK_INSTALLED , "Please install auto_round_kernel package."
9196
9297 self .use_bf16 = ark .check_isa_supported ("AMX" )
9398
@@ -176,9 +181,7 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze
176181
177182 @torch .no_grad ()
178183 def forward (self , x ):
179- assert ARK_INSTALLED , (
180- "ARK kernels could not be loaded. "
181- )
184+ assert ARK_INSTALLED , "ARK kernels could not be loaded. "
182185
183186 input_dtype = x .dtype
184187 out_shape = x .shape [:- 1 ] + (self .out_features ,)
@@ -214,6 +217,7 @@ def extra_repr(self) -> str:
214217 self .group_size ,
215218 )
216219
220+
217221class QuantLinear (nn .Module ):
218222 QUANT_TYPE = "ark_gptq_nozp"
219223 ZP_BIAS = 0
@@ -234,9 +238,7 @@ def __init__(
234238
235239 if bits not in [2 , 4 , 8 ]:
236240 raise NotImplementedError ("Only 2, 4,8 bits are supported for ARK." )
237- assert (
238- ARK_INSTALLED
239- ), "Please install auto_round_kernel."
241+ assert ARK_INSTALLED , "Please install auto_round_kernel."
240242
241243 self .infeatures = infeatures
242244 self .outfeatures = outfeatures
@@ -275,9 +277,8 @@ def __init__(
275277 self .kernel_switch_threshold = kernel_switch_threshold
276278 self .trainable = trainable
277279
278-
279280 def post_init (self ):
280- assert self .qweight .device .type in ["cpu" , ' xpu' ]
281+ assert self .qweight .device .type in ["cpu" , " xpu" ]
281282 # intweight: k x n, zeros: k / group_size x n
282283 intweight , zeros = unpack_to_8bit_signed (self .qweight , self .qzeros , self .bits , self .ZP_BIAS )
283284 if zeros is None :
@@ -289,7 +290,7 @@ def post_init(self):
289290 zeros = (zeros .to (torch .int32 ) - (2 ** (self .bits - 1 ))).to (torch .int8 )
290291 else :
291292 zeros -= 2 ** (self .bits - 1 )
292- if self .qweight .device .type != ' cpu' :
293+ if self .qweight .device .type != " cpu" :
293294 assert not self .asym
294295 if not self .asym :
295296 intweight -= 2 ** (self .bits - 1 )
@@ -299,21 +300,20 @@ def post_init(self):
299300 if self .asym :
300301 intweight = (intweight .to (torch .int32 ) - (2 ** (self .bits - 1 ))).to (torch .int8 )
301302
302-
303303 logger .debug (
304304 f"ARK repack quantized weight: K:{ intweight .shape [0 ]} , N:{ intweight .shape [1 ]} , weight_dtype:{ BITS_DTYPE_MAPPING [self .bits ]} , scale_dtype:fp32, compute_dtype:fp32, group_size:{ self .group_size } "
305305 )
306306
307- if self .qweight .device .type == ' xpu' :
308- self .sdt = ' fp16'
309- self .cdt = ' fp16'
307+ if self .qweight .device .type == " xpu" :
308+ self .sdt = " fp16"
309+ self .cdt = " fp16"
310310 scales = self .scales .to (torch .float16 ).contiguous ()
311311 else :
312- self .sdt = ' fp32'
313- self .cdt = ' fp32'
312+ self .sdt = " fp32"
313+ self .cdt = " fp32"
314314 scales = self .scales .float ().contiguous ()
315315 self .wdt = BITS_DTYPE_MAPPING [self .bits ]
316-
316+
317317 self .qweight = ark .repack_quantized_weight (
318318 intweight .contiguous (),
319319 scales ,
@@ -325,11 +325,10 @@ def post_init(self):
325325 self .wdt ,
326326 # scale_dtype
327327 self .sdt ,
328-
329328 self .asym ,
330329 self .group_size ,
331330 )
332-
331+
333332 # self.revert_wei = torch.zeros(self.infeatures, self.outfeatures, dtype=scales.dtype, device=self.qweight.device)
334333 # # print(packw, packw.device, packw.dtype)
335334 # ark.dequantize_packed_weight(
@@ -338,20 +337,20 @@ def post_init(self):
338337 self .qzeros = torch .empty (0 )
339338 self .scales = torch .empty (0 )
340339 if self .bias is not None :
341- if self .bias .device .type == ' cpu' :
340+ if self .bias .device .type == " cpu" :
342341 self .bias = self .bias .to (torch .float32 )
343342 else :
344343 self .bias = self .bias .to (torch .float16 )
345344
346345 def forward (self , x : torch .Tensor ):
347346 raw_input_dtype = x .dtype
348- if x .device .type == ' cpu' :
347+ if x .device .type == " cpu" :
349348 odt = torch .float32
350349 if raw_input_dtype != torch .float32 :
351350 x = x .to (torch .float32 )
352351 else :
353352 odt = x .dtype
354-
353+
355354 out_shape = x .shape [:- 1 ] + (self .outfeatures ,)
356355 x = x .view (- 1 , x .shape [- 1 ]) # convert xd to 2d
357356 out_2d_shape = x .shape [:- 1 ] + (self .outfeatures ,)
@@ -367,16 +366,18 @@ def forward(self, x: torch.Tensor):
367366 self .wdt , # weight_dtype
368367 self .sdt , # scale_dtype
369368 self .asym ,
370- self .group_size
369+ self .group_size ,
371370 )
372- if x .device .type == ' xpu' :
371+ if x .device .type == " xpu" :
373372 outputs = outputs + bias
374373 return outputs .to (raw_input_dtype ).view (out_shape )
375374
375+
376376class QuantLinearGPTQ (QuantLinear ):
377377 QUANT_TYPE = "ark_gptq"
378378 ZP_BIAS = 1
379379
380+
380381@torch .no_grad ()
381382def unpack_to_8bit_signed (qweight , qzeros , bits , gptq_bias = 1 ):
382383 wf = torch .tensor (list (range (0 , 32 , bits )), dtype = torch .int32 , device = qweight .device ).unsqueeze (0 )
@@ -407,6 +408,7 @@ def unpack_to_8bit_signed(qweight, qzeros, bits, gptq_bias=1):
407408
408409 return weight , zeros
409410
411+
410412# Copied from qlinear_marlin.py
411413@torch .no_grad ()
412414def dequantize_weight (qweight , qzeros , scales , bits ):
@@ -416,16 +418,18 @@ def dequantize_weight(qweight, qzeros, scales, bits):
416418 if unpacked_qzeros is not None :
417419 unpacked_qzeros = unpacked_qzeros .repeat_interleave (group_size , dim = 0 )
418420 else :
419- unpacked_qzeros = torch .full_like (scales , 8 if bits == 4 else 128 , dtype = torch .int32 , device = qweight .device )
421+ unpacked_qzeros = torch .full_like (scales , 8 if bits == 4 else 128 , dtype = torch .int32 , device = qweight .device )
420422 unpacked_qweight = (unpacked_qweight - unpacked_qzeros ) * scales
421423
422424 return unpacked_qweight , unpacked_qzeros
423425
426+
424427def ark_post_init (model ):
425428 for _ , submodule in model .named_modules ():
426429 if isinstance (submodule , QuantLinear ):
427430 submodule .post_init ()
428431
429432 return model
430433
431- __all__ = ["QuantLinear" , 'QuantLinearGPTQ' , 'QuantLinearAWQ' ]
434+
435+ __all__ = ["QuantLinear" , "QuantLinearGPTQ" , "QuantLinearAWQ" ]
0 commit comments