Skip to content

Commit c3de745

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 174cfcb commit c3de745

File tree

2 files changed

+33
-28
lines changed

2 files changed

+33
-28
lines changed

auto_round_extension/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from auto_round_extension.qlinear import QuantLinear, QuantLinearGPTQ, QuantLinearAWQ
16+
1617
qlinear_classes = (QuantLinear, QuantLinearGPTQ)
1718

1819
awq_classes = (QuantLinearAWQ,)

auto_round_extension/qlinear.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414

1515
import math
16+
1617
import torch
1718
import torch.nn as nn
19+
1820
from auto_round.utils import convert_dtype_torch2str, logger
1921

2022
try:
2123
import auto_round_kernel as ark
24+
2225
ARK_INSTALLED = True
2326
except:
2427
ARK_INSTALLED = False
@@ -31,6 +34,7 @@
3134

3235
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
3336

37+
3438
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
3539
shifts = torch.arange(0, 32, bits, device="cpu")
3640

@@ -51,6 +55,7 @@ def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
5155

5256
return iweights, izeros
5357

58+
5459
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
5560
reverse_order_tensor = torch.arange(
5661
iweights.shape[-1],
@@ -66,6 +71,7 @@ def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
6671
iweights = iweights[:, reverse_order_tensor]
6772
return iweights, izeros
6873

74+
6975
def convert_dtype_torch2str(dtype):
7076
if dtype == torch.int8:
7177
return "int8"
@@ -80,14 +86,13 @@ def convert_dtype_torch2str(dtype):
8086
else:
8187
assert False, "Unsupported pytorch dtype {} to str dtype".format(dtype)
8288

89+
8390
class QuantLinearAWQ(nn.Module):
8491
QUANT_TYPE = "ark_awq"
8592

8693
def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
8794
super().__init__()
88-
assert (
89-
ARK_INSTALLED
90-
), "Please install auto_round_kernel package."
95+
assert ARK_INSTALLED, "Please install auto_round_kernel package."
9196

9297
self.use_bf16 = ark.check_isa_supported("AMX")
9398

@@ -176,9 +181,7 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze
176181

177182
@torch.no_grad()
178183
def forward(self, x):
179-
assert ARK_INSTALLED, (
180-
"ARK kernels could not be loaded. "
181-
)
184+
assert ARK_INSTALLED, "ARK kernels could not be loaded. "
182185

183186
input_dtype = x.dtype
184187
out_shape = x.shape[:-1] + (self.out_features,)
@@ -214,6 +217,7 @@ def extra_repr(self) -> str:
214217
self.group_size,
215218
)
216219

220+
217221
class QuantLinear(nn.Module):
218222
QUANT_TYPE = "ark_gptq_nozp"
219223
ZP_BIAS = 0
@@ -234,9 +238,7 @@ def __init__(
234238

235239
if bits not in [2, 4, 8]:
236240
raise NotImplementedError("Only 2, 4,8 bits are supported for ARK.")
237-
assert (
238-
ARK_INSTALLED
239-
), "Please install auto_round_kernel."
241+
assert ARK_INSTALLED, "Please install auto_round_kernel."
240242

241243
self.infeatures = infeatures
242244
self.outfeatures = outfeatures
@@ -275,9 +277,8 @@ def __init__(
275277
self.kernel_switch_threshold = kernel_switch_threshold
276278
self.trainable = trainable
277279

278-
279280
def post_init(self):
280-
assert self.qweight.device.type in ["cpu", 'xpu']
281+
assert self.qweight.device.type in ["cpu", "xpu"]
281282
# intweight: k x n, zeros: k / group_size x n
282283
intweight, zeros = unpack_to_8bit_signed(self.qweight, self.qzeros, self.bits, self.ZP_BIAS)
283284
if zeros is None:
@@ -289,7 +290,7 @@ def post_init(self):
289290
zeros = (zeros.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8)
290291
else:
291292
zeros -= 2 ** (self.bits - 1)
292-
if self.qweight.device.type != 'cpu':
293+
if self.qweight.device.type != "cpu":
293294
assert not self.asym
294295
if not self.asym:
295296
intweight -= 2 ** (self.bits - 1)
@@ -299,21 +300,20 @@ def post_init(self):
299300
if self.asym:
300301
intweight = (intweight.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8)
301302

302-
303303
logger.debug(
304304
f"ARK repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}"
305305
)
306306

307-
if self.qweight.device.type == 'xpu':
308-
self.sdt = 'fp16'
309-
self.cdt = 'fp16'
307+
if self.qweight.device.type == "xpu":
308+
self.sdt = "fp16"
309+
self.cdt = "fp16"
310310
scales = self.scales.to(torch.float16).contiguous()
311311
else:
312-
self.sdt = 'fp32'
313-
self.cdt = 'fp32'
312+
self.sdt = "fp32"
313+
self.cdt = "fp32"
314314
scales = self.scales.float().contiguous()
315315
self.wdt = BITS_DTYPE_MAPPING[self.bits]
316-
316+
317317
self.qweight = ark.repack_quantized_weight(
318318
intweight.contiguous(),
319319
scales,
@@ -325,11 +325,10 @@ def post_init(self):
325325
self.wdt,
326326
# scale_dtype
327327
self.sdt,
328-
329328
self.asym,
330329
self.group_size,
331330
)
332-
331+
333332
# self.revert_wei = torch.zeros(self.infeatures, self.outfeatures, dtype=scales.dtype, device=self.qweight.device)
334333
# # print(packw, packw.device, packw.dtype)
335334
# ark.dequantize_packed_weight(
@@ -338,20 +337,20 @@ def post_init(self):
338337
self.qzeros = torch.empty(0)
339338
self.scales = torch.empty(0)
340339
if self.bias is not None:
341-
if self.bias.device.type == 'cpu':
340+
if self.bias.device.type == "cpu":
342341
self.bias = self.bias.to(torch.float32)
343342
else:
344343
self.bias = self.bias.to(torch.float16)
345344

346345
def forward(self, x: torch.Tensor):
347346
raw_input_dtype = x.dtype
348-
if x.device.type == 'cpu':
347+
if x.device.type == "cpu":
349348
odt = torch.float32
350349
if raw_input_dtype != torch.float32:
351350
x = x.to(torch.float32)
352351
else:
353352
odt = x.dtype
354-
353+
355354
out_shape = x.shape[:-1] + (self.outfeatures,)
356355
x = x.view(-1, x.shape[-1]) # convert xd to 2d
357356
out_2d_shape = x.shape[:-1] + (self.outfeatures,)
@@ -367,16 +366,18 @@ def forward(self, x: torch.Tensor):
367366
self.wdt, # weight_dtype
368367
self.sdt, # scale_dtype
369368
self.asym,
370-
self.group_size
369+
self.group_size,
371370
)
372-
if x.device.type == 'xpu':
371+
if x.device.type == "xpu":
373372
outputs = outputs + bias
374373
return outputs.to(raw_input_dtype).view(out_shape)
375374

375+
376376
class QuantLinearGPTQ(QuantLinear):
377377
QUANT_TYPE = "ark_gptq"
378378
ZP_BIAS = 1
379379

380+
380381
@torch.no_grad()
381382
def unpack_to_8bit_signed(qweight, qzeros, bits, gptq_bias=1):
382383
wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qweight.device).unsqueeze(0)
@@ -407,6 +408,7 @@ def unpack_to_8bit_signed(qweight, qzeros, bits, gptq_bias=1):
407408

408409
return weight, zeros
409410

411+
410412
# Copied from qlinear_marlin.py
411413
@torch.no_grad()
412414
def dequantize_weight(qweight, qzeros, scales, bits):
@@ -416,16 +418,18 @@ def dequantize_weight(qweight, qzeros, scales, bits):
416418
if unpacked_qzeros is not None:
417419
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
418420
else:
419-
unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32, device = qweight.device)
421+
unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32, device=qweight.device)
420422
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
421423

422424
return unpacked_qweight, unpacked_qzeros
423425

426+
424427
def ark_post_init(model):
425428
for _, submodule in model.named_modules():
426429
if isinstance(submodule, QuantLinear):
427430
submodule.post_init()
428431

429432
return model
430433

431-
__all__ = ["QuantLinear", 'QuantLinearGPTQ', 'QuantLinearAWQ']
434+
435+
__all__ = ["QuantLinear", "QuantLinearGPTQ", "QuantLinearAWQ"]

0 commit comments

Comments
 (0)