diff --git a/src/pquant/core/activations_quantizer.py b/src/pquant/core/activations_quantizer.py index 31d8f02..3d742ec 100644 --- a/src/pquant/core/activations_quantizer.py +++ b/src/pquant/core/activations_quantizer.py @@ -1,4 +1,6 @@ import keras +import torch +import torch.nn as nn from hgq.quantizer import Quantizer from keras import ops from keras.ops import convert_to_tensor, maximum, minimum, tanh @@ -136,6 +138,75 @@ def call(self, x): return x +class QuantizedPooling(nn.Module): + + def __init__(self, config, layer): + super().__init__() + self.f = torch.tensor(config.quantization_parameters.default_fractional_bits) + self.i = torch.tensor(config.quantization_parameters.default_integer_bits) + self.overflow = "SAT_SYM" if config.quantization_parameters.use_symmetric_quantization else "SAT" + self.config = config + self.hgq_heterogeneous = config.quantization_parameters.hgq_heterogeneous + self.is_pretraining = True + self.use_high_granularity_quantization = config.quantization_parameters.use_high_granularity_quantization + self.pooling = layer + self.hgq_gamma = config.quantization_parameters.hgq_gamma + + def build(self, input_shape): + if self.use_high_granularity_quantization: + if self.hgq_heterogeneous: + self.hgq = Quantizer( + k0=1.0, + i0=self.i, + f0=self.f, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + homogeneous_axis=(0,), + ) + + else: + self.hgq = Quantizer( + k0=1.0, + i0=self.i, + f0=self.f, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + heterogeneous_axis=(), + ) + self.hgq.build(input_shape) + else: + self.quantizer = get_fixed_quantizer(round_mode="RND", overflow_mode=self.overflow) + + def set_activation_bits(self, i, f): + self.i = torch.tensor(i) + self.f = torch.tensor(f) + + def post_pre_train_function(self): + self.is_pretraining = False + + def hgq_loss(self): + if self.is_pretraining: + return 0.0 + return ( + torch.sum(self.hgq.quantizer.i) + torch.sum(self.hgq.quantizer.f) + ) * self.config.quantization_parameters.hgq_gamma + + def quantize(self, x): + if not hasattr(self, "hgq") or not hasattr(self, "quantizer"): + self.build(x.shape) + if self.use_high_granularity_quantization: + x = self.hgq(x) + else: + x = self.quantizer(x, k=torch.tensor(1.0), i=self.i, f=self.f, training=True) + return x + + def forward(self, x): + x = self.pooling(x) + return self.quantize(x) + + def hard_sigmoid(x): """Computes hard_sigmoid function that saturates between 0 and 1.""" x = 0.5 * x + 0.5 diff --git a/src/pquant/core/backend_interface.py b/src/pquant/core/backend_interface.py new file mode 100644 index 0000000..dbb9409 --- /dev/null +++ b/src/pquant/core/backend_interface.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod + + +class BackendInterface(ABC): + @abstractmethod + def add_default_layer_quantization_pruning_to_config(self, model, config): + pass + + @abstractmethod + def iterative_train(self, model, config, train_func, valid_func, **kwargs): + pass + + @abstractmethod + def remove_pruning_from_model(self, model, config): + pass + + @abstractmethod + def add_compression_layers(self, model, config, input_shape=None): + pass + + @abstractmethod + def post_epoch_functions(self, model, epoch, total_epochs, **kwargs): + pass + + @abstractmethod + def post_pretrain_functions(self, model, config): + pass + + @abstractmethod + def pre_epoch_functions(self, model, epoch, total_epochs): + pass + + @abstractmethod + def pre_finetune_functions(self, model): + pass + + @abstractmethod + def save_weights_functions(self, model): + pass + + @abstractmethod + def get_layer_keep_ratio(self, model): + pass + + @abstractmethod + def get_model_losses(self, model, losses): + pass + + def call_post_round_functions(self, model, rewind, rounds, r): + if rewind == "round": + self.rewind_weights_functions(model) + elif rewind == "post-ticket-search" and r == rounds - 1: + self.rewind_weights_functions(model) + else: + self.post_round_functions(model) diff --git a/src/pquant/core/compressed_layers.py b/src/pquant/core/compressed_layers.py index 0a02ea9..06d2db4 100644 --- a/src/pquant/core/compressed_layers.py +++ b/src/pquant/core/compressed_layers.py @@ -1,106 +1,595 @@ import keras +import torch +import torch.nn as nn +import torch.nn.functional as F +from hgq.quantizer import Quantizer +from keras import ops +from keras.layers import Layer +from quantizers import get_fixed_quantizer +from pquant.core.utils import get_backend, get_pruning_layer -def add_default_layer_quantization_pruning_to_config(model, config): - if keras.backend.backend() == "torch": - from pquant.core.torch_impl.compressed_layers_torch import ( - add_default_layer_quantization_pruning_to_config_torch, - ) - return add_default_layer_quantization_pruning_to_config_torch(model, config) - else: - from pquant.core.tf_impl.compressed_layers_tf import ( - add_default_layer_quantization_pruning_to_config_tf, - ) +# Compressed Layers for PyTorch +class CompressedLayerBase(nn.Module): + def __init__(self, config, layer, layer_type): + super().__init__() + self.f_weight = torch.tensor(config.quantization_parameters.default_fractional_bits) + self.i_weight = torch.tensor(config.quantization_parameters.default_integer_bits) + self.f_bias = torch.tensor(config.quantization_parameters.default_fractional_bits) + self.i_bias = torch.tensor(config.quantization_parameters.default_integer_bits) + self.weight = nn.Parameter(layer.weight.clone()) + self.pruning_layer = get_pruning_layer(config=config, layer_type=layer_type) + self.pruning_method = config.pruning_parameters.pruning_method + self.overflow = "SAT_SYM" if config.quantization_parameters.use_symmetric_quantization else "SAT" + self.quantizer = get_fixed_quantizer(overflow_mode=self.overflow) + self.hgq_heterogeneous = config.quantization_parameters.hgq_heterogeneous - return add_default_layer_quantization_pruning_to_config_tf(model, config) + self.bias = nn.Parameter(layer.bias.clone()) if layer.bias is not None else None + self.init_weight = self.weight.clone() + self.pruning_first = config.training_parameters.pruning_first + self.enable_quantization = config.quantization_parameters.enable_quantization + self.use_high_granularity_quantization = config.quantization_parameters.use_high_granularity_quantization + self.enable_pruning = config.pruning_parameters.enable_pruning + self.hgq_gamma = config.quantization_parameters.hgq_gamma + def build(self, input_shape): + if self.use_high_granularity_quantization: + if self.hgq_heterogeneous: + self.hgq_weight = Quantizer( + k0=1.0, + i0=self.i_weight, + f0=self.f_weight, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + homogeneous_axis=(), + ) + self.hgq_weight.build(self.weight.shape) + if self.bias is not None: + self.hgq_bias = Quantizer( + k0=1.0, + i0=self.i_bias, + f0=self.f_bias, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + homogeneous_axis=(), + ) + self.hgq_bias.build(self.bias.shape) + else: + self.hgq_weight = Quantizer( + k0=1.0, + i0=self.i_weight, + f0=self.f_weight, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + heterogeneous_axis=(), + ) + self.hgq_weight.build(self.weight.shape) + if self.bias is not None: + self.hgq_bias = Quantizer( + k0=1.0, + i0=self.i_bias, + f0=self.f_bias, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + heterogeneous_axis=(), + ) + self.hgq_bias.build(self.bias.shape) -def add_compression_layers(model, config, input_shape): - if keras.backend.backend() == "torch": - from pquant.core.torch_impl.compressed_layers_torch import ( - add_compression_layers_torch, + def save_weights(self): + self.init_weight = self.weight.clone() + + def rewind_weights(self): + self.weight.data = self.init_weight.clone() + + def hgq_loss(self): + if self.pruning_layer.is_pretraining: + return 0.0 + loss = (torch.sum(self.hgq_weight.quantizer.i) + torch.sum(self.hgq_weight.quantizer.f)) * self.hgq_gamma + if self.bias is not None: + loss += (torch.sum(self.hgq_bias.quantizer.i) + torch.sum(self.hgq_bias.quantizer.f)) * self.hgq_gamma + return loss + + def quantize(self, weight, bias): + if self.enable_quantization: + if self.use_high_granularity_quantization: + weight = self.hgq_weight(weight) + bias = None if bias is None else self.hgq_bias(bias) + else: + weight = self.quantizer(weight, k=torch.tensor(1.0), i=self.i_weight, f=self.f_weight, training=True) + bias = ( + None + if bias is None + else self.quantizer(bias, k=torch.tensor(1.0), i=self.i_bias, f=self.f_bias, training=True) + ) + return weight, bias + + def prune(self, weight): + if self.enable_pruning: + weight = self.pruning_layer(weight) + return weight + + def prune_and_quantize(self, weight, bias): + if self.pruning_first: + weight = self.prune(weight) + weight, bias = self.quantize(weight, bias) + else: + weight, bias = self.quantize(weight, bias) + weight = self.prune(weight) + return weight, bias + + def forward(self, x): + weight, bias = self.prune_and_quantize(self.weight, self.bias) + if self.pruning_method == "wanda": + self.pruning_layer.collect_input(x, self.weight, self.training) + x = F.linear(x, weight, bias) + if self.pruning_method == "activation_pruning": + self.pruning_layer.collect_output(x, self.training) + return x + + +class CompressedLayerLinear(CompressedLayerBase): + def __init__(self, config, layer, layer_type): + super().__init__(config, layer, layer_type) + self.in_features = layer.in_features + self.out_features = layer.out_features + + def forward(self, x): + weight, bias = self.prune_and_quantize(self.weight, self.bias) + if self.pruning_method == "wanda": + self.pruning_layer.collect_input(x, self.weight, self.training) + x = F.linear(x, weight, bias) + if self.pruning_method == "activation_pruning": + self.pruning_layer.collect_output(x, self.training) + return x + + +class CompressedLayerConv2d(CompressedLayerBase): + def __init__(self, config, layer, layer_type): + super().__init__(config, layer, layer_type) + self.stride = layer.stride + self.dilation = layer.dilation + self.padding = layer.padding + self.groups = layer.groups + self.in_channels = layer.in_channels + self.out_channels = layer.out_channels + self.kernel_size = layer.kernel_size + self.padding_mode = layer.padding_mode + + def forward(self, x): + weight, bias = self.prune_and_quantize(self.weight, self.bias) + if self.pruning_method == "wanda": + self.pruning_layer.collect_input(x, weight, self.training) + x = F.conv2d( + input=x, + weight=weight, + bias=bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, ) + if self.pruning_method == "activation_pruning": + self.pruning_layer.collect_output(x, self.training) + return x - return add_compression_layers_torch(model, config, input_shape) - else: - from pquant.core.tf_impl.compressed_layers_tf import add_compression_layers_tf - return add_compression_layers_tf(model, config, input_shape) +class CompressedLayerConv1d(CompressedLayerBase): + def __init__(self, config, layer, layer_type): + super().__init__(config, layer, layer_type) + self.stride = layer.stride + self.dilation = layer.dilation + self.padding = layer.padding + self.groups = layer.groups + self.in_channels = layer.in_channels + self.out_channels = layer.out_channels + self.kernel_size = layer.kernel_size + self.padding_mode = layer.padding_mode -def get_layer_keep_ratio(model): - if keras.backend.backend() == "torch": - from pquant.core.torch_impl.compressed_layers_torch import ( - get_layer_keep_ratio_torch, + def forward(self, x): + weight, bias = self.prune_and_quantize(self.weight, self.bias) + if self.pruning_method == "wanda": + self.pruning_layer.collect_input(x, self.weight, self.training) + x = F.conv1d( + input=x, + weight=weight, + bias=bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, ) + if self.pruning_method == "activation_pruning": + self.pruning_layer.collect_output(x, self.training) + return x - return get_layer_keep_ratio_torch(model) - else: - from pquant.core.tf_impl.compressed_layers_tf import get_layer_keep_ratio_tf - return get_layer_keep_ratio_tf(model) +# Compressed layers for TF +class CompressedLayerBase(keras.layers.Layer): + def __init__(self, config, layer, layer_type): + super().__init__() + i_bits = config.quantization_parameters.default_integer_bits + f_bits = config.quantization_parameters.default_fractional_bits + self.i_weight = ops.convert_to_tensor(i_bits) + self.f_weight = ops.convert_to_tensor(f_bits) + self.i_bias = ops.convert_to_tensor(i_bits) + self.f_bias = ops.convert_to_tensor(f_bits) + self.pruning_layer = get_pruning_layer(config=config, layer_type=layer_type) + self.pruning_method = config.pruning_parameters.pruning_method + self.overflow = "SAT_SYM" if config.quantization_parameters.use_symmetric_quantization else "SAT" + self.hgq_gamma = config.quantization_parameters.hgq_gamma + self.pruning_first = config.training_parameters.pruning_first + self.enable_quantization = config.quantization_parameters.enable_quantization + self.use_high_granularity_quantization = config.quantization_parameters.use_high_granularity_quantization + self.hgq_heterogeneous = config.quantization_parameters.hgq_heterogeneous + self.enable_pruning = config.pruning_parameters.enable_pruning + self.do_transpose_data = None + self.weight_transpose = None + self.data_transpose = None -def get_model_losses(model, losses): - if keras.backend.backend() == "torch": - from pquant.core.torch_impl.compressed_layers_torch import ( - get_model_losses_torch, + def set_quantization_bits(self, i_bits_w, f_bits_w, i_bits_b, f_bits_b): + self.i_weight = ops.convert_to_tensor(i_bits_w) + self.f_weight = ops.convert_to_tensor(f_bits_w) + self.i_bias = ops.convert_to_tensor(i_bits_b) + self.f_bias = ops.convert_to_tensor(f_bits_b) + + def set_enable_pruning(self, enable_pruning): + self.enable_pruning = enable_pruning + + def build(self, input_shape): + super().build(input_shape) + if self.use_high_granularity_quantization: + if self.hgq_heterogeneous: + self.hgq_weight = Quantizer( + k0=1.0, + i0=self.i_weight, + f0=self.f_weight, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + homogeneous_axis=(), + ) + self.hgq_weight.build(self.weight.shape) + if self.use_bias: + self.hgq_bias = Quantizer( + k0=1.0, + i0=self.i_bias, + f0=self.f_bias, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + homogeneous_axis=(), + ) + self.hgq_bias.build(self.bias.shape) + else: + self.hgq_weight = Quantizer( + k0=1.0, + i0=self.i_weight, + f0=self.f_weight, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + heterogeneous_axis=(), + ) + self.hgq_weight.build(self.weight.shape) + if self.use_bias: + self.hgq_bias = Quantizer( + k0=1.0, + i0=self.i_bias, + f0=self.f_bias, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + heterogeneous_axis=(), + ) + self.hgq_bias.build(self.bias.shape) + else: + self.quantizer = get_fixed_quantizer(round_mode="RND", overflow_mode=self.overflow) + + def save_weights(self): + self.init_weight = self.weight.value + + def rewind_weights(self): + self.weight.assign(self.init_weight) + + def hgq_loss(self): + if self.pruning_layer.is_pretraining: + return 0.0 + loss = (ops.sum(self.hgq_weight.quantizer.i) + ops.sum(self.hgq_weight.quantizer.f)) * self.hgq_gamma + if self.bias is not None: + loss += (ops.sum(self.hgq_bias.quantizer.i) + ops.sum(self.hgq_bias.quantizer.f)) * self.hgq_gamma + return loss + + def handle_transpose(self, x, transpose, do_transpose=False): + if do_transpose: + x = ops.transpose(x, transpose) + return x + + def quantize_i(self, weight, bias): + if self.enable_quantization: + if self.use_high_granularity_quantization: + weight = self.hgq_weight(weight) + bias = None if bias is None else self.hgq_bias(bias) + else: + weight = self.quantizer( + weight, k=ops.convert_to_tensor(1.0), i=self.i_weight, f=self.f_weight, training=True + ) + bias = ( + None + if bias is None + else self.quantizer(bias, k=ops.convert_to_tensor(1.0), i=self.i_bias, f=self.f_bias, training=True) + ) + return weight, bias + + def prune(self, weight): + if self.enable_pruning: + weight = self.handle_transpose(weight, self.weight_transpose, True) + weight = self.pruning_layer(weight) + weight = self.handle_transpose(weight, self.weight_transpose_back, True) + return weight + + def prune_and_quantize(self, weight, bias): + weight = ops.cast(weight, weight.dtype) + bias = ops.cast(bias, bias.dtype) if bias is not None else None + if self.pruning_first: + weight = self.prune(weight) + weight, bias = self.quantize_i(weight, bias) + else: + weight, bias = self.quantize_i(weight, bias) + weight = self.prune(weight) + return weight, bias + + def call(self, x): + return x + + def collect_input(self, x, weight, training): + collect_x = self.handle_transpose(x, self.data_transpose, self.do_transpose_data) + weight_channels_first = self.handle_transpose(weight, self.weight_transpose, True) + self.pruning_layer.collect_input(collect_x, weight_channels_first, training) + + def collect_output(self, x, training): + collect_x = self.handle_transpose(x, self.data_transpose, self.do_transpose_data) + self.pruning_layer.collect_output(collect_x, training) + + +class CompressedLayerDepthwiseConv2dKeras(CompressedLayerBase): + def __init__(self, config, layer, layer_type): + super().__init__(config, layer, layer_type) + self.depthwise_regularizer = layer.depthwise_regularizer + self.use_bias = layer.use_bias + self.strides = layer.strides + self.dilation_rate = layer.dilation_rate + self.padding = layer.padding + self.kernel_size = layer.kernel_size + self.bias_shape = layer.bias.shape if layer.use_bias else None + self.init_bias = layer.bias.value if layer.use_bias else None + self.weight_shape = layer.kernel.shape + self.init_weight = layer.kernel.value + self.weight_transpose = (3, 2, 0, 1) + self.weight_transpose_back = (2, 3, 1, 0) + self.data_transpose = (0, 3, 1, 2) + self.do_transpose_data = layer.data_format == "channels_last" + + def build(self, input_shape): + self.weight = self.add_weight( + self.weight_shape, initializer=self.init_weight, trainable=True, regularizer=self.depthwise_regularizer + ) + self.bias = ( + self.add_weight(self.bias_shape, initializer=self.init_bias, trainable=True) + if self.bias_shape is not None + else None ) + super().build(input_shape) - return get_model_losses_torch(model, losses) - else: - from pquant.core.tf_impl.compressed_layers_tf import get_model_losses_tf + def call(self, x, training=None): + weight, bias = self.prune_and_quantize(self.weight, self.bias) + if self.pruning_method == "wanda": + self.collect_input(x, weight, training) + x = ops.depthwise_conv( + x, weight, strides=self.strides, padding=self.padding, data_format=None, dilation_rate=self.dilation_rate + ) + if self.pruning_method == "activation_pruning": + self.collect_output(x, training) + return x - return get_model_losses_tf(model, losses) +class CompressedLayerConv2dKeras(CompressedLayerBase): + def __init__(self, config, layer, layer_type): + super().__init__(config, layer, layer_type) + self.kernel_regularizer = layer.kernel_regularizer + self.filters = layer.filters + self.use_bias = layer.use_bias + self.strides = layer.strides + self.dilation_rate = layer.dilation_rate + self.padding = layer.padding + self.kernel_size = layer.kernel_size + if hasattr(layer, "groups"): + self.groups = layer.groups + self.bias_shape = layer.bias.shape if layer.use_bias else None + self.init_bias = layer.bias.value if layer.use_bias else None + self.weight_shape = layer.kernel.shape + self.init_weight = layer.kernel.value + self.weight_transpose = (3, 2, 0, 1) + self.weight_transpose_back = (2, 3, 1, 0) + self.data_transpose = (0, 3, 1, 2) + self.do_transpose_data = layer.data_format == "channels_last" -def remove_pruning_from_model(model, config): - if keras.backend.backend() == "torch": - from pquant.core.torch_impl.compressed_layers_torch import ( - remove_pruning_from_model_torch, + def build(self, input_shape): + self.weight = self.add_weight( + self.weight_shape, initializer=self.init_weight, trainable=True, regularizer=self.kernel_regularizer + ) + self.bias = ( + self.add_weight(self.bias_shape, initializer=self.init_bias, trainable=True) + if self.bias_shape is not None + else None ) + super().build(input_shape) - return remove_pruning_from_model_torch(model, config) - else: - from pquant.core.tf_impl.compressed_layers_tf import ( - remove_pruning_from_model_tf, + def call(self, x, training=None): + weight, bias = self.prune_and_quantize(self.weight, self.bias) + if self.pruning_method == "wanda": + self.collect_input(x, weight, training) + x = ops.conv( + x, weight, strides=self.strides, padding=self.padding, data_format=None, dilation_rate=self.dilation_rate ) + if self.bias is not None: + x = ops.add(x, bias) + if self.pruning_method == "activation_pruning": + self.collect_output(x, training) + return x - return remove_pruning_from_model_tf(model, config) +class CompressedLayerSeparableConv2dKeras(Layer): + def __init__(self, config, layer): + super().__init__() + self.weight_transpose = (3, 2, 0, 1) + self.weight_transpose_back = (2, 3, 1, 0) + self.data_transpose = (0, 3, 1, 2) + layer.kernel = layer.depthwise_kernel + bias = layer.use_bias + layer.use_bias = False + self.depthwise_conv = CompressedLayerDepthwiseConv2dKeras(config, layer, "conv") + layer.kernel_regularizer = layer.pointwise_regularizer + layer.kernel_size = 1 + layer.kernel = layer.pointwise_kernel + layer.use_bias = bias + self.pointwise_conv = CompressedLayerConv2dKeras(config, layer, "conv") + self.do_transpose_data = layer.data_format == "channels_last" -def post_training_prune(model, calibration_data, config): - if keras.backend.backend() == "torch": - from pquant.core.torch_impl.compressed_layers_torch import ( - add_compression_layers_torch, - post_pretrain_functions, - remove_pruning_from_model_torch, + def build(self, input_shape): + super().build(input_shape) + + def call(self, x, training=None): + x = self.depthwise_conv(x, training=training) + x = self.pointwise_conv(x, training=training) + return x + + +class CompressedLayerConv1dKeras(CompressedLayerBase): + def __init__(self, config, layer, layer_type): + super().__init__(config, layer, layer_type) + self.kernel_regularizer = layer.kernel_regularizer + self.filters = layer.filters + self.use_bias = layer.use_bias + self.strides = layer.strides + self.dilation_rate = layer.dilation_rate + self.padding = layer.padding + self.kernel_size = layer.kernel_size + self.groups = layer.groups + self.bias_shape = layer.bias.shape if layer.use_bias else None + self.init_bias = layer.bias.value if layer.use_bias else None + self.weight_shape = layer.kernel.shape + self.init_weight = layer.kernel.value + self.weight_transpose = (2, 1, 0) + self.weight_transpose_back = (2, 1, 0) + self.data_transpose = (0, 2, 1) + self.do_transpose_data = layer.data_format == "channels_last" + + def build(self, input_shape): + self.weight = self.add_weight( + self.weight_shape, initializer=self.init_weight, trainable=True, regularizer=self.kernel_regularizer + ) + self.bias = ( + self.add_weight(self.bias_shape, initializer=self.init_bias, trainable=True) + if self.bias_shape is not None + else None ) + super().build(input_shape) - t_delta = config["pruning_parameters"]["t_delta"] - config["pruning_parameters"]["t_start_collecting_batch"] = 0 - for i in range(t_delta): - inputs = calibration_data[i] - if i == 0: - model = add_compression_layers_torch(model, config, inputs.shape) - post_pretrain_functions(model, config) - model(inputs) - return remove_pruning_from_model_torch(model, config) - else: - from pquant.core.tf_impl.compressed_layers_tf import ( - add_compression_layers_tf, - post_pretrain_functions, - remove_pruning_from_model_tf, + def call(self, x, training=None): + weight, bias = self.prune_and_quantize(self.weight, self.bias) + if self.pruning_method == "wanda": + self.collect_input(x, weight, training) + x = ops.conv( + x, weight, strides=self.strides, padding=self.padding, data_format=None, dilation_rate=self.dilation_rate + ) + if self.bias is not None: + x = ops.add(x, bias) + if self.pruning_method == "activation_pruning": + self.collect_output(x, training) + return x + + +class CompressedLayerDenseKeras(CompressedLayerBase): + def __init__(self, config, layer, layer_type): + super().__init__(config, layer, layer_type) + self.kernel_regularizer = layer.kernel_regularizer + self.use_bias = layer.use_bias + self.units = layer.units + self.bias_shape = layer.bias.shape if layer.use_bias else None + self.init_bias = layer.bias.value if layer.use_bias else None + self.weight_shape = layer.kernel.shape + self.init_weight = layer.kernel.value + self.weight_transpose = (1, 0) + self.weight_transpose_back = (1, 0) + self.data_transpose = (0, 1) # Always (BATCH_SIZE, OUT_FEATURES) + + def build(self, input_shape): + self.weight = self.add_weight( + self.weight_shape, initializer=self.init_weight, trainable=True, regularizer=self.kernel_regularizer + ) + self.bias = ( + self.add_weight(self.bias_shape, initializer=self.init_bias, trainable=True) + if self.bias_shape is not None + else None ) + super().build(input_shape) - t_delta = config["pruning_parameters"]["t_delta"] - config["pruning_parameters"]["t_start_collecting_batch"] = 0 + def call(self, x, training=None): + weight, bias = self.prune_and_quantize(self.weight, self.bias) + if self.pruning_method == "wanda": + self.collect_input(x, weight, training) + x = ops.matmul(x, weight) + if self.bias is not None: + x = ops.add(x, bias) + if self.pruning_method == "activation_pruning": + self.collect_output(x, training) + return x - for i in range(t_delta): - inputs = calibration_data[i] - if i == 0: - model = add_compression_layers_tf(model, config, inputs.shape) - post_pretrain_functions(model, config) - model(inputs, training=True) # True so pruning works - return remove_pruning_from_model_tf(model, config) + +def add_default_layer_quantization_pruning_to_config(model, config): + backend = get_backend() + return backend.add_default_layer_quantization_pruning_to_config(model, config) + + +def add_compression_layers(model, config, input_shape): + backend = get_backend() + return backend.add_compression_layers(model, config, input_shape) + + +def get_layer_keep_ratio(model): + backend = get_backend() + return backend.get_layer_keep_ratio(model) + + +def get_model_losses(model, losses): + backend = get_backend() + return backend.get_model_losses(model, losses) + + +def remove_pruning_from_model(model, config): + backend = get_backend() + return backend.remove_pruning_from_model(model, config) + + +def post_training_prune(model, calibration_data, config): + from pquant.core.tf_backend import TFBackend + backend = get_backend() + t_delta = config.pruning_parameters.t_delta + config.pruning_parameters.t_start_collecting_batch = 0 + for i in range(t_delta): + inputs = calibration_data[i] + if i == 0: + model = backend.add_compression_layers(model, config, inputs.shape) + backend.post_pretrain_functions(model, config) + if isinstance(backend, TFBackend): + model(inputs, training=True) + else: + model(inputs) + return backend.remove_pruning_from_model(model, config) diff --git a/src/pquant/core/tf_backend.py b/src/pquant/core/tf_backend.py new file mode 100644 index 0000000..4272954 --- /dev/null +++ b/src/pquant/core/tf_backend.py @@ -0,0 +1,764 @@ +import keras +from hgq.quantizer import Quantizer +from keras import ops +from keras.layers import ( + Activation, + AveragePooling1D, + AveragePooling2D, + AveragePooling3D, + Conv1D, + Conv2D, + Dense, + DepthwiseConv2D, + ReLU, + SeparableConv2D, +) +from quantizers import get_fixed_quantizer + +from pquant.core.activations_quantizer import QuantizedReLU, QuantizedTanh +from pquant.core.backend_interface import BackendInterface +from pquant.core.compressed_layers import ( + CompressedLayerConv1dKeras, + CompressedLayerConv2dKeras, + CompressedLayerDenseKeras, + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerSeparableConv2dKeras, +) + + +class QuantizedPooling(keras.layers.Layer): + def __init__(self, config, layer): + super().__init__() + self.i = ops.convert_to_tensor(config.quantization_parameters.default_integer_bits) + self.f = ops.convert_to_tensor(config.quantization_parameters.default_fractional_bits) + + self.is_pretraining = True + + self.overflow = "SAT_SYM" if config.quantization_parameters.use_symmetric_quantization else "SAT" + self.hgq_gamma = config.quantization_parameters.hgq_gamma + + self.use_high_granularity_quantization = config.quantization_parameters.use_high_granularity_quantization + self.hgq_heterogeneous = config.quantization_parameters.hgq_heterogeneous + self.pool_size = layer.pool_size + self.strides = layer.strides + self.padding = layer.padding + self.data_format = layer.data_format + self.dimensions = layer.__class__.__name__[-2] + + def post_pre_train_function(self): + self.is_pretraining = False + + def set_quantization_bits(self, i_bits, f_bits): + self.i = ops.convert_to_tensor(i_bits) + self.f = ops.convert_to_tensor(f_bits) + + def build(self, input_shape): + super().build(input_shape) + if self.use_high_granularity_quantization: + if self.hgq_heterogeneous: + self.hgq = Quantizer( + k0=1.0, + i0=self.i, + f0=self.f, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + homogeneous_axis=(0,), + ) + self.hgq.build(input_shape) + else: + self.hgq = Quantizer( + k0=1.0, + i0=self.i, + f0=self.f, + round_mode="RND", + overflow_mode=self.overflow, + q_type="kif", + heterogeneous_axis=(), + ) + self.hgq.build(input_shape) + + self.hgq_gamma = self.hgq_gamma + else: + self.quantizer = get_fixed_quantizer(round_mode="RND", overflow_mode=self.overflow) + + def hgq_loss(self): + if self.is_pretraining: + return 0.0 + loss = (ops.sum(self.hgq_weight.quantizer.i) + ops.sum(self.hgq_weight.quantizer.f)) * self.hgq_gamma + if self.bias is not None: + loss += (ops.sum(self.hgq_bias.quantizer.i) + ops.sum(self.hgq_bias.quantizer.f)) * self.hgq_gamma + return loss + + def quantize_i(self, x): + if self.use_high_granularity_quantization: + x = self.hgq(x) + else: + x = self.quantizer(x, k=ops.convert_to_tensor(1.0), i=self.i, f=self.f, training=True) + return x + + def call(self, x): + x = ops.average_pool( + x, + pool_size=self.pool_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + ) + return self.quantize_i(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "i": self.i, + "f": self.f, + "is_pretraining": self.is_pretraining, + "overflow": self.overflow, + "hgq_gamma": self.hgq_gamma, + "hgq_heterogeneous": self.hgq_heterogeneous, + "pooling": self.pooling, + } + ) + return config + + +class TFBackend(BackendInterface): + def iterative_train(self, model, config, train_func, valid_func, **kwargs): + """ + Generic training loop, user provides training and validation functions + """ + epoch = keras.ops.convert_to_tensor(0) # Keeps track of all the epochs completed + training_config = config.training_parameters + if training_config.pretraining_epochs > 0: + for e in range(training_config.pretraining_epochs): + self.pre_epoch_functions(model, e, training_config.pretraining_epochs) + train_func(model, epoch=epoch, **kwargs) + valid_func(model, epoch=epoch, **kwargs) + self.post_epoch_functions(model, e, training_config.pretraining_epochs) + epoch += 1 + self.post_pretrain_functions(model, config) + for r in range(training_config.rounds): + for e in range(training_config.epochs): + if r == 0 and training_config.save_weights_epoch == e: + self.save_weights_functions(model) + self.pre_epoch_functions(model, e, training_config.epochs) + train_func(model, epoch=epoch, **kwargs) + valid_func(model, epoch=epoch, **kwargs) + self.post_epoch_functions(model, e, training_config.epochs) + epoch += 1 + self.call_post_round_functions(model, training_config.rewind, training_config.rounds, r) + self.pre_finetune_functions(model) + if training_config.fine_tuning_epochs > 0: + for e in range(training_config.fine_tuning_epochs): + self.pre_epoch_functions(model, e, training_config.fine_tuning_epochs) + train_func(model, epoch=epoch, **kwargs) + valid_func(model, epoch=epoch, **kwargs) + self.post_epoch_functions(model, e, training_config.fine_tuning_epochs) + epoch += 1 + return model + + def add_default_layer_quantization_pruning_to_config(self, model, config): + custom_scheme = {"layer_specific": {}, "disable_pruning_for_layers": []} + for layer in model.layers: + if layer.__class__ in [Dense, Conv2D, Conv1D, DepthwiseConv2D]: + if layer.use_bias: + custom_scheme["layer_specific"][layer.name] = { + "weight": {"integer_bits": 0.0, "fractional_bits": 7.0}, + "bias": {"integer_bits": 0.0, "fractional_bits": 7.0}, + } + else: + custom_scheme["layer_specific"][layer.name] = {"weight": {"integer_bits": 0.0, "fractional_bits": 7.0}} + if hasattr(layer.activation, "__name__") and layer.activation.__name__ in ["relu", "tanh"]: + custom_scheme["layer_specific"][layer.name][layer.activation.__name__] = { + "integer_bits": 0.0, + "fractional_bits": 7.0, + } + custom_scheme["disable_pruning_for_layers"].append(layer.name) + if layer.__class__ == SeparableConv2D: + if layer.use_bias: + custom_scheme["layer_specific"][layer.name] = { + "depthwise": { + "weight": {"integer_bits": 0.0, "fractional_bits": 7.0}, + }, + "pointwise": { + "weight": {"integer_bits": 0.0, "fractional_bits": 7.0}, + "bias": {"integer_bits": 0.0, "fractional_bits": 7.0}, + }, + } + else: + custom_scheme["layer_specific"][layer.name] = { + "depthwise": {"weight": {"integer_bits": 0.0, "fractional_bits": 7.0}}, + "pointwise": {"weight": {"integer_bits": 0.0, "fractional_bits": 7.0}}, + } + if hasattr(layer.activation, "__name__") and layer.activation.__name__ in ["relu", "tanh"]: + custom_scheme["layer_specific"][layer.name][layer.activation.__name__] = { + "integer_bits": 0.0, + "fractional_bits": 7.0, + } + custom_scheme["disable_pruning_for_layers"].append(layer.name + "_depthwise") + custom_scheme["disable_pruning_for_layers"].append(layer.name + "_pointwise") + elif layer.__class__ in [Activation, ReLU, AveragePooling1D, AveragePooling2D, AveragePooling3D]: + custom_scheme["layer_specific"][layer.name] = {"integer_bits": 0.0, "fractional_bits": 7.0} + config.quantization_parameters.layer_specific = custom_scheme["layer_specific"] + config.pruning_parameters.disable_pruning_for_layers = custom_scheme["disable_pruning_for_layers"] + return config + + def remove_pruning_from_model(self, model, config): + x = model.layers[0].output + for layer in model.layers[1:]: + if isinstance(layer, CompressedLayerDepthwiseConv2dKeras): + new_layer = DepthwiseConv2D( + kernel_size=layer.kernel_size, + strides=layer.strides, + padding=layer.padding, + dilation_rate=layer.dilation_rate, + use_bias=layer.use_bias, + depthwise_regularizer=layer.depthwise_regularizer, + activity_regularizer=layer.activity_regularizer, + ) + x = new_layer(x) + use_bias = layer.use_bias + weight, bias = self._prune_and_quantize_layer(layer, use_bias) + new_layer.set_weights([weight, bias] if use_bias else [weight]) + elif isinstance(layer, CompressedLayerConv2dKeras): + new_layer = Conv2D( + filters=layer.filters, + kernel_size=layer.kernel_size, + strides=layer.strides, + padding=layer.padding, + dilation_rate=layer.dilation_rate, + use_bias=layer.use_bias, + kernel_regularizer=layer.kernel_regularizer, + activity_regularizer=layer.activity_regularizer, + ) + x = new_layer(x) + use_bias = layer.use_bias + weight, bias = self._prune_and_quantize_layer(layer, use_bias) + new_layer.set_weights([weight, bias] if use_bias else [weight]) + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + new_layer = SeparableConv2D( + filters=layer.pointwise_conv.filters, + kernel_size=layer.depthwise_conv.kernel_size, + strides=layer.depthwise_conv.strides, + padding=layer.depthwise_conv.padding, + dilation_rate=layer.depthwise_conv.dilation_rate, + use_bias=layer.pointwise_conv.use_bias, + depthwise_regularizer=layer.depthwise_conv.depthwise_regularizer, + pointwise_regularizer=layer.pointwise_conv.kernel_regularizer, + activity_regularizer=layer.activity_regularizer, + ) + x = new_layer(x) + use_bias = layer.pointwise_conv.use_bias + depthwise_weight, _ = self._prune_and_quantize_layer(layer.depthwise_conv, False) + pointwise_weight, bias = self._prune_and_quantize_layer(layer.pointwise_conv, layer.pointwise_conv.use_bias) + new_layer.set_weights( + [depthwise_weight, pointwise_weight, bias] if use_bias else [depthwise_weight, pointwise_weight] + ) + + elif isinstance(layer, CompressedLayerConv1dKeras): + new_layer = Conv1D( + filters=layer.filters, + kernel_size=layer.kernel_size, + strides=layer.strides, + padding=layer.padding, + dilation_rate=layer.dilation_rate, + use_bias=layer.use_bias, + kernel_regularizer=layer.kernel_regularizer, + activity_regularizer=layer.activity_regularizer, + ) + x = new_layer(x) + use_bias = layer.use_bias + weight, bias = self._prune_and_quantize_layer(layer, use_bias) + new_layer.set_weights([weight, bias] if use_bias else [weight]) + elif isinstance(layer, CompressedLayerDenseKeras): + new_layer = Dense(units=layer.units, use_bias=layer.use_bias, kernel_regularizer=layer.kernel_regularizer) + x = new_layer(x) + use_bias = new_layer.use_bias + weight, bias = self._prune_and_quantize_layer(layer, use_bias) + new_layer.set_weights([weight, bias] if use_bias else [weight]) + else: + x = layer(x) + replaced_model = keras.Model(inputs=model.inputs, outputs=x) + return replaced_model + + def add_compression_layers(self, model, config, input_shape=None): + # Pruning algorithms assume channels_first format + # Creates a new functional model from model, replacing certain layers with compressed / quantized variants + x = model.layers[0].output + for layer in model.layers[1:]: + act = None + if isinstance(layer, DepthwiseConv2D): + new_layer = CompressedLayerDepthwiseConv2dKeras(config, layer, layer_type="conv") + i_bits_w, f_bits_w, i_bits_b, f_bits_b = self.get_quantization_bits_weights_biases(config, layer) + new_layer.set_quantization_bits(i_bits_w, f_bits_w, i_bits_b, f_bits_b) + enable_pruning = self.get_enable_pruning(layer, config) + new_layer.set_enable_pruning(enable_pruning) + pruning_layer_input = layer.kernel + transpose_shape = new_layer.weight_transpose + pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) + new_layer.pruning_layer.build(pruning_layer_input.shape) + + x = new_layer(x) + act = self.check_activation(layer, config) + elif isinstance(layer, Conv2D): + new_layer = CompressedLayerConv2dKeras(config, layer, layer_type="conv") + i_bits_w, f_bits_w, i_bits_b, f_bits_b = self.get_quantization_bits_weights_biases(config, layer) + new_layer.set_quantization_bits(i_bits_w, f_bits_w, i_bits_b, f_bits_b) + enable_pruning = self.get_enable_pruning(layer, config) + new_layer.set_enable_pruning(enable_pruning) + pruning_layer_input = layer.kernel + transpose_shape = new_layer.weight_transpose + pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) + new_layer.pruning_layer.build(pruning_layer_input.shape) + x = new_layer(x) + act = self.check_activation(layer, config) + elif isinstance(layer, SeparableConv2D): + new_layer = CompressedLayerSeparableConv2dKeras(config, layer) + dw_i_bits_w, dw_f_bits_w, pw_i_bits_w, pw_f_bits_w, pw_i_bits_b, pw_f_bits_b = ( + self.get_quantization_bits_weights_biases(config, layer) + ) + new_layer.depthwise_conv.set_quantization_bits(dw_i_bits_w, dw_f_bits_w, pw_i_bits_b, pw_f_bits_b) + new_layer.pointwise_conv.set_quantization_bits(pw_i_bits_w, pw_f_bits_w, pw_i_bits_b, pw_f_bits_b) + enable_pruning_depthwise, enable_pruning_pointwise = self.get_enable_pruning(layer, config) + new_layer.depthwise_conv.set_enable_pruning(enable_pruning_depthwise) + new_layer.pointwise_conv.set_enable_pruning(enable_pruning_pointwise) + + pruning_layer_input = layer.depthwise_kernel + transpose_shape = new_layer.weight_transpose + pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) + new_layer.depthwise_conv.pruning_layer.build(pruning_layer_input.shape) + + pointwise_pruning_layer_input = layer.pointwise_kernel + transpose_shape = new_layer.weight_transpose + pointwise_pruning_layer_input = ops.transpose(pointwise_pruning_layer_input, transpose_shape) + new_layer.pointwise_conv.pruning_layer.build(pointwise_pruning_layer_input.shape) + new_layer.depthwise_conv.build(x.shape) + y = new_layer.depthwise_conv(x).shape + new_layer.pointwise_conv.build(y) + x = new_layer(x) + act = self.check_activation(layer, config) + elif isinstance(layer, Conv1D): + new_layer = CompressedLayerConv1dKeras(config, layer, layer_type="conv") + i_bits_w, f_bits_w, i_bits_b, f_bits_b = self.get_quantization_bits_weights_biases(config, layer) + new_layer.set_quantization_bits(i_bits_w, f_bits_w, i_bits_b, f_bits_b) + enable_pruning = self.get_enable_pruning(layer, config) + new_layer.set_enable_pruning(enable_pruning) + pruning_layer_input = layer.kernel + transpose_shape = new_layer.weight_transpose + pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) + new_layer.pruning_layer.build(pruning_layer_input.shape) + + x = new_layer(x) + act = self.check_activation(layer, config) + elif isinstance(layer, Dense): + new_layer = CompressedLayerDenseKeras(config, layer, layer_type="linear") + i_bits_w, f_bits_w, i_bits_b, f_bits_b = self.get_quantization_bits_weights_biases(config, layer) + new_layer.set_quantization_bits(i_bits_w, f_bits_w, i_bits_b, f_bits_b) + enable_pruning = self.get_enable_pruning(layer, config) + new_layer.set_enable_pruning(enable_pruning) + pruning_layer_input = layer.kernel + transpose_shape = new_layer.weight_transpose + pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) + new_layer.pruning_layer.build(pruning_layer_input.shape) + x = new_layer(x) + act = self.check_activation(layer, config) + # Activation layers + elif isinstance(layer, ReLU): + if config.quantization_parameters.enable_quantization: + i_bits = config.quantization_parameters.default_integer_bits + f_bits = config.quantization_parameters.default_fractional_bits + i_bits, f_bits = self.get_quantization_bits_activations(config, layer) + new_layer = QuantizedReLU(config, i_bits, f_bits) + new_layer.build(layer.input.shape) + x = new_layer(x) + else: + x = layer(x) + elif isinstance(layer, Activation): + new_layer = self.check_activation(layer, config) + if new_layer is not None: + x = new_layer(x) + elif isinstance(layer, (AveragePooling1D, AveragePooling2D, AveragePooling3D)): + if config.quantization_parameters.enable_quantization: + i_bits, f_bits = self.get_quantization_bits_activations(config, layer) + new_layer = QuantizedPooling(config, layer) + new_layer.set_quantization_bits(i_bits, f_bits) + new_layer.build(layer.output.shape) + x = new_layer(x) + else: + x = layer(x) + else: + x = layer(x) + if act is not None: + x = act(x) + replaced_model = keras.Model(inputs=model.inputs, outputs=x) + return replaced_model + + def post_epoch_functions(self, model, epoch, total_epochs, **kwargs): + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + layer.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + layer.depthwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + layer.pointwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + + def pdp_setup(self, model, config): + """ + Calculates a global sparsity threshold. Initializes target sparsity for each layer, which depends on + how large percentage of weights in the layer is smaller than the global threshold + """ + global_weights = None + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + if global_weights is None: + global_weights = ops.ravel(layer.weight) + else: + global_weights = ops.concatenate((global_weights, ops.ravel(layer.weight))) + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + if global_weights is None: + global_weights = ops.ravel(layer.depthwise_conv.weight) + global_weights = ops.concatenate((global_weights, ops.ravel(layer.pointwise_conv.weight))) + else: + global_weights = ops.concatenate((global_weights, ops.ravel(layer.depthwise_conv.weight))) + global_weights = ops.concatenate((global_weights, ops.ravel(layer.pointwise_conv.weight))) + + abs_global_weights = ops.abs(global_weights) + global_weight_topk, _ = ops.top_k(abs_global_weights, ops.size(abs_global_weights)) + threshold = global_weight_topk[int((1 - config.pruning_parameters.sparsity) * float(ops.size(global_weight_topk)))] + global_weights_below_threshold = ops.where(abs_global_weights < threshold, 1, 0) + idx = 0 + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + weight_size = ops.size(layer.weight) + w = ops.sum(global_weights_below_threshold[idx : idx + weight_size]) + layer.pruning_layer.init_r = ops.convert_to_tensor(w / weight_size, dtype=layer.weight.dtype) + layer.pruning_layer.sparsity = ops.convert_to_tensor(w / weight_size, dtype=layer.weight.dtype) # Wanda + idx += weight_size + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + weight_size = ops.size(layer.depthwise_conv.weight) + w = ops.sum(global_weights_below_threshold[idx : idx + weight_size]) + layer.depthwise_conv.pruning_layer.init_r = ops.convert_to_tensor( + w / weight_size, dtype=layer.depthwise_conv.weight.dtype + ) + layer.depthwise_conv.pruning_layer.sparsity = ops.convert_to_tensor( + w / weight_size, dtype=layer.depthwise_conv.weight.dtype + ) # Wanda + idx += weight_size + + weight_size = ops.size(layer.pointwise_conv.weight) + w = ops.sum(global_weights_below_threshold[idx : idx + weight_size]) + layer.pointwise_conv.pruning_layer.init_r = ops.convert_to_tensor( + w / weight_size, dtype=layer.pointwise_conv.weight.dtype + ) + layer.pointwise_conv.pruning_layer.sparsity = ops.convert_to_tensor( + w / weight_size, dtype=layer.pointwise_conv.weight.dtype + ) # Wanda + idx += weight_size + + def post_pretrain_functions(self, model, config): + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + layer.pruning_layer.post_pre_train_function() + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + layer.depthwise_conv.pruning_layer.post_pre_train_function() + layer.pointwise_conv.pruning_layer.post_pre_train_function() + elif isinstance(layer, (QuantizedReLU, QuantizedTanh, QuantizedPooling)): + layer.post_pre_train_function() + if config.pruning_parameters.pruning_method == "pdp" or ( + config.pruning_parameters.pruning_method == "wanda" and config.pruning_parameters.calculate_pruning_budget + ): + self.pdp_setup(model, config) + + def pre_epoch_functions(self, model, epoch, total_epochs): + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + layer.pruning_layer.pre_epoch_function(epoch, total_epochs) + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + layer.depthwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) + layer.pointwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) + + def pre_finetune_functions(self, model): + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + layer.pruning_layer.pre_finetune_function() + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + layer.depthwise_conv.pruning_layer.pre_finetune_function() + layer.pointwise_conv.pruning_layer.pre_finetune_function() + + def save_weights_functions(self, model): + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + layer.save_weights() + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + layer.depthwise_conv.save_weights() + layer.pointwise_conv.save_weights() + + def get_layer_keep_ratio(self, model): + total_w = 0 + remaining_weights = 0 + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + # weight, bias = layer.prune_and_quantize(layer.weight, layer.bias) + weight = ops.cast(layer.weight, layer.weight.dtype) + bias = ops.cast(layer.bias, layer.bias.dtype) if layer.bias is not None else None + weight, bias = layer.quantize_i(weight, bias) + transpose = layer.weight_transpose + if layer.enable_pruning: + weight = layer.pruning_layer.get_hard_mask(ops.transpose(weight, transpose)) * ops.transpose( + weight, transpose + ) + total_w += ops.size(weight) + rem = ops.count_nonzero(weight) + remaining_weights += rem + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + depthwise_weight = ops.cast(layer.depthwise_conv.weight, layer.depthwise_conv.weight.dtype) + pointwise_weight = ops.cast(layer.pointwise_conv.weight, layer.pointwise_conv.weight.dtype) + bias = ( + ops.cast(layer.pointwise_conv.bias, layer.pointwise_conv.bias.dtype) + if layer.pointwise_conv.bias is not None + else None + ) + + depthwise_weight, _ = layer.depthwise_conv.quantize_i(depthwise_weight, None) + transpose = layer.depthwise_conv.weight_transpose + if layer.depthwise_conv.enable_pruning: + depthwise_weight = layer.depthwise_conv.pruning_layer.get_hard_mask( + ops.transpose(depthwise_weight, transpose) + ) * ops.transpose(depthwise_weight, transpose) + total_w += ops.size(layer.depthwise_conv.weight) + rem = ops.count_nonzero(depthwise_weight) + remaining_weights += rem + + pointwise_weight, _ = layer.pointwise_conv.quantize_i(pointwise_weight, bias) + transpose = layer.pointwise_conv.weight_transpose + if layer.pointwise_conv.enable_pruning: + pointwise_weight = layer.pointwise_conv.pruning_layer.get_hard_mask( + ops.transpose(pointwise_weight, transpose) + ) * ops.transpose(pointwise_weight, transpose) + total_w += ops.size(layer.pointwise_conv.weight) + rem = ops.count_nonzero(pointwise_weight) + remaining_weights += rem + + elif isinstance(layer, (Conv2D, Conv1D, DepthwiseConv2D, Dense)): + weight = layer.kernel + total_w += ops.size(weight) + remaining_weights += ops.count_nonzero(weight) + elif isinstance(layer, SeparableConv2D): + depthwise_weight = layer.depthwise_kernel + pointwise_weight = layer.pointwise_kernel + total_w += ops.size(depthwise_weight) + total_w += ops.size(pointwise_weight) + remaining_weights += ops.count_nonzero(depthwise_weight) + remaining_weights += ops.count_nonzero(pointwise_weight) + if total_w != 0: + return remaining_weights / total_w + return 0.0 + + def get_model_losses(self, model, losses): + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + loss = layer.pruning_layer.calculate_additional_loss() + if layer.enable_quantization and layer.use_high_granularity_quantization: + loss += layer.hgq_loss() + losses += loss + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + loss = layer.depthwise_conv.pruning_layer.calculate_additional_loss() + loss += layer.pointwise_conv.pruning_layer.calculate_additional_loss() + if layer.enable_quantization and layer.use_high_granularity_quantization: + loss += layer.depthwise_conv.hgq_loss() + loss += layer.pointwise_conv.hgq_loss() + losses += loss + elif isinstance(layer, (QuantizedReLU, QuantizedTanh, QuantizedPooling)): + if layer.use_high_granularity_quantization: + losses += layer.hgq_loss() + return losses + + def _prune_and_quantize_layer(self, layer, use_bias): + layer_weights = layer.get_weights() + layer_weight = ops.cast(layer_weights[0], layer_weights[0].dtype) + + layer_bias = ops.cast(layer_weights[1], layer_weights[1].dtype) if use_bias else None + weight, bias = layer.prune_and_quantize(layer_weight, layer_bias) + return weight, bias + + def post_round_functions(self, model): + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + layer.pruning_layer.post_round_function() + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + layer.depthwise_conv.pruning_layer.post_round_function() + layer.pointwise_conv.pruning_layer.post_round_function() + + def rewind_weights_functions(self, model): + for layer in model.layers: + if isinstance( + layer, + ( + CompressedLayerDepthwiseConv2dKeras, + CompressedLayerConv2dKeras, + CompressedLayerConv1dKeras, + CompressedLayerDenseKeras, + ), + ): + layer.rewind_weights() + elif isinstance(layer, CompressedLayerSeparableConv2dKeras): + layer.depthwise_conv.rewind_weights() + layer.pointwise_conv.rewind_weights() + + def check_activation(self, layer, config): + """ + Replaces activations with quantized activations. + The activation can be a part of another layer such as Conv2D, or an Activation layer + """ + quantization_enabled = config.quantization_parameters.enable_quantization + act = None + if hasattr(layer.activation, "__name__"): + if layer.activation.__name__ == "relu": + i_bits, f_bits = self.get_quantization_bits_activations(config, layer) + act = QuantizedReLU(config, i_bits, f_bits) if quantization_enabled else ReLU() + act.build(layer.input.shape) + elif layer.activation.__name__ == "tanh": + i_bits, f_bits = self.get_quantization_bits_activations(config, layer) + act = QuantizedTanh(config, i=i_bits, f=f_bits) if quantization_enabled else Activation(activation="tanh") + else: + act = None + return act + + def get_quantization_bits_activations(self, config, layer): + i_bits = config.quantization_parameters.default_integer_bits + f_bits = config.quantization_parameters.default_fractional_bits + if isinstance(layer, ReLU): + f_bits += 1 # Unsigned, add 1 bit to default value only + layer_specific = config.quantization_parameters.layer_specific + if layer.name in layer_specific: + if hasattr(layer, "activation") and layer.activation.__name__ in layer_specific[layer.name]: + i_bits = layer_specific[layer.name][layer.activation.__name__]["integer_bits"] + f_bits = layer_specific[layer.name][layer.activation.__name__]["fractional_bits"] + else: + i_bits = layer_specific[layer.name]["integer_bits"] + f_bits = layer_specific[layer.name]["fractional_bits"] + return i_bits, f_bits + + def get_quantization_bits_weights_biases(self, config, layer): + layer_specific = config.quantization_parameters.layer_specific + if isinstance(layer, SeparableConv2D): + dw_i_bits_w = pw_i_bits_w = pw_i_bits_b = config.quantization_parameters.default_integer_bits + dw_f_bits_w = pw_f_bits_w = pw_f_bits_b = config.quantization_parameters.default_fractional_bits + if layer.name in layer_specific: + if "depthwise" in layer_specific[layer.name]: + if "weight" in layer_specific[layer.name]["depthwise"]: + dw_i_bits_w = layer_specific[layer.name]["depthwise"]["weight"]["integer_bits"] + dw_f_bits_w = layer_specific[layer.name]["depthwise"]["weight"]["fractional_bits"] + if "pointwise" in layer_specific[layer.name]: + if "weight" in layer_specific[layer.name]["pointwise"]: + pw_i_bits_w = layer_specific[layer.name]["pointwise"]["weight"]["integer_bits"] + pw_f_bits_w = layer_specific[layer.name]["pointwise"]["weight"]["fractional_bits"] + if "bias" in layer_specific[layer.name]: + pw_i_bits_b = layer_specific[layer.name]["pointwise"]["bias"]["integer_bits"] + pw_f_bits_b = layer_specific[layer.name]["pointwise"]["bias"]["fractional_bits"] + return dw_i_bits_w, dw_f_bits_w, pw_i_bits_w, pw_f_bits_w, pw_i_bits_b, pw_f_bits_b + else: + i_bits_w = i_bits_b = config.quantization_parameters.default_integer_bits + f_bits_w = f_bits_b = config.quantization_parameters.default_fractional_bits + if layer.name in layer_specific: + if "weight" in layer_specific[layer.name]: + i_bits_w = layer_specific[layer.name]["weight"]["integer_bits"] + f_bits_w = layer_specific[layer.name]["weight"]["fractional_bits"] + if "bias" in layer_specific[layer.name]: + i_bits_b = layer_specific[layer.name]["bias"]["integer_bits"] + f_bits_b = layer_specific[layer.name]["bias"]["fractional_bits"] + return i_bits_w, f_bits_w, i_bits_b, f_bits_b + + def get_enable_pruning(self, layer, config): + enable_pruning = config.pruning_parameters.enable_pruning + if isinstance(layer, SeparableConv2D): + enable_pruning_depthwise = enable_pruning_pointwise = True + if layer.name + "_depthwise" in config.pruning_parameters.disable_pruning_for_layers: + enable_pruning_depthwise = False + if layer.name + "_pointwise" in config.pruning_parameters.disable_pruning_for_layers: + enable_pruning_pointwise = False + return enable_pruning_depthwise, enable_pruning_pointwise + else: + if layer.name in config.pruning_parameters.disable_pruning_for_layers: + enable_pruning = False + return enable_pruning diff --git a/src/pquant/core/tf_impl/compressed_layers_tf.py b/src/pquant/core/tf_impl/compressed_layers_tf.py deleted file mode 100644 index 42f9d13..0000000 --- a/src/pquant/core/tf_impl/compressed_layers_tf.py +++ /dev/null @@ -1,1093 +0,0 @@ -import keras -from hgq.quantizer import Quantizer -from keras import ops -from keras.layers import ( - Activation, - AveragePooling1D, - AveragePooling2D, - AveragePooling3D, - Conv1D, - Conv2D, - Dense, - DepthwiseConv2D, - Layer, - ReLU, - SeparableConv2D, -) -from quantizers import get_fixed_quantizer - -from pquant.core.activations_quantizer import QuantizedReLU, QuantizedTanh -from pquant.core.utils import get_pruning_layer - - -class CompressedLayerBase(keras.layers.Layer): - def __init__(self, config, layer, layer_type): - super().__init__() - i_bits = config["quantization_parameters"]["default_integer_bits"] - f_bits = config["quantization_parameters"]["default_fractional_bits"] - self.i_weight = ops.convert_to_tensor(i_bits) - self.f_weight = ops.convert_to_tensor(f_bits) - self.i_bias = ops.convert_to_tensor(i_bits) - self.f_bias = ops.convert_to_tensor(f_bits) - self.pruning_layer = get_pruning_layer(config=config, layer_type=layer_type) - self.pruning_method = config["pruning_parameters"]["pruning_method"] - self.overflow = "SAT_SYM" if config["quantization_parameters"]["use_symmetric_quantization"] else "SAT" - self.hgq_gamma = config["quantization_parameters"]["hgq_gamma"] - - self.pruning_first = config["training_parameters"]["pruning_first"] - self.enable_quantization = config["quantization_parameters"]["enable_quantization"] - self.use_high_granularity_quantization = config["quantization_parameters"]["use_high_granularity_quantization"] - self.hgq_heterogeneous = config["quantization_parameters"]["hgq_heterogeneous"] - self.enable_pruning = config["pruning_parameters"]["enable_pruning"] - self.do_transpose_data = None - self.weight_transpose = None - self.data_transpose = None - - def set_quantization_bits(self, i_bits_w, f_bits_w, i_bits_b, f_bits_b): - self.i_weight = ops.convert_to_tensor(i_bits_w) - self.f_weight = ops.convert_to_tensor(f_bits_w) - self.i_bias = ops.convert_to_tensor(i_bits_b) - self.f_bias = ops.convert_to_tensor(f_bits_b) - - def set_enable_pruning(self, enable_pruning): - self.enable_pruning = enable_pruning - - def build(self, input_shape): - super().build(input_shape) - if self.use_high_granularity_quantization: - if self.hgq_heterogeneous: - self.hgq_weight = Quantizer( - k0=1.0, - i0=self.i_weight, - f0=self.f_weight, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - homogeneous_axis=(), - ) - self.hgq_weight.build(self.weight.shape) - if self.use_bias: - self.hgq_bias = Quantizer( - k0=1.0, - i0=self.i_bias, - f0=self.f_bias, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - homogeneous_axis=(), - ) - self.hgq_bias.build(self.bias.shape) - else: - self.hgq_weight = Quantizer( - k0=1.0, - i0=self.i_weight, - f0=self.f_weight, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - heterogeneous_axis=(), - ) - self.hgq_weight.build(self.weight.shape) - if self.use_bias: - self.hgq_bias = Quantizer( - k0=1.0, - i0=self.i_bias, - f0=self.f_bias, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - heterogeneous_axis=(), - ) - self.hgq_bias.build(self.bias.shape) - else: - self.quantizer = get_fixed_quantizer(round_mode="RND", overflow_mode=self.overflow) - - def save_weights(self): - self.init_weight = self.weight.value - - def rewind_weights(self): - self.weight.assign(self.init_weight) - - def hgq_loss(self): - if self.pruning_layer.is_pretraining: - return 0.0 - loss = (ops.sum(self.hgq_weight.quantizer.i) + ops.sum(self.hgq_weight.quantizer.f)) * self.hgq_gamma - if self.bias is not None: - loss += (ops.sum(self.hgq_bias.quantizer.i) + ops.sum(self.hgq_bias.quantizer.f)) * self.hgq_gamma - return loss - - def handle_transpose(self, x, transpose, do_transpose=False): - if do_transpose: - x = ops.transpose(x, transpose) - return x - - def quantize_i(self, weight, bias): - if self.enable_quantization: - if self.use_high_granularity_quantization: - weight = self.hgq_weight(weight) - bias = None if bias is None else self.hgq_bias(bias) - else: - weight = self.quantizer( - weight, k=ops.convert_to_tensor(1.0), i=self.i_weight, f=self.f_weight, training=True - ) - bias = ( - None - if bias is None - else self.quantizer(bias, k=ops.convert_to_tensor(1.0), i=self.i_bias, f=self.f_bias, training=True) - ) - return weight, bias - - def prune(self, weight): - if self.enable_pruning: - weight = self.handle_transpose(weight, self.weight_transpose, True) - weight = self.pruning_layer(weight) - weight = self.handle_transpose(weight, self.weight_transpose_back, True) - return weight - - def prune_and_quantize(self, weight, bias): - weight = ops.cast(weight, weight.dtype) - bias = ops.cast(bias, bias.dtype) if bias is not None else None - if self.pruning_first: - weight = self.prune(weight) - weight, bias = self.quantize_i(weight, bias) - else: - weight, bias = self.quantize_i(weight, bias) - weight = self.prune(weight) - return weight, bias - - def call(self, x): - return x - - def collect_input(self, x, weight, training): - collect_x = self.handle_transpose(x, self.data_transpose, self.do_transpose_data) - weight_channels_first = self.handle_transpose(weight, self.weight_transpose, True) - self.pruning_layer.collect_input(collect_x, weight_channels_first, training) - - def collect_output(self, x, training): - collect_x = self.handle_transpose(x, self.data_transpose, self.do_transpose_data) - self.pruning_layer.collect_output(collect_x, training) - - -class CompressedLayerDepthwiseConv2dKeras(CompressedLayerBase): - def __init__(self, config, layer, layer_type): - super().__init__(config, layer, layer_type) - self.depthwise_regularizer = layer.depthwise_regularizer - self.use_bias = layer.use_bias - self.strides = layer.strides - self.dilation_rate = layer.dilation_rate - self.padding = layer.padding - self.kernel_size = layer.kernel_size - self.bias_shape = layer.bias.shape if layer.use_bias else None - self.init_bias = layer.bias.value if layer.use_bias else None - self.weight_shape = layer.kernel.shape - self.init_weight = layer.kernel.value - self.weight_transpose = (3, 2, 0, 1) - self.weight_transpose_back = (2, 3, 1, 0) - self.data_transpose = (0, 3, 1, 2) - self.do_transpose_data = layer.data_format == "channels_last" - - def build(self, input_shape): - self.weight = self.add_weight( - self.weight_shape, initializer=self.init_weight, trainable=True, regularizer=self.depthwise_regularizer - ) - self.bias = ( - self.add_weight(self.bias_shape, initializer=self.init_bias, trainable=True) - if self.bias_shape is not None - else None - ) - super().build(input_shape) - - def call(self, x, training=None): - weight, bias = self.prune_and_quantize(self.weight, self.bias) - if self.pruning_method == "wanda": - self.collect_input(x, weight, training) - x = ops.depthwise_conv( - x, weight, strides=self.strides, padding=self.padding, data_format=None, dilation_rate=self.dilation_rate - ) - if self.pruning_method == "activation_pruning": - self.collect_output(x, training) - return x - - -class CompressedLayerConv2dKeras(CompressedLayerBase): - def __init__(self, config, layer, layer_type): - super().__init__(config, layer, layer_type) - self.kernel_regularizer = layer.kernel_regularizer - self.filters = layer.filters - self.use_bias = layer.use_bias - self.strides = layer.strides - self.dilation_rate = layer.dilation_rate - self.padding = layer.padding - self.kernel_size = layer.kernel_size - if hasattr(layer, "groups"): - self.groups = layer.groups - self.bias_shape = layer.bias.shape if layer.use_bias else None - self.init_bias = layer.bias.value if layer.use_bias else None - self.weight_shape = layer.kernel.shape - self.init_weight = layer.kernel.value - self.weight_transpose = (3, 2, 0, 1) - self.weight_transpose_back = (2, 3, 1, 0) - self.data_transpose = (0, 3, 1, 2) - self.do_transpose_data = layer.data_format == "channels_last" - - def build(self, input_shape): - self.weight = self.add_weight( - self.weight_shape, initializer=self.init_weight, trainable=True, regularizer=self.kernel_regularizer - ) - self.bias = ( - self.add_weight(self.bias_shape, initializer=self.init_bias, trainable=True) - if self.bias_shape is not None - else None - ) - super().build(input_shape) - - def call(self, x, training=None): - weight, bias = self.prune_and_quantize(self.weight, self.bias) - if self.pruning_method == "wanda": - self.collect_input(x, weight, training) - x = ops.conv( - x, weight, strides=self.strides, padding=self.padding, data_format=None, dilation_rate=self.dilation_rate - ) - if self.bias is not None: - x = ops.add(x, bias) - if self.pruning_method == "activation_pruning": - self.collect_output(x, training) - return x - - -class CompressedLayerSeparableConv2dKeras(Layer): - def __init__(self, config, layer): - super().__init__() - self.weight_transpose = (3, 2, 0, 1) - self.weight_transpose_back = (2, 3, 1, 0) - self.data_transpose = (0, 3, 1, 2) - layer.kernel = layer.depthwise_kernel - bias = layer.use_bias - layer.use_bias = False - self.depthwise_conv = CompressedLayerDepthwiseConv2dKeras(config, layer, "conv") - layer.kernel_regularizer = layer.pointwise_regularizer - layer.kernel_size = 1 - layer.kernel = layer.pointwise_kernel - layer.use_bias = bias - self.pointwise_conv = CompressedLayerConv2dKeras(config, layer, "conv") - self.do_transpose_data = layer.data_format == "channels_last" - - def build(self, input_shape): - super().build(input_shape) - - def call(self, x, training=None): - x = self.depthwise_conv(x, training=training) - x = self.pointwise_conv(x, training=training) - return x - - -class CompressedLayerConv1dKeras(CompressedLayerBase): - def __init__(self, config, layer, layer_type): - super().__init__(config, layer, layer_type) - self.kernel_regularizer = layer.kernel_regularizer - self.filters = layer.filters - self.use_bias = layer.use_bias - self.strides = layer.strides - self.dilation_rate = layer.dilation_rate - self.padding = layer.padding - self.kernel_size = layer.kernel_size - self.groups = layer.groups - self.bias_shape = layer.bias.shape if layer.use_bias else None - self.init_bias = layer.bias.value if layer.use_bias else None - self.weight_shape = layer.kernel.shape - self.init_weight = layer.kernel.value - self.weight_transpose = (2, 1, 0) - self.weight_transpose_back = (2, 1, 0) - self.data_transpose = (0, 2, 1) - self.do_transpose_data = layer.data_format == "channels_last" - - def build(self, input_shape): - self.weight = self.add_weight( - self.weight_shape, initializer=self.init_weight, trainable=True, regularizer=self.kernel_regularizer - ) - self.bias = ( - self.add_weight(self.bias_shape, initializer=self.init_bias, trainable=True) - if self.bias_shape is not None - else None - ) - super().build(input_shape) - - def call(self, x, training=None): - weight, bias = self.prune_and_quantize(self.weight, self.bias) - if self.pruning_method == "wanda": - self.collect_input(x, weight, training) - x = ops.conv( - x, weight, strides=self.strides, padding=self.padding, data_format=None, dilation_rate=self.dilation_rate - ) - if self.bias is not None: - x = ops.add(x, bias) - if self.pruning_method == "activation_pruning": - self.collect_output(x, training) - return x - - -class CompressedLayerDenseKeras(CompressedLayerBase): - def __init__(self, config, layer, layer_type): - super().__init__(config, layer, layer_type) - self.kernel_regularizer = layer.kernel_regularizer - self.use_bias = layer.use_bias - self.units = layer.units - self.bias_shape = layer.bias.shape if layer.use_bias else None - self.init_bias = layer.bias.value if layer.use_bias else None - self.weight_shape = layer.kernel.shape - self.init_weight = layer.kernel.value - self.weight_transpose = (1, 0) - self.weight_transpose_back = (1, 0) - self.data_transpose = (0, 1) # Always (BATCH_SIZE, OUT_FEATURES) - - def build(self, input_shape): - self.weight = self.add_weight( - self.weight_shape, initializer=self.init_weight, trainable=True, regularizer=self.kernel_regularizer - ) - self.bias = ( - self.add_weight(self.bias_shape, initializer=self.init_bias, trainable=True) - if self.bias_shape is not None - else None - ) - super().build(input_shape) - - def call(self, x, training=None): - weight, bias = self.prune_and_quantize(self.weight, self.bias) - if self.pruning_method == "wanda": - self.collect_input(x, weight, training) - x = ops.matmul(x, weight) - if self.bias is not None: - x = ops.add(x, bias) - if self.pruning_method == "activation_pruning": - self.collect_output(x, training) - return x - - -class QuantizedPooling(keras.layers.Layer): - def __init__(self, config, layer): - super().__init__() - self.i = ops.convert_to_tensor(config["quantization_parameters"]["default_integer_bits"]) - self.f = ops.convert_to_tensor(config["quantization_parameters"]["default_fractional_bits"]) - - self.is_pretraining = True - - self.overflow = "SAT_SYM" if config["quantization_parameters"]["use_symmetric_quantization"] else "SAT" - self.hgq_gamma = config["quantization_parameters"]["hgq_gamma"] - - self.use_high_granularity_quantization = config["quantization_parameters"]["use_high_granularity_quantization"] - self.hgq_heterogeneous = config["quantization_parameters"]["hgq_heterogeneous"] - self.pool_size = layer.pool_size - self.strides = layer.strides - self.padding = layer.padding - self.data_format = layer.data_format - self.dimensions = layer.__class__.__name__[-2] - - def post_pre_train_function(self): - self.is_pretraining = False - - def set_quantization_bits(self, i_bits, f_bits): - self.i = ops.convert_to_tensor(i_bits) - self.f = ops.convert_to_tensor(f_bits) - - def build(self, input_shape): - super().build(input_shape) - if self.use_high_granularity_quantization: - if self.hgq_heterogeneous: - self.hgq = Quantizer( - k0=1.0, - i0=self.i, - f0=self.f, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - homogeneous_axis=(0,), - ) - self.hgq.build(input_shape) - else: - self.hgq = Quantizer( - k0=1.0, - i0=self.i, - f0=self.f, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - heterogeneous_axis=(), - ) - self.hgq.build(input_shape) - - self.hgq_gamma = self.hgq_gamma - else: - self.quantizer = get_fixed_quantizer(round_mode="RND", overflow_mode=self.overflow) - - def hgq_loss(self): - if self.is_pretraining: - return 0.0 - loss = (ops.sum(self.hgq_weight.quantizer.i) + ops.sum(self.hgq_weight.quantizer.f)) * self.hgq_gamma - if self.bias is not None: - loss += (ops.sum(self.hgq_bias.quantizer.i) + ops.sum(self.hgq_bias.quantizer.f)) * self.hgq_gamma - return loss - - def quantize_i(self, x): - if self.use_high_granularity_quantization: - x = self.hgq(x) - else: - x = self.quantizer(x, k=ops.convert_to_tensor(1.0), i=self.i, f=self.f, training=True) - return x - - def call(self, x): - x = ops.average_pool( - x, - pool_size=self.pool_size, - strides=self.strides, - padding=self.padding, - data_format=self.data_format, - ) - return self.quantize_i(x) - - def get_config(self): - config = super().get_config() - config.update( - { - "i": self.i, - "f": self.f, - "is_pretraining": self.is_pretraining, - "overflow": self.overflow, - "hgq_gamma": self.hgq_gamma, - "hgq_heterogeneous": self.hgq_heterogeneous, - "pooling": self.pooling, - } - ) - return config - - -def call_post_round_functions(model, rewind, rounds, r): - if rewind == "round": - rewind_weights_functions(model) - elif rewind == "post-ticket-search" and r == rounds - 1: - rewind_weights_functions(model) - else: - post_round_functions(model) - - -def _prune_and_quantize_layer(layer, use_bias): - layer_weights = layer.get_weights() - layer_weight = ops.cast(layer_weights[0], layer_weights[0].dtype) - - layer_bias = ops.cast(layer_weights[1], layer_weights[1].dtype) if use_bias else None - weight, bias = layer.prune_and_quantize(layer_weight, layer_bias) - return weight, bias - - -def remove_pruning_from_model_tf(model, config): - x = model.layers[0].output - for layer in model.layers[1:]: - if isinstance(layer, CompressedLayerDepthwiseConv2dKeras): - new_layer = DepthwiseConv2D( - kernel_size=layer.kernel_size, - strides=layer.strides, - padding=layer.padding, - dilation_rate=layer.dilation_rate, - use_bias=layer.use_bias, - depthwise_regularizer=layer.depthwise_regularizer, - activity_regularizer=layer.activity_regularizer, - ) - x = new_layer(x) - use_bias = layer.use_bias - weight, bias = _prune_and_quantize_layer(layer, use_bias) - new_layer.set_weights([weight, bias] if use_bias else [weight]) - elif isinstance(layer, CompressedLayerConv2dKeras): - new_layer = Conv2D( - filters=layer.filters, - kernel_size=layer.kernel_size, - strides=layer.strides, - padding=layer.padding, - dilation_rate=layer.dilation_rate, - use_bias=layer.use_bias, - kernel_regularizer=layer.kernel_regularizer, - activity_regularizer=layer.activity_regularizer, - ) - x = new_layer(x) - use_bias = layer.use_bias - weight, bias = _prune_and_quantize_layer(layer, use_bias) - new_layer.set_weights([weight, bias] if use_bias else [weight]) - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - new_layer = SeparableConv2D( - filters=layer.pointwise_conv.filters, - kernel_size=layer.depthwise_conv.kernel_size, - strides=layer.depthwise_conv.strides, - padding=layer.depthwise_conv.padding, - dilation_rate=layer.depthwise_conv.dilation_rate, - use_bias=layer.pointwise_conv.use_bias, - depthwise_regularizer=layer.depthwise_conv.depthwise_regularizer, - pointwise_regularizer=layer.pointwise_conv.kernel_regularizer, - activity_regularizer=layer.activity_regularizer, - ) - x = new_layer(x) - use_bias = layer.pointwise_conv.use_bias - depthwise_weight, _ = _prune_and_quantize_layer(layer.depthwise_conv, False) - pointwise_weight, bias = _prune_and_quantize_layer(layer.pointwise_conv, layer.pointwise_conv.use_bias) - new_layer.set_weights( - [depthwise_weight, pointwise_weight, bias] if use_bias else [depthwise_weight, pointwise_weight] - ) - - elif isinstance(layer, CompressedLayerConv1dKeras): - new_layer = Conv1D( - filters=layer.filters, - kernel_size=layer.kernel_size, - strides=layer.strides, - padding=layer.padding, - dilation_rate=layer.dilation_rate, - use_bias=layer.use_bias, - kernel_regularizer=layer.kernel_regularizer, - activity_regularizer=layer.activity_regularizer, - ) - x = new_layer(x) - use_bias = layer.use_bias - weight, bias = _prune_and_quantize_layer(layer, use_bias) - new_layer.set_weights([weight, bias] if use_bias else [weight]) - elif isinstance(layer, CompressedLayerDenseKeras): - new_layer = Dense(units=layer.units, use_bias=layer.use_bias, kernel_regularizer=layer.kernel_regularizer) - x = new_layer(x) - use_bias = new_layer.use_bias - weight, bias = _prune_and_quantize_layer(layer, use_bias) - new_layer.set_weights([weight, bias] if use_bias else [weight]) - else: - x = layer(x) - replaced_model = keras.Model(inputs=model.inputs, outputs=x) - return replaced_model - - -def post_epoch_functions(model, epoch, total_epochs, **kwargs): - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - layer.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - layer.depthwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) - layer.pointwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) - - -def pre_epoch_functions(model, epoch, total_epochs): - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - layer.pruning_layer.pre_epoch_function(epoch, total_epochs) - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - layer.depthwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) - layer.pointwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) - - -def post_round_functions(model): - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - layer.pruning_layer.post_round_function() - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - layer.depthwise_conv.pruning_layer.post_round_function() - layer.pointwise_conv.pruning_layer.post_round_function() - - -def save_weights_functions(model): - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - layer.save_weights() - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - layer.depthwise_conv.save_weights() - layer.pointwise_conv.save_weights() - - -def rewind_weights_functions(model): - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - layer.rewind_weights() - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - layer.depthwise_conv.rewind_weights() - layer.pointwise_conv.rewind_weights() - - -def pre_finetune_functions(model): - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - layer.pruning_layer.pre_finetune_function() - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - layer.depthwise_conv.pruning_layer.pre_finetune_function() - layer.pointwise_conv.pruning_layer.pre_finetune_function() - - -def post_pretrain_functions(model, config): - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - layer.pruning_layer.post_pre_train_function() - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - layer.depthwise_conv.pruning_layer.post_pre_train_function() - layer.pointwise_conv.pruning_layer.post_pre_train_function() - elif isinstance(layer, (QuantizedReLU, QuantizedTanh, QuantizedPooling)): - layer.post_pre_train_function() - if config["pruning_parameters"]["pruning_method"] == "pdp" or ( - config["pruning_parameters"]["pruning_method"] == "wanda" - and config["pruning_parameters"]["calculate_pruning_budget"] - ): - pdp_setup(model, config) - - -def pdp_setup(model, config): - """ - Calculates a global sparsity threshold. Initializes target sparsity for each layer, which depends on - how large percentage of weights in the layer is smaller than the global threshold - """ - global_weights = None - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - if global_weights is None: - global_weights = ops.ravel(layer.weight) - else: - global_weights = ops.concatenate((global_weights, ops.ravel(layer.weight))) - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - if global_weights is None: - global_weights = ops.ravel(layer.depthwise_conv.weight) - global_weights = ops.concatenate((global_weights, ops.ravel(layer.pointwise_conv.weight))) - else: - global_weights = ops.concatenate((global_weights, ops.ravel(layer.depthwise_conv.weight))) - global_weights = ops.concatenate((global_weights, ops.ravel(layer.pointwise_conv.weight))) - - abs_global_weights = ops.abs(global_weights) - global_weight_topk, _ = ops.top_k(abs_global_weights, ops.size(abs_global_weights)) - threshold = global_weight_topk[int((1 - config["pruning_parameters"]["sparsity"]) * float(ops.size(global_weight_topk)))] - global_weights_below_threshold = ops.where(abs_global_weights < threshold, 1, 0) - idx = 0 - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - weight_size = ops.size(layer.weight) - w = ops.sum(global_weights_below_threshold[idx : idx + weight_size]) - layer.pruning_layer.init_r = ops.convert_to_tensor(w / weight_size, dtype=layer.weight.dtype) - layer.pruning_layer.sparsity = ops.convert_to_tensor(w / weight_size, dtype=layer.weight.dtype) # Wanda - idx += weight_size - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - weight_size = ops.size(layer.depthwise_conv.weight) - w = ops.sum(global_weights_below_threshold[idx : idx + weight_size]) - layer.depthwise_conv.pruning_layer.init_r = ops.convert_to_tensor( - w / weight_size, dtype=layer.depthwise_conv.weight.dtype - ) - layer.depthwise_conv.pruning_layer.sparsity = ops.convert_to_tensor( - w / weight_size, dtype=layer.depthwise_conv.weight.dtype - ) # Wanda - idx += weight_size - - weight_size = ops.size(layer.pointwise_conv.weight) - w = ops.sum(global_weights_below_threshold[idx : idx + weight_size]) - layer.pointwise_conv.pruning_layer.init_r = ops.convert_to_tensor( - w / weight_size, dtype=layer.pointwise_conv.weight.dtype - ) - layer.pointwise_conv.pruning_layer.sparsity = ops.convert_to_tensor( - w / weight_size, dtype=layer.pointwise_conv.weight.dtype - ) # Wanda - idx += weight_size - - -def get_layer_keep_ratio_tf(model): - total_w = 0 - remaining_weights = 0 - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - # weight, bias = layer.prune_and_quantize(layer.weight, layer.bias) - weight = ops.cast(layer.weight, layer.weight.dtype) - bias = ops.cast(layer.bias, layer.bias.dtype) if layer.bias is not None else None - weight, bias = layer.quantize_i(weight, bias) - transpose = layer.weight_transpose - if layer.enable_pruning: - weight = layer.pruning_layer.get_hard_mask(ops.transpose(weight, transpose)) * ops.transpose( - weight, transpose - ) - total_w += ops.size(weight) - rem = ops.count_nonzero(weight) - remaining_weights += rem - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - depthwise_weight = ops.cast(layer.depthwise_conv.weight, layer.depthwise_conv.weight.dtype) - pointwise_weight = ops.cast(layer.pointwise_conv.weight, layer.pointwise_conv.weight.dtype) - bias = ( - ops.cast(layer.pointwise_conv.bias, layer.pointwise_conv.bias.dtype) - if layer.pointwise_conv.bias is not None - else None - ) - - depthwise_weight, _ = layer.depthwise_conv.quantize_i(depthwise_weight, None) - transpose = layer.depthwise_conv.weight_transpose - if layer.depthwise_conv.enable_pruning: - depthwise_weight = layer.depthwise_conv.pruning_layer.get_hard_mask( - ops.transpose(depthwise_weight, transpose) - ) * ops.transpose(depthwise_weight, transpose) - total_w += ops.size(layer.depthwise_conv.weight) - rem = ops.count_nonzero(depthwise_weight) - remaining_weights += rem - - pointwise_weight, _ = layer.pointwise_conv.quantize_i(pointwise_weight, bias) - transpose = layer.pointwise_conv.weight_transpose - if layer.pointwise_conv.enable_pruning: - pointwise_weight = layer.pointwise_conv.pruning_layer.get_hard_mask( - ops.transpose(pointwise_weight, transpose) - ) * ops.transpose(pointwise_weight, transpose) - total_w += ops.size(layer.pointwise_conv.weight) - rem = ops.count_nonzero(pointwise_weight) - remaining_weights += rem - - elif isinstance(layer, (Conv2D, Conv1D, DepthwiseConv2D, Dense)): - weight = layer.kernel - total_w += ops.size(weight) - remaining_weights += ops.count_nonzero(weight) - elif isinstance(layer, SeparableConv2D): - depthwise_weight = layer.depthwise_kernel - pointwise_weight = layer.pointwise_kernel - total_w += ops.size(depthwise_weight) - total_w += ops.size(pointwise_weight) - remaining_weights += ops.count_nonzero(depthwise_weight) - remaining_weights += ops.count_nonzero(pointwise_weight) - if total_w != 0: - return remaining_weights / total_w - return 0.0 - - -def get_model_losses_tf(model, losses): - for layer in model.layers: - if isinstance( - layer, - ( - CompressedLayerDepthwiseConv2dKeras, - CompressedLayerConv2dKeras, - CompressedLayerConv1dKeras, - CompressedLayerDenseKeras, - ), - ): - loss = layer.pruning_layer.calculate_additional_loss() - if layer.enable_quantization and layer.use_high_granularity_quantization: - loss += layer.hgq_loss() - losses += loss - elif isinstance(layer, CompressedLayerSeparableConv2dKeras): - loss = layer.depthwise_conv.pruning_layer.calculate_additional_loss() - loss += layer.pointwise_conv.pruning_layer.calculate_additional_loss() - if layer.enable_quantization and layer.use_high_granularity_quantization: - loss += layer.depthwise_conv.hgq_loss() - loss += layer.pointwise_conv.hgq_loss() - losses += loss - elif isinstance(layer, (QuantizedReLU, QuantizedTanh, QuantizedPooling)): - if layer.use_high_granularity_quantization: - losses += layer.hgq_loss() - return losses - - -def check_activation(layer, config): - """ - Replaces activations with quantized activations. - The activation can be a part of another layer such as Conv2D, or an Activation layer - """ - quantization_enabled = config["quantization_parameters"]["enable_quantization"] - act = None - if hasattr(layer.activation, "__name__"): - if layer.activation.__name__ == "relu": - i_bits, f_bits = get_quantization_bits_activations(config, layer) - act = QuantizedReLU(config, i_bits, f_bits) if quantization_enabled else ReLU() - act.build(layer.input.shape) - elif layer.activation.__name__ == "tanh": - i_bits, f_bits = get_quantization_bits_activations(config, layer) - act = QuantizedTanh(config, i=i_bits, f=f_bits) if quantization_enabled else Activation(activation="tanh") - else: - act = None - return act - - -def add_compression_layers_tf(model, config, input_shape=None): - # Pruning algorithms assume channels_first format - # Creates a new functional model from model, replacing certain layers with compressed / quantized variants - x = model.layers[0].output - for layer in model.layers[1:]: - act = None - if isinstance(layer, DepthwiseConv2D): - new_layer = CompressedLayerDepthwiseConv2dKeras(config, layer, layer_type="conv") - i_bits_w, f_bits_w, i_bits_b, f_bits_b = get_quantization_bits_weights_biases(config, layer) - new_layer.set_quantization_bits(i_bits_w, f_bits_w, i_bits_b, f_bits_b) - enable_pruning = get_enable_pruning(layer, config) - new_layer.set_enable_pruning(enable_pruning) - pruning_layer_input = layer.kernel - transpose_shape = new_layer.weight_transpose - pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) - new_layer.pruning_layer.build(pruning_layer_input.shape) - - x = new_layer(x) - act = check_activation(layer, config) - elif isinstance(layer, Conv2D): - new_layer = CompressedLayerConv2dKeras(config, layer, layer_type="conv") - i_bits_w, f_bits_w, i_bits_b, f_bits_b = get_quantization_bits_weights_biases(config, layer) - new_layer.set_quantization_bits(i_bits_w, f_bits_w, i_bits_b, f_bits_b) - enable_pruning = get_enable_pruning(layer, config) - new_layer.set_enable_pruning(enable_pruning) - pruning_layer_input = layer.kernel - transpose_shape = new_layer.weight_transpose - pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) - new_layer.pruning_layer.build(pruning_layer_input.shape) - x = new_layer(x) - act = check_activation(layer, config) - elif isinstance(layer, SeparableConv2D): - new_layer = CompressedLayerSeparableConv2dKeras(config, layer) - dw_i_bits_w, dw_f_bits_w, pw_i_bits_w, pw_f_bits_w, pw_i_bits_b, pw_f_bits_b = ( - get_quantization_bits_weights_biases(config, layer) - ) - new_layer.depthwise_conv.set_quantization_bits(dw_i_bits_w, dw_f_bits_w, pw_i_bits_b, pw_f_bits_b) - new_layer.pointwise_conv.set_quantization_bits(pw_i_bits_w, pw_f_bits_w, pw_i_bits_b, pw_f_bits_b) - enable_pruning_depthwise, enable_pruning_pointwise = get_enable_pruning(layer, config) - new_layer.depthwise_conv.set_enable_pruning(enable_pruning_depthwise) - new_layer.pointwise_conv.set_enable_pruning(enable_pruning_pointwise) - - pruning_layer_input = layer.depthwise_kernel - transpose_shape = new_layer.weight_transpose - pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) - new_layer.depthwise_conv.pruning_layer.build(pruning_layer_input.shape) - - pointwise_pruning_layer_input = layer.pointwise_kernel - transpose_shape = new_layer.weight_transpose - pointwise_pruning_layer_input = ops.transpose(pointwise_pruning_layer_input, transpose_shape) - new_layer.pointwise_conv.pruning_layer.build(pointwise_pruning_layer_input.shape) - new_layer.depthwise_conv.build(x.shape) - y = new_layer.depthwise_conv(x).shape - new_layer.pointwise_conv.build(y) - x = new_layer(x) - act = check_activation(layer, config) - elif isinstance(layer, Conv1D): - new_layer = CompressedLayerConv1dKeras(config, layer, layer_type="conv") - i_bits_w, f_bits_w, i_bits_b, f_bits_b = get_quantization_bits_weights_biases(config, layer) - new_layer.set_quantization_bits(i_bits_w, f_bits_w, i_bits_b, f_bits_b) - enable_pruning = get_enable_pruning(layer, config) - new_layer.set_enable_pruning(enable_pruning) - pruning_layer_input = layer.kernel - transpose_shape = new_layer.weight_transpose - pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) - new_layer.pruning_layer.build(pruning_layer_input.shape) - - x = new_layer(x) - act = check_activation(layer, config) - elif isinstance(layer, Dense): - new_layer = CompressedLayerDenseKeras(config, layer, layer_type="linear") - i_bits_w, f_bits_w, i_bits_b, f_bits_b = get_quantization_bits_weights_biases(config, layer) - new_layer.set_quantization_bits(i_bits_w, f_bits_w, i_bits_b, f_bits_b) - enable_pruning = get_enable_pruning(layer, config) - new_layer.set_enable_pruning(enable_pruning) - pruning_layer_input = layer.kernel - transpose_shape = new_layer.weight_transpose - pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) - new_layer.pruning_layer.build(pruning_layer_input.shape) - x = new_layer(x) - act = check_activation(layer, config) - # Activation layers - elif isinstance(layer, ReLU): - if config["quantization_parameters"]["enable_quantization"]: - i_bits = config["quantization_parameters"]["default_integer_bits"] - f_bits = config["quantization_parameters"]["default_fractional_bits"] - i_bits, f_bits = get_quantization_bits_activations(config, layer) - new_layer = QuantizedReLU(config, i_bits, f_bits) - new_layer.build(layer.input.shape) - x = new_layer(x) - else: - x = layer(x) - elif isinstance(layer, Activation): - new_layer = check_activation(layer, config) - if new_layer is not None: - x = new_layer(x) - elif isinstance(layer, (AveragePooling1D, AveragePooling2D, AveragePooling3D)): - if config["quantization_parameters"]["enable_quantization"]: - i_bits, f_bits = get_quantization_bits_activations(config, layer) - new_layer = QuantizedPooling(config, layer) - new_layer.set_quantization_bits(i_bits, f_bits) - new_layer.build(layer.output.shape) - x = new_layer(x) - else: - x = layer(x) - else: - x = layer(x) - if act is not None: - x = act(x) - replaced_model = keras.Model(inputs=model.inputs, outputs=x) - return replaced_model - - -def get_quantization_bits_activations(config, layer): - i_bits = config["quantization_parameters"]["default_integer_bits"] - f_bits = config["quantization_parameters"]["default_fractional_bits"] - if isinstance(layer, ReLU): - f_bits += 1 # Unsigned, add 1 bit to default value only - layer_specific = config["quantization_parameters"]["layer_specific"] - if layer.name in layer_specific: - if hasattr(layer, "activation") and layer.activation.__name__ in layer_specific[layer.name]: - i_bits = layer_specific[layer.name][layer.activation.__name__]["integer_bits"] - f_bits = layer_specific[layer.name][layer.activation.__name__]["fractional_bits"] - else: - i_bits = layer_specific[layer.name]["integer_bits"] - f_bits = layer_specific[layer.name]["fractional_bits"] - return i_bits, f_bits - - -def get_quantization_bits_weights_biases(config, layer): - layer_specific = config["quantization_parameters"]["layer_specific"] - if isinstance(layer, SeparableConv2D): - dw_i_bits_w = pw_i_bits_w = pw_i_bits_b = config["quantization_parameters"]["default_integer_bits"] - dw_f_bits_w = pw_f_bits_w = pw_f_bits_b = config["quantization_parameters"]["default_fractional_bits"] - if layer.name in layer_specific: - if "depthwise" in layer_specific[layer.name]: - if "weight" in layer_specific[layer.name]["depthwise"]: - dw_i_bits_w = layer_specific[layer.name]["depthwise"]["weight"]["integer_bits"] - dw_f_bits_w = layer_specific[layer.name]["depthwise"]["weight"]["fractional_bits"] - if "pointwise" in layer_specific[layer.name]: - if "weight" in layer_specific[layer.name]["pointwise"]: - pw_i_bits_w = layer_specific[layer.name]["pointwise"]["weight"]["integer_bits"] - pw_f_bits_w = layer_specific[layer.name]["pointwise"]["weight"]["fractional_bits"] - if "bias" in layer_specific[layer.name]: - pw_i_bits_b = layer_specific[layer.name]["pointwise"]["bias"]["integer_bits"] - pw_f_bits_b = layer_specific[layer.name]["pointwise"]["bias"]["fractional_bits"] - return dw_i_bits_w, dw_f_bits_w, pw_i_bits_w, pw_f_bits_w, pw_i_bits_b, pw_f_bits_b - else: - i_bits_w = i_bits_b = config["quantization_parameters"]["default_integer_bits"] - f_bits_w = f_bits_b = config["quantization_parameters"]["default_fractional_bits"] - if layer.name in layer_specific: - if "weight" in layer_specific[layer.name]: - i_bits_w = layer_specific[layer.name]["weight"]["integer_bits"] - f_bits_w = layer_specific[layer.name]["weight"]["fractional_bits"] - if "bias" in layer_specific[layer.name]: - i_bits_b = layer_specific[layer.name]["bias"]["integer_bits"] - f_bits_b = layer_specific[layer.name]["bias"]["fractional_bits"] - return i_bits_w, f_bits_w, i_bits_b, f_bits_b - - -def get_enable_pruning(layer, config): - enable_pruning = config["pruning_parameters"]["enable_pruning"] - if isinstance(layer, SeparableConv2D): - enable_pruning_depthwise = enable_pruning_pointwise = True - if layer.name + "_depthwise" in config["pruning_parameters"]["disable_pruning_for_layers"]: - enable_pruning_depthwise = False - if layer.name + "pointwise" in config["pruning_parameters"]["disable_pruning_for_layers"]: - enable_pruning_pointwise = False - return enable_pruning_depthwise, enable_pruning_pointwise - else: - if layer.name in config["pruning_parameters"]["disable_pruning_for_layers"]: - enable_pruning = False - return enable_pruning - - -def add_default_layer_quantization_pruning_to_config_tf(model, config): - custom_scheme = {"layer_specific": {}, "disable_pruning_for_layers": []} - for layer in model.layers: - if layer.__class__ in [Dense, Conv2D, Conv1D, DepthwiseConv2D]: - if layer.use_bias: - custom_scheme["layer_specific"][layer.name] = { - "weight": {"integer_bits": 0.0, "fractional_bits": 7.0}, - "bias": {"integer_bits": 0.0, "fractional_bits": 7.0}, - } - else: - custom_scheme["layer_specific"][layer.name] = {"weight": {"integer_bits": 0.0, "fractional_bits": 7.0}} - if hasattr(layer.activation, "__name__") and layer.activation.__name__ in ["relu", "tanh"]: - custom_scheme["layer_specific"][layer.name][layer.activation.__name__] = { - "integer_bits": 0.0, - "fractional_bits": 7.0, - } - custom_scheme["disable_pruning_for_layers"].append(layer.name) - if layer.__class__ == SeparableConv2D: - if layer.use_bias: - custom_scheme["layer_specific"][layer.name] = { - "depthwise": { - "weight": {"integer_bits": 0.0, "fractional_bits": 7.0}, - }, - "pointwise": { - "weight": {"integer_bits": 0.0, "fractional_bits": 7.0}, - "bias": {"integer_bits": 0.0, "fractional_bits": 7.0}, - }, - } - else: - custom_scheme["layer_specific"][layer.name] = { - "depthwise": {"weight": {"integer_bits": 0.0, "fractional_bits": 7.0}}, - "pointwise": {"weight": {"integer_bits": 0.0, "fractional_bits": 7.0}}, - } - if hasattr(layer.activation, "__name__") and layer.activation.__name__ in ["relu", "tanh"]: - custom_scheme["layer_specific"][layer.name][layer.activation.__name__] = { - "integer_bits": 0.0, - "fractional_bits": 7.0, - } - custom_scheme["disable_pruning_for_layers"].append(layer.name + "_depthwise") - custom_scheme["disable_pruning_for_layers"].append(layer.name + "_pointwise") - elif layer.__class__ in [Activation, ReLU, AveragePooling1D, AveragePooling2D, AveragePooling3D]: - custom_scheme["layer_specific"][layer.name] = {"integer_bits": 0.0, "fractional_bits": 7.0} - config["quantization_parameters"]["layer_specific"] = custom_scheme["layer_specific"] - config["pruning_parameters"]["disable_pruning_for_layers"] = custom_scheme["disable_pruning_for_layers"] - return config diff --git a/src/pquant/core/tf_impl/train_tf.py b/src/pquant/core/tf_impl/train_tf.py deleted file mode 100644 index d19b650..0000000 --- a/src/pquant/core/tf_impl/train_tf.py +++ /dev/null @@ -1,45 +0,0 @@ -import keras - -from pquant.core.tf_impl.compressed_layers_tf import ( - call_post_round_functions, - post_epoch_functions, - post_pretrain_functions, - pre_epoch_functions, - pre_finetune_functions, - save_weights_functions, -) - - -def iterative_train_tf(model, config, train_func, valid_func, **kwargs): - """ - Generic training loop, user provides training and validation functions - """ - epoch = keras.ops.convert_to_tensor(0) # Keeps track of all the epochs completed - training_config = config["training_parameters"] - if training_config["pretraining_epochs"] > 0: - for e in range(training_config["pretraining_epochs"]): - pre_epoch_functions(model, e, training_config["pretraining_epochs"]) - train_func(model, epoch=epoch, **kwargs) - valid_func(model, epoch=epoch, **kwargs) - post_epoch_functions(model, e, training_config["pretraining_epochs"]) - epoch += 1 - post_pretrain_functions(model, config) - for r in range(training_config["rounds"]): - for e in range(training_config["epochs"]): - if r == 0 and training_config["save_weights_epoch"] == e: - save_weights_functions(model) - pre_epoch_functions(model, e, training_config["epochs"]) - train_func(model, epoch=epoch, **kwargs) - valid_func(model, epoch=epoch, **kwargs) - post_epoch_functions(model, e, training_config["epochs"]) - epoch += 1 - call_post_round_functions(model, training_config["rewind"], training_config["rounds"], r) - pre_finetune_functions(model) - if training_config["fine_tuning_epochs"] > 0: - for e in range(training_config["fine_tuning_epochs"]): - pre_epoch_functions(model, e, training_config["fine_tuning_epochs"]) - train_func(model, epoch=epoch, **kwargs) - valid_func(model, epoch=epoch, **kwargs) - post_epoch_functions(model, e, training_config["fine_tuning_epochs"]) - epoch += 1 - return model diff --git a/src/pquant/core/torch_backend.py b/src/pquant/core/torch_backend.py new file mode 100644 index 0000000..c67df02 --- /dev/null +++ b/src/pquant/core/torch_backend.py @@ -0,0 +1,366 @@ +import torch +import torch.nn as nn +from torch.fx import symbolic_trace + +from pquant.core.activations_quantizer import ( + QuantizedPooling, + QuantizedReLU, + QuantizedTanh, +) +from pquant.core.backend_interface import BackendInterface +from pquant.core.compressed_layers import ( + CompressedLayerBase, + CompressedLayerConv1d, + CompressedLayerConv2d, + CompressedLayerLinear, +) + + +class TorchBackend(BackendInterface): + def iterative_train(self, model, config, train_func, valid_func, **kwargs): + """ + Generic training loop, user provides training and validation functions + """ + epoch = torch.tensor(0) # Keeps track of all the epochs completed + training_config = config.training_parameters + if training_config.pretraining_epochs > 0: + for e in range(training_config.pretraining_epochs): + model.train() + self.pre_epoch_functions(model, e, training_config.pretraining_epochs) + train_func(model, epoch=epoch, **kwargs) + model.eval() + valid_func(model, epoch=epoch, **kwargs) + self.post_epoch_functions(model, e, training_config.pretraining_epochs) + epoch += 1 + self.post_pretrain_functions(model, config) + for r in range(training_config.rounds): + for e in range(training_config.epochs): + model.train() + if r == 0 and training_config.save_weights_epoch == e: + self.save_weights_functions(model) + self.pre_epoch_functions(model, e, training_config.epochs) + train_func(model, epoch=epoch, **kwargs) + model.eval() + valid_func(model, epoch=epoch, **kwargs) + self.post_epoch_functions(model, e, training_config.epochs) + epoch += 1 + self.call_post_round_functions(model, training_config.rewind, training_config.rounds, r) + self.pre_finetune_functions(model) + if training_config.fine_tuning_epochs > 0: + for e in range(training_config.fine_tuning_epochs): + model.train() + self.pre_epoch_functions(model, e, training_config.fine_tuning_epochs) + train_func(model, epoch=epoch, **kwargs) + model.eval() + valid_func(model, epoch=epoch, **kwargs) + self.post_epoch_functions(model, e, training_config.fine_tuning_epochs) + epoch += 1 + return model + + def create_default_layer_quantization_pruning_config(self, model): + config = {"layer_specific": {}, "disable_pruning_for_layers": []} + for name, layer in model.named_modules(): + if layer.__class__ in [nn.Linear, nn.Conv1d, nn.Conv2d]: + if layer.bias is None: + config["layer_specific"][name] = {"weight": {"integer_bits": 0, "fractional_bits": 7}} + else: + config["layer_specific"][name] = { + "weight": {"integer_bits": 0, "fractional_bits": 7}, + "bias": {"integer_bits": 0, "fractional_bits": 7}, + } + config["disable_pruning_for_layers"].append(name) + elif layer.__class__ in [nn.Tanh, nn.ReLU, nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]: + config["layer_specific"][name] = {"integer_bits": 0, "fractional_bits": 7} + return config + + def add_default_layer_quantization_pruning_to_config(self, model, config): + custom_scheme = self.create_default_layer_quantization_pruning_config(model) + config.quantization_parameters.layer_specific = custom_scheme["layer_specific"] + config.pruning_parameters.disable_pruning_for_layers = custom_scheme["disable_pruning_for_layers"] + return config + + def remove_pruning_from_model(self, module, config): + for name, layer in module.named_children(): + if isinstance(layer, CompressedLayerLinear): + if config.pruning_parameters.pruning_method == "pdp": # Find better solution later + if config.training_parameters.pruning_first: + weight = layer.weight + if layer.enable_pruning: + weight = layer.pruning_layer.get_hard_mask(weight) * weight + weight, bias = layer.quantize(weight, layer.bias) + else: + weight, bias = layer.quantize(layer.weight, layer.bias) + if layer.enable_pruning: + weight = layer.pruning_layer.get_hard_mask(weight) * weight + else: + weight, bias = layer.prune_and_quantize(layer.weight, layer.bias) + out_features = layer.out_features + bias_values = bias + in_features = layer.in_features + bias = True if bias_values is not None else False + setattr(module, name, nn.Linear(in_features=in_features, out_features=out_features, bias=bias)) + getattr(module, name).weight.data.copy_(weight) + if getattr(module, name).bias is not None: + getattr(module, name).bias.data.copy_(bias_values.data) + elif isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d)): + if config.pruning_parameters.pruning_method == "pdp": # Find better solution later + if config.training_parameters.pruning_first: + weight = layer.weight + if layer.enable_pruning: + weight = layer.pruning_layer.get_hard_mask(weight) * weight + weight, bias = layer.quantize(weight, layer.bias) + else: + weight, bias = layer.quantize(layer.weight, layer.bias) + if layer.enable_pruning: + weight = layer.pruning_layer.get_hard_mask(weight) * weight + else: + weight, bias = layer.prune_and_quantize(layer.weight, layer.bias) + bias_values = bias + bias = True if bias_values is not None else False + conv = nn.Conv2d if isinstance(layer, CompressedLayerConv2d) else nn.Conv1d + setattr( + module, + name, + conv( + layer.in_channels, + layer.out_channels, + layer.kernel_size, + layer.stride, + layer.padding, + layer.dilation, + layer.groups, + bias, + layer.padding_mode, + ), + ) + getattr(module, name).weight.data.copy_(weight) + if getattr(module, name).bias is not None: + getattr(module, name).bias.data.copy_(bias_values.data) + else: + self.remove_pruning_from_model(layer, config) + return module + + def add_quantized_activations_to_model_layer(self, module, config): + if not config.quantization_parameters.enable_quantization: + return module + # Replaces ReLU and Tanh layers with quantized versions + for name, layer in module.named_children(): + i = config.quantization_parameters.default_integer_bits + f = config.quantization_parameters.default_fractional_bits + if layer.__class__ in [nn.ReLU]: + # For ReLU, if using default values, add 1 bit since values are unsigned. + # Otherwise user provides bits. TODO: Find better way to do this + f = config.quantization_parameters.default_fractional_bits + 1 + relu = QuantizedReLU(config, i=i, f=f) + setattr(module, name, relu) + elif layer.__class__ in [nn.Tanh]: + tanh = QuantizedTanh(config, i=0.0, f=f) + setattr(module, name, tanh) + elif layer.__class__ in [nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]: + new_layer = QuantizedPooling(config, layer) + setattr(module, name, new_layer) + else: + layer = self.add_quantized_activations_to_model_layer(layer, config) + return module + + def add_pruning_to_model(self, module, config): + for name, layer in module.named_children(): + if layer.__class__ is nn.Linear: + sparse_layer = CompressedLayerLinear(config, layer, "linear") + sparse_layer.pruning_layer.build(layer.weight.shape) + setattr(module, name, sparse_layer) + elif layer.__class__ is nn.Conv2d: + sparse_layer = CompressedLayerConv2d(config, layer, "conv") + sparse_layer.pruning_layer.build(layer.weight.shape) + setattr(module, name, sparse_layer) + elif layer.__class__ is nn.Conv1d: + sparse_layer = CompressedLayerConv1d(config, layer, "conv") + sparse_layer.pruning_layer.build(layer.weight.shape) + setattr(module, name, sparse_layer) + else: + self.add_pruning_to_model(layer, config) + return module + + def disable_pruning_from_layers(self, module, config): + for name, layer in module.named_modules(): + enable_pruning = name not in config.pruning_parameters.disable_pruning_for_layers + if ( + layer.__class__ in [CompressedLayerLinear, CompressedLayerConv2d, CompressedLayerConv1d] + and not enable_pruning + ): + layer.enable_pruning = enable_pruning + return module + + def add_layer_specific_quantization_to_model(self, module, config): + for name, layer in module.named_modules(): + if isinstance(layer, CompressedLayerBase): + if name in config.quantization_parameters.layer_specific: + if "weight" in config.quantization_parameters.layer_specific[name]: + weight_int_bits = config.quantization_parameters.layer_specific[name]["weight"]["integer_bits"] + weight_fractional_bits = config.quantization_parameters.layer_specific[name]["weight"][ + "fractional_bits" + ] + layer.i_weight = torch.tensor(weight_int_bits) + layer.f_weight = torch.tensor(weight_fractional_bits) + if "bias" in config.quantization_parameters.layer_specific[name]: + bias_int_bits = config.quantization_parameters.layer_specific[name]["bias"]["integer_bits"] + bias_fractional_bits = config.quantization_parameters.layer_specific[name]["bias"]["fractional_bits"] + layer.i_bias = torch.tensor(bias_int_bits) + layer.f_bias = torch.tensor(bias_fractional_bits) + layer.build(None) + elif layer.__class__ in [QuantizedPooling, QuantizedReLU, QuantizedTanh]: + if name in config.quantization_parameters.layer_specific: + i = config.quantization_parameters.layer_specific[name]["integer_bits"] + f = config.quantization_parameters.layer_specific[name]["fractional_bits"] + layer.set_activation_bits(i, f) + return module + + def add_quantized_activations_to_model_functional(self, module, config): + # Currently not in use. TODO: Fix this + if config.quantization_parameters.use_high_granularity_quantization: + return module + # Replaces functional activation calls with quantized versions + traced_model = symbolic_trace(module) + for node in traced_model.graph.nodes: + if node.op in ["call_method", "call_function"] and ( + node.target == "tanh" or "function relu" in str(node.target) + ): + with traced_model.graph.inserting_after(node): + if node.name in config.quantization_parameters.layer_specific: + bits = config.quantization_parameters.layer_specific[node.name]["bits"] + else: + bits = ( + config.quantization_parameters.default_integer_bits + + config.quantization_parameters.default_fractional_bits + + 1 + ) # 1 sign bit + kwargs = {"bits": bits} + if node.target == "tanh": + kwargs["use_real_tanh"] = config.quantization_parameters.use_real_tanh + kwargs["use_symmetric"] = config.quantization_parameters.use_symmetric_quantization + # new_node = traced_model.graph.call_function(quantized_tanh, node.args, kwargs) + else: + kwargs = {"integer_bits": config.quantization_parameters.default_integer_bits, "bits": bits} + # new_node = traced_model.graph.call_function(quantized_relu, node.args, kwargs) + # node.replace_all_uses_with(new_node) + traced_model.graph.erase_node(node) + + traced_model.graph.lint() + traced_model.recompile() + return traced_model + + def add_compression_layers(self, model, config, input_shape): + model = self.add_quantized_activations_to_model_layer(model, config) + # model = self.add_quantized_activations_to_model_functional(model, config) + model = self.add_pruning_to_model(model, config) + model = self.disable_pruning_from_layers(model, config) + model = self.add_layer_specific_quantization_to_model(model, config) + model(torch.rand(input_shape, device=next(model.parameters()).device)) + return model + + def post_epoch_functions(self, model, epoch, total_epochs, **kwargs): + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + layer.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + + def pdp_setup(self, model, config): + """ + Calculates a global sparsity threshold. Initializes target sparsity for each layer, which depends on + how large percentage of weights in the layer is smaller than the global threshold + """ + global_weights = None + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + if global_weights is None: + global_weights = layer.weight.flatten() + else: + global_weights = torch.concat((global_weights, layer.weight.flatten())) + + abs_global_weights = torch.abs(global_weights) + global_weight_topk, _ = torch.topk(abs_global_weights, abs_global_weights.numel()) + threshold = global_weight_topk[int((1 - config.pruning_parameters.sparsity) * global_weight_topk.numel())] + global_weights_below_threshold = torch.where(abs_global_weights < threshold, 1, 0) + idx = 0 + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + weight_size = layer.weight.numel() + w = torch.sum(global_weights_below_threshold[idx : idx + weight_size]) + layer.pruning_layer.init_r = w / weight_size + layer.pruning_layer.sparsity = w / weight_size # Wanda + idx += weight_size + + def post_pretrain_functions(self, model, config): + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + layer.pruning_layer.post_pre_train_function() + elif isinstance(layer, (QuantizedReLU, QuantizedTanh, QuantizedPooling)): + layer.post_pre_train_function() + if config.pruning_parameters.pruning_method == "pdp" or ( + config.pruning_parameters.pruning_method == "wanda" and config.pruning_parameters.calculate_pruning_budget + ): + self.pdp_setup(model, config) + + def pre_epoch_functions(self, model, epoch, total_epochs): + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + layer.pruning_layer.pre_epoch_function(epoch, total_epochs) + + def pre_finetune_functions(self, model): + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + layer.pruning_layer.pre_finetune_function() + + def save_weights_functions(self, model): + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + layer.save_weights() + + def post_round_functions(self, model): + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + layer.pruning_layer.post_round_function() + + def rewind_weights_functions(self, model): + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + layer.rewind_weights() + + @torch.no_grad + def get_layer_keep_ratio(self, model): + total_w = 0 + remaining_weights = 0 + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + if layer.pruning_first: + weight = layer.weight + if layer.enable_pruning: + weight = layer.pruning_layer.get_hard_mask(weight) * weight + weight, bias = layer.quantize(weight, layer.bias) + total_w += weight.numel() + rem = torch.count_nonzero(weight) + remaining_weights += rem + else: + weight, bias = layer.quantize(layer.weight, layer.bias) + if layer.enable_pruning: + weight = layer.pruning_layer.get_hard_mask(weight) * weight + total_w += weight.numel() + rem = torch.count_nonzero(weight) + remaining_weights += rem + elif layer.__class__ in (nn.Conv2d, nn.Conv1d, nn.Linear): + total_w += layer.weight.numel() + remaining_weights += torch.count_nonzero(layer.weight) + if total_w != 0: + return remaining_weights / total_w + return 0.0 + + def get_model_losses(self, model, losses): + for layer in model.modules(): + if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): + loss = layer.pruning_layer.calculate_additional_loss() + if layer.use_high_granularity_quantization: + loss += layer.hgq_loss() + losses += loss + elif isinstance(layer, (QuantizedReLU, QuantizedTanh, QuantizedPooling)): + if layer.use_high_granularity_quantization: + losses += layer.hgq_loss() + return losses diff --git a/src/pquant/core/torch_impl/compressed_layers_torch.py b/src/pquant/core/torch_impl/compressed_layers_torch.py deleted file mode 100644 index 6055d9e..0000000 --- a/src/pquant/core/torch_impl/compressed_layers_torch.py +++ /dev/null @@ -1,611 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from hgq.quantizer import Quantizer -from quantizers import get_fixed_quantizer -from torch.fx import symbolic_trace - -from pquant.core.activations_quantizer import QuantizedReLU, QuantizedTanh -from pquant.core.utils import get_pruning_layer - - -class CompressedLayerBase(nn.Module): - def __init__(self, config, layer, layer_type): - super().__init__() - self.f_weight = torch.tensor(config["quantization_parameters"]["default_fractional_bits"]) - self.i_weight = torch.tensor(config["quantization_parameters"]["default_integer_bits"]) - self.f_bias = torch.tensor(config["quantization_parameters"]["default_fractional_bits"]) - self.i_bias = torch.tensor(config["quantization_parameters"]["default_integer_bits"]) - - self.weight = nn.Parameter(layer.weight.clone()) - self.pruning_layer = get_pruning_layer(config=config, layer_type=layer_type) - self.pruning_method = config["pruning_parameters"]["pruning_method"] - self.overflow = "SAT_SYM" if config["quantization_parameters"]["use_symmetric_quantization"] else "SAT" - self.quantizer = get_fixed_quantizer(overflow_mode=self.overflow) - self.hgq_heterogeneous = config["quantization_parameters"]["hgq_heterogeneous"] - - self.bias = nn.Parameter(layer.bias.clone()) if layer.bias is not None else None - self.init_weight = self.weight.clone() - self.pruning_first = config["training_parameters"]["pruning_first"] - self.enable_quantization = config["quantization_parameters"]["enable_quantization"] - self.use_high_granularity_quantization = config["quantization_parameters"]["use_high_granularity_quantization"] - self.enable_pruning = config["pruning_parameters"]["enable_pruning"] - self.hgq_gamma = config["quantization_parameters"]["hgq_gamma"] - - def build(self, input_shape): - if self.use_high_granularity_quantization: - if self.hgq_heterogeneous: - self.hgq_weight = Quantizer( - k0=1.0, - i0=self.i_weight, - f0=self.f_weight, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - homogeneous_axis=(), - ) - self.hgq_weight.build(self.weight.shape) - if self.bias is not None: - self.hgq_bias = Quantizer( - k0=1.0, - i0=self.i_bias, - f0=self.f_bias, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - homogeneous_axis=(), - ) - self.hgq_bias.build(self.bias.shape) - else: - self.hgq_weight = Quantizer( - k0=1.0, - i0=self.i_weight, - f0=self.f_weight, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - heterogeneous_axis=(), - ) - self.hgq_weight.build(self.weight.shape) - if self.bias is not None: - self.hgq_bias = Quantizer( - k0=1.0, - i0=self.i_bias, - f0=self.f_bias, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - heterogeneous_axis=(), - ) - self.hgq_bias.build(self.bias.shape) - - def save_weights(self): - self.init_weight = self.weight.clone() - - def rewind_weights(self): - self.weight.data = self.init_weight.clone() - - def hgq_loss(self): - if self.pruning_layer.is_pretraining: - return 0.0 - loss = (torch.sum(self.hgq_weight.quantizer.i) + torch.sum(self.hgq_weight.quantizer.f)) * self.hgq_gamma - if self.bias is not None: - loss += (torch.sum(self.hgq_bias.quantizer.i) + torch.sum(self.hgq_bias.quantizer.f)) * self.hgq_gamma - return loss - - def quantize(self, weight, bias): - if self.enable_quantization: - if self.use_high_granularity_quantization: - weight = self.hgq_weight(weight) - bias = None if bias is None else self.hgq_bias(bias) - else: - weight = self.quantizer(weight, k=torch.tensor(1.0), i=self.i_weight, f=self.f_weight, training=True) - bias = ( - None - if bias is None - else self.quantizer(bias, k=torch.tensor(1.0), i=self.i_bias, f=self.f_bias, training=True) - ) - return weight, bias - - def prune(self, weight): - if self.enable_pruning: - weight = self.pruning_layer(weight) - return weight - - def prune_and_quantize(self, weight, bias): - if self.pruning_first: - weight = self.prune(weight) - weight, bias = self.quantize(weight, bias) - else: - weight, bias = self.quantize(weight, bias) - weight = self.prune(weight) - return weight, bias - - def forward(self, x): - weight, bias = self.prune_and_quantize(self.weight, self.bias) - if self.pruning_method == "wanda": - self.pruning_layer.collect_input(x, self.weight, self.training) - x = F.linear(x, weight, bias) - if self.pruning_method == "activation_pruning": - self.pruning_layer.collect_output(x, self.training) - return x - - -class CompressedLayerLinear(CompressedLayerBase): - def __init__(self, config, layer, layer_type): - super().__init__(config, layer, layer_type) - self.in_features = layer.in_features - self.out_features = layer.out_features - - def forward(self, x): - weight, bias = self.prune_and_quantize(self.weight, self.bias) - if self.pruning_method == "wanda": - self.pruning_layer.collect_input(x, self.weight, self.training) - x = F.linear(x, weight, bias) - if self.pruning_method == "activation_pruning": - self.pruning_layer.collect_output(x, self.training) - return x - - -class CompressedLayerConv2d(CompressedLayerBase): - def __init__(self, config, layer, layer_type): - super().__init__(config, layer, layer_type) - self.stride = layer.stride - self.dilation = layer.dilation - self.padding = layer.padding - self.groups = layer.groups - self.in_channels = layer.in_channels - self.out_channels = layer.out_channels - self.kernel_size = layer.kernel_size - self.padding_mode = layer.padding_mode - - def forward(self, x): - weight, bias = self.prune_and_quantize(self.weight, self.bias) - if self.pruning_method == "wanda": - self.pruning_layer.collect_input(x, weight, self.training) - x = F.conv2d( - input=x, - weight=weight, - bias=bias, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=self.groups, - ) - if self.pruning_method == "activation_pruning": - self.pruning_layer.collect_output(x, self.training) - return x - - -class CompressedLayerConv1d(CompressedLayerBase): - def __init__(self, config, layer, layer_type): - super().__init__(config, layer, layer_type) - - self.stride = layer.stride - self.dilation = layer.dilation - self.padding = layer.padding - self.groups = layer.groups - self.in_channels = layer.in_channels - self.out_channels = layer.out_channels - self.kernel_size = layer.kernel_size - self.padding_mode = layer.padding_mode - - def forward(self, x): - weight, bias = self.prune_and_quantize(self.weight, self.bias) - if self.pruning_method == "wanda": - self.pruning_layer.collect_input(x, self.weight, self.training) - x = F.conv1d( - input=x, - weight=weight, - bias=bias, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=self.groups, - ) - if self.pruning_method == "activation_pruning": - self.pruning_layer.collect_output(x, self.training) - return x - - -def add_compression_layers_torch(model, config, input_shape): - model = add_quantized_activations_to_model_layer(model, config) - # model = add_quantized_activations_to_model_functional(model, config) - model = add_pruning_to_model(model, config) - model = disable_pruning_from_layers(model, config) - model = add_layer_specific_quantization_to_model(model, config) - model(torch.rand(input_shape, device=next(model.parameters()).device)) - return model - - -class QuantizedPooling(nn.Module): - - def __init__(self, config, layer): - super().__init__() - self.f = torch.tensor(config["quantization_parameters"]["default_fractional_bits"]) - self.i = torch.tensor(config["quantization_parameters"]["default_integer_bits"]) - self.overflow = "SAT_SYM" if config["quantization_parameters"]["use_symmetric_quantization"] else "SAT" - self.config = config - self.hgq_heterogeneous = config["quantization_parameters"]["hgq_heterogeneous"] - self.is_pretraining = True - self.use_high_granularity_quantization = config["quantization_parameters"]["use_high_granularity_quantization"] - self.pooling = layer - self.hgq_gamma = config["quantization_parameters"]["hgq_gamma"] - - def build(self, input_shape): - if self.use_high_granularity_quantization: - if self.hgq_heterogeneous: - self.hgq = Quantizer( - k0=1.0, - i0=self.i, - f0=self.f, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - homogeneous_axis=(0,), - ) - - else: - self.hgq = Quantizer( - k0=1.0, - i0=self.i, - f0=self.f, - round_mode="RND", - overflow_mode=self.overflow, - q_type="kif", - heterogeneous_axis=(), - ) - self.hgq.build(input_shape) - else: - self.quantizer = get_fixed_quantizer(round_mode="RND", overflow_mode=self.overflow) - - def set_activation_bits(self, i, f): - self.i = torch.tensor(i) - self.f = torch.tensor(f) - - def post_pre_train_function(self): - self.is_pretraining = False - - def hgq_loss(self): - if self.is_pretraining: - return 0.0 - return (torch.sum(self.hgq.quantizer.i) + torch.sum(self.hgq.quantizer.f)) * self.config["quantization_parameters"][ - "hgq_gamma" - ] - - def quantize(self, x): - if not hasattr(self, "hgq") or not hasattr(self, "quantizer"): - self.build(x.shape) - if self.use_high_granularity_quantization: - x = self.hgq(x) - else: - x = self.quantizer(x, k=torch.tensor(1.0), i=self.i, f=self.f, training=True) - return x - - def forward(self, x): - x = self.pooling(x) - return self.quantize(x) - - -def add_layer_specific_quantization_to_model(module, config): - for name, layer in module.named_modules(): - if isinstance(layer, CompressedLayerBase): - if name in config["quantization_parameters"]["layer_specific"]: - if "weight" in config["quantization_parameters"]["layer_specific"][name]: - weight_int_bits = config["quantization_parameters"]["layer_specific"][name]["weight"]["integer_bits"] - weight_fractional_bits = config["quantization_parameters"]["layer_specific"][name]["weight"][ - "fractional_bits" - ] - layer.i_weight = torch.tensor(weight_int_bits) - layer.f_weight = torch.tensor(weight_fractional_bits) - if "bias" in config["quantization_parameters"]["layer_specific"][name]: - bias_int_bits = config["quantization_parameters"]["layer_specific"][name]["bias"]["integer_bits"] - bias_fractional_bits = config["quantization_parameters"]["layer_specific"][name]["bias"][ - "fractional_bits" - ] - layer.i_bias = torch.tensor(bias_int_bits) - layer.f_bias = torch.tensor(bias_fractional_bits) - layer.build(None) - elif layer.__class__ in [QuantizedPooling, QuantizedReLU, QuantizedTanh]: - if name in config["quantization_parameters"]["layer_specific"]: - i = config["quantization_parameters"]["layer_specific"][name]["integer_bits"] - f = config["quantization_parameters"]["layer_specific"][name]["fractional_bits"] - layer.set_activation_bits(i, f) - return module - - -def add_quantized_activations_to_model_layer(module, config): - if not config["quantization_parameters"]["enable_quantization"]: - return module - # Replaces ReLU and Tanh layers with quantized versions - for name, layer in module.named_children(): - i = config["quantization_parameters"]["default_integer_bits"] - f = config["quantization_parameters"]["default_fractional_bits"] - if layer.__class__ in [nn.ReLU]: - # For ReLU, if using default values, add 1 bit since values are unsigned. - # Otherwise user provides bits. TODO: Find better way to do this - f = config["quantization_parameters"]["default_fractional_bits"] + 1 - relu = QuantizedReLU(config, i=i, f=f) - setattr(module, name, relu) - elif layer.__class__ in [nn.Tanh]: - tanh = QuantizedTanh(config, i=0.0, f=f) - setattr(module, name, tanh) - elif layer.__class__ in [nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]: - new_layer = QuantizedPooling(config, layer) - setattr(module, name, new_layer) - else: - layer = add_quantized_activations_to_model_layer(layer, config) - return module - - -def add_quantized_activations_to_model_functional(module, config): - # Currently not in use. TODO: Fix this - if config["quantization_parameters"]["use_high_granularity_quantization"]: - return module - # Replaces functional activation calls with quantized versions - traced_model = symbolic_trace(module) - for node in traced_model.graph.nodes: - if node.op in ["call_method", "call_function"] and (node.target == "tanh" or "function relu" in str(node.target)): - with traced_model.graph.inserting_after(node): - if node.name in config["quantization_parameters"]["layer_specific"]: - bits = config["quantization_parameters"]["layer_specific"][node.name]["bits"] - else: - bits = ( - config["quantization_parameters"]["default_integer_bits"] - + config["quantization_parameters"]["default_fractional_bits"] - + 1 - ) # 1 sign bit - kwargs = {"bits": bits} - if node.target == "tanh": - kwargs["use_real_tanh"] = config["quantization_parameters"]["use_real_tanh"] - kwargs["use_symmetric"] = config["quantization_parameters"]["use_symmetric_quantization"] - # new_node = traced_model.graph.call_function(quantized_tanh, node.args, kwargs) - else: - kwargs = {"integer_bits": config["quantization_parameters"]["default_integer_bits"], "bits": bits} - # new_node = traced_model.graph.call_function(quantized_relu, node.args, kwargs) - # node.replace_all_uses_with(new_node) - traced_model.graph.erase_node(node) - - traced_model.graph.lint() - traced_model.recompile() - return traced_model - - -def disable_pruning_from_layers(module, config): - for name, layer in module.named_modules(): - enable_pruning = name not in config["pruning_parameters"]["disable_pruning_for_layers"] - if layer.__class__ in [CompressedLayerLinear, CompressedLayerConv2d, CompressedLayerConv1d] and not enable_pruning: - layer.enable_pruning = enable_pruning - return module - - -def add_pruning_to_model(module, config): - for name, layer in module.named_children(): - if layer.__class__ is nn.Linear: - sparse_layer = CompressedLayerLinear(config, layer, "linear") - sparse_layer.pruning_layer.build(layer.weight.shape) - setattr(module, name, sparse_layer) - elif layer.__class__ is nn.Conv2d: - sparse_layer = CompressedLayerConv2d(config, layer, "conv") - sparse_layer.pruning_layer.build(layer.weight.shape) - setattr(module, name, sparse_layer) - elif layer.__class__ is nn.Conv1d: - sparse_layer = CompressedLayerConv1d(config, layer, "conv") - sparse_layer.pruning_layer.build(layer.weight.shape) - setattr(module, name, sparse_layer) - else: - add_pruning_to_model(layer, config) - return module - - -def remove_pruning_from_model_torch(module, config): - for name, layer in module.named_children(): - if isinstance(layer, CompressedLayerLinear): - if config["pruning_parameters"]["pruning_method"] == "pdp": # Find better solution later - if config["training_parameters"]["pruning_first"]: - weight = layer.weight - if layer.enable_pruning: - weight = layer.pruning_layer.get_hard_mask(weight) * weight - weight, bias = layer.quantize(weight, layer.bias) - else: - weight, bias = layer.quantize(layer.weight, layer.bias) - if layer.enable_pruning: - weight = layer.pruning_layer.get_hard_mask(weight) * weight - else: - weight, bias = layer.prune_and_quantize(layer.weight, layer.bias) - out_features = layer.out_features - bias_values = bias - in_features = layer.in_features - bias = True if bias_values is not None else False - setattr(module, name, nn.Linear(in_features=in_features, out_features=out_features, bias=bias)) - getattr(module, name).weight.data.copy_(weight) - if getattr(module, name).bias is not None: - getattr(module, name).bias.data.copy_(bias_values.data) - elif isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d)): - if config["pruning_parameters"]["pruning_method"] == "pdp": # Find better solution later - if config["training_parameters"]["pruning_first"]: - weight = layer.weight - if layer.enable_pruning: - weight = layer.pruning_layer.get_hard_mask(weight) * weight - weight, bias = layer.quantize(weight, layer.bias) - else: - weight, bias = layer.quantize(layer.weight, layer.bias) - if layer.enable_pruning: - weight = layer.pruning_layer.get_hard_mask(weight) * weight - else: - weight, bias = layer.prune_and_quantize(layer.weight, layer.bias) - bias_values = bias - bias = True if bias_values is not None else False - conv = nn.Conv2d if isinstance(layer, CompressedLayerConv2d) else nn.Conv1d - setattr( - module, - name, - conv( - layer.in_channels, - layer.out_channels, - layer.kernel_size, - layer.stride, - layer.padding, - layer.dilation, - layer.groups, - bias, - layer.padding_mode, - ), - ) - getattr(module, name).weight.data.copy_(weight) - if getattr(module, name).bias is not None: - getattr(module, name).bias.data.copy_(bias_values.data) - else: - remove_pruning_from_model_torch(layer, config) - return module - - -def call_post_round_functions(model, rewind, rounds, r): - if rewind == "round": - rewind_weights_functions(model) - elif rewind == "post-ticket-search" and r == rounds - 1: - rewind_weights_functions(model) - else: - post_round_functions(model) - - -def post_epoch_functions(model, epoch, total_epochs, **kwargs): - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - layer.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) - - -def pre_epoch_functions(model, epoch, total_epochs): - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - layer.pruning_layer.pre_epoch_function(epoch, total_epochs) - - -def post_round_functions(model): - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - layer.pruning_layer.post_round_function() - - -def save_weights_functions(model): - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - layer.save_weights() - - -def rewind_weights_functions(model): - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - layer.rewind_weights() - - -def pre_finetune_functions(model): - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - layer.pruning_layer.pre_finetune_function() - - -def post_pretrain_functions(model, config): - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - layer.pruning_layer.post_pre_train_function() - elif isinstance(layer, (QuantizedReLU, QuantizedTanh, QuantizedPooling)): - layer.post_pre_train_function() - if config["pruning_parameters"]["pruning_method"] == "pdp" or ( - config["pruning_parameters"]["pruning_method"] == "wanda" - and config["pruning_parameters"]["calculate_pruning_budget"] - ): - pdp_setup(model, config) - - -def pdp_setup(model, config): - """ - Calculates a global sparsity threshold. Initializes target sparsity for each layer, which depends on - how large percentage of weights in the layer is smaller than the global threshold - """ - global_weights = None - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - if global_weights is None: - global_weights = layer.weight.flatten() - else: - global_weights = torch.concat((global_weights, layer.weight.flatten())) - - abs_global_weights = torch.abs(global_weights) - global_weight_topk, _ = torch.topk(abs_global_weights, abs_global_weights.numel()) - threshold = global_weight_topk[int((1 - config["pruning_parameters"]["sparsity"]) * global_weight_topk.numel())] - global_weights_below_threshold = torch.where(abs_global_weights < threshold, 1, 0) - idx = 0 - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - weight_size = layer.weight.numel() - w = torch.sum(global_weights_below_threshold[idx : idx + weight_size]) - layer.pruning_layer.init_r = w / weight_size - layer.pruning_layer.sparsity = w / weight_size # Wanda - idx += weight_size - - -@torch.no_grad -def get_layer_keep_ratio_torch(model): - total_w = 0 - remaining_weights = 0 - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - if layer.pruning_first: - weight = layer.weight - if layer.enable_pruning: - weight = layer.pruning_layer.get_hard_mask(weight) * weight - weight, bias = layer.quantize(weight, layer.bias) - total_w += weight.numel() - rem = torch.count_nonzero(weight) - remaining_weights += rem - else: - weight, bias = layer.quantize(layer.weight, layer.bias) - if layer.enable_pruning: - weight = layer.pruning_layer.get_hard_mask(weight) * weight - total_w += weight.numel() - rem = torch.count_nonzero(weight) - remaining_weights += rem - elif layer.__class__ in (nn.Conv2d, nn.Conv1d, nn.Linear): - total_w += layer.weight.numel() - remaining_weights += torch.count_nonzero(layer.weight) - if total_w != 0: - return remaining_weights / total_w - return 0.0 - - -def get_model_losses_torch(model, losses): - for layer in model.modules(): - if isinstance(layer, (CompressedLayerConv2d, CompressedLayerConv1d, CompressedLayerLinear)): - loss = layer.pruning_layer.calculate_additional_loss() - if layer.use_high_granularity_quantization: - loss += layer.hgq_loss() - losses += loss - elif isinstance(layer, (QuantizedReLU, QuantizedTanh, QuantizedPooling)): - if layer.use_high_granularity_quantization: - losses += layer.hgq_loss() - return losses - - -def create_default_layer_quantization_pruning_config(model): - config = {"layer_specific": {}, "disable_pruning_for_layers": []} - for name, layer in model.named_modules(): - if layer.__class__ in [nn.Linear, nn.Conv1d, nn.Conv2d]: - if layer.bias is None: - config["layer_specific"][name] = {"weight": {"integer_bits": 0, "fractional_bits": 7}} - else: - config["layer_specific"][name] = { - "weight": {"integer_bits": 0, "fractional_bits": 7}, - "bias": {"integer_bits": 0, "fractional_bits": 7}, - } - config["disable_pruning_for_layers"].append(name) - elif layer.__class__ in [nn.Tanh, nn.ReLU, nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]: - config["layer_specific"][name] = {"integer_bits": 0, "fractional_bits": 7} - return config - - -def add_default_layer_quantization_pruning_to_config_torch(model, config): - custom_scheme = create_default_layer_quantization_pruning_config(model) - config["quantization_parameters"]["layer_specific"] = custom_scheme["layer_specific"] - config["pruning_parameters"]["disable_pruning_for_layers"] = custom_scheme["disable_pruning_for_layers"] - return config diff --git a/src/pquant/core/torch_impl/train_torch.py b/src/pquant/core/torch_impl/train_torch.py deleted file mode 100644 index c2d788a..0000000 --- a/src/pquant/core/torch_impl/train_torch.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch - -from pquant.core.torch_impl.compressed_layers_torch import ( - call_post_round_functions, - post_epoch_functions, - post_pretrain_functions, - pre_epoch_functions, - pre_finetune_functions, - save_weights_functions, -) - - -def iterative_train_torch(model, config, train_func, valid_func, **kwargs): - """ - Generic training loop, user provides training and validation functions - """ - epoch = torch.tensor(0) # Keeps track of all the epochs completed - training_config = config["training_parameters"] - if training_config["pretraining_epochs"] > 0: - for e in range(training_config["pretraining_epochs"]): - model.train() - pre_epoch_functions(model, e, training_config["pretraining_epochs"]) - train_func(model, epoch=epoch, **kwargs) - model.eval() - valid_func(model, epoch=epoch, **kwargs) - post_epoch_functions(model, e, training_config["pretraining_epochs"]) - epoch += 1 - post_pretrain_functions(model, config) - for r in range(training_config["rounds"]): - for e in range(training_config["epochs"]): - model.train() - if r == 0 and training_config["save_weights_epoch"] == e: - save_weights_functions(model) - pre_epoch_functions(model, e, training_config["epochs"]) - train_func(model, epoch=epoch, **kwargs) - model.eval() - valid_func(model, epoch=epoch, **kwargs) - post_epoch_functions(model, e, training_config["epochs"]) - epoch += 1 - call_post_round_functions(model, training_config["rewind"], training_config["rounds"], r) - pre_finetune_functions(model) - if training_config["fine_tuning_epochs"] > 0: - for e in range(training_config["fine_tuning_epochs"]): - model.train() - pre_epoch_functions(model, e, training_config["fine_tuning_epochs"]) - train_func(model, epoch=epoch, **kwargs) - model.eval() - valid_func(model, epoch=epoch, **kwargs) - post_epoch_functions(model, e, training_config["fine_tuning_epochs"]) - epoch += 1 - return model diff --git a/src/pquant/core/train.py b/src/pquant/core/train.py index 776a1d7..1308c57 100644 --- a/src/pquant/core/train.py +++ b/src/pquant/core/train.py @@ -1,12 +1,6 @@ -import keras +from pquant.core.utils import get_backend def iterative_train(model, config, train_func, valid_func, **kwargs): - if keras.backend.backend() == "torch": - from pquant.core.torch_impl.train_torch import iterative_train_torch - - return iterative_train_torch(model, config, train_func, valid_func, **kwargs) - else: - from pquant.core.tf_impl.train_tf import iterative_train_tf - - return iterative_train_tf(model, config, train_func, valid_func, **kwargs) + backend = get_backend() + return backend.iterative_train(model, config, train_func, valid_func, **kwargs) diff --git a/src/pquant/core/utils.py b/src/pquant/core/utils.py index 4d66e2c..16174d1 100644 --- a/src/pquant/core/utils.py +++ b/src/pquant/core/utils.py @@ -1,7 +1,9 @@ import os +import keras import yaml +from pquant.core.backend_interface import BackendInterface from pquant.pruning_methods.activation_pruning import ActivationPruning from pquant.pruning_methods.autosparse import AutoSparse from pquant.pruning_methods.cs import ContinuousSparsification @@ -11,6 +13,16 @@ from pquant.pruning_methods.mdmm import MDMM +def get_backend() -> BackendInterface: + from pquant.core.tf_backend import TFBackend + from pquant.core.torch_backend import TorchBackend + + if keras.backend.backend() == "torch": + return TorchBackend() + else: + return TFBackend() + + def get_pruning_layer(config, layer_type): pruning_method = config["pruning_parameters"]["pruning_method"] if pruning_method == "dst":