Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/operator-run.c
Original file line number Diff line number Diff line change
Expand Up @@ -1538,11 +1538,57 @@ void xnn_compute_pad_qd8_params(
}
}

typedef struct xnn_qd8_quantization_params(bf16_quantization_params_fn)(
xnn_bfloat16 min, xnn_bfloat16 max, xnn_bfloat16* bf16_scale);
typedef struct xnn_qd8_quantization_params(f16_quantization_params_fn)(
xnn_float16 min, xnn_float16 max, xnn_float16* f32_scale);
typedef struct xnn_qd8_quantization_params(f32_quantization_params_fn)(
float min, float max, float* f32_scale);

void xnn_compute_bf16_qx8_convert(
struct bf16_qd8_convert_context* restrict context,
bf16_quantization_params_fn quantization_params_function,
size_t batch_index) {
const size_t x_stride = context->x_stride;
const size_t y_stride = context->y_stride;
const size_t n = context->n;
const void* input =
(const void*)((uintptr_t)context->x + x_stride * batch_index);
void* output = (void*)((uintptr_t)context->y + y_stride * batch_index);

xnn_bfloat16 minmax[2] = {xnn_bfloat16_from_bits(UINT16_C(0x7F80)),
xnn_bfloat16_from_bits(UINT16_C(0xFF80))};
context->rminmax_ukernel(n, input, minmax, &context->params);
xnn_bfloat16 bf16_scale;
context->quantization_params[batch_index] =
quantization_params_function(minmax[0], minmax[1], &bf16_scale);

struct xnn_bf16_qs8_cvt_params params;
params.scalar.scale = bf16_scale;
params.scalar.output_zero_point =
context->quantization_params[batch_index].zero_point;
context->convert_ukernel(n, input, output, (union xnn_unary_uparams*)&params);

if (context->rsum_ukernel) {
// Compute and store the row sum of the quantized output.
const size_t num_bytes = n / sizeof(xnn_bfloat16) * sizeof(int8_t);
int32_t row_sum = 0;
struct xnn_qs8_rsum_params rsum_params = {0,};
context->rsum_ukernel(num_bytes, output, &row_sum, &rsum_params);
context->row_sum[batch_index] = (float)row_sum;
}
}

void xnn_compute_bf16_qd8_convert(
struct bf16_qd8_convert_context* restrict context, size_t batch_offset,
size_t batch_range) {
for (size_t batch_index = batch_offset;
batch_index < batch_offset + batch_range; batch_index++) {
xnn_compute_bf16_qx8_convert(
context, xnn_bf16_qd8_asymmetric_quantization_params, batch_index);
}
}

void xnn_compute_f16_qx8_convert(
struct f16_qd8_convert_context* restrict context,
f16_quantization_params_fn quantization_params_function,
Expand Down
129 changes: 129 additions & 0 deletions src/operators/unary-elementwise-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,15 @@ enum xnn_status create_convert_nc_qx8(
return status;
}

enum xnn_status xnn_create_convert_nc_bf16_qd8(
uint32_t flags,
xnn_operator_t* convert_op_out) {
return create_convert_nc_qx8(
flags, xnn_init_bf16_to_qs8_cvt_config(), xnn_init_bf16_rminmax_config(),
xnn_init_qs8_rsum_config(), xnn_operator_type_convert_nc_bf16_qd8,
convert_op_out);
}

enum xnn_status xnn_create_convert_nc_f16_qd8(
uint32_t flags,
xnn_operator_t* convert_op_out) {
Expand Down Expand Up @@ -928,6 +937,69 @@ enum xnn_status xnn_create_copy_nc_x32(
xnn_operator_type_copy_nc_x32, copy_op_out);
}

enum xnn_status reshape_convert_nc_bf16_qx8(
xnn_operator_t convert_op,
size_t batch_size,
size_t channels,
size_t input_stride,
size_t output_stride,
enum xnn_operator_type expected_type,
pthreadpool_t threadpool)
{
if (convert_op->type != expected_type) {
xnn_log_error(
"failed to setup operator: operator type mismatch (expected %s, got "
"%s)",
xnn_operator_type_to_string(expected_type),
xnn_operator_type_to_string_v2(convert_op));
return xnn_status_invalid_parameter;
}
convert_op->state = xnn_run_state_invalid;

if (batch_size == 0) {
convert_op->state = xnn_run_state_skip;
return xnn_status_success;
}

convert_op->batch_size = batch_size;

convert_op->context.bf16_qd8_convert = (struct bf16_qd8_convert_context) {
.n = channels * sizeof(uint16_t),
.x_stride = input_stride * sizeof(uint16_t),
.y_stride = output_stride,
.batch_size = batch_size,
.rminmax_ukernel = convert_op->reduce_config->ukernel,
.convert_ukernel = convert_op->unary_elementwise_config->ukernel,
.init_params = convert_op->unary_elementwise_config->init,
};

if (convert_op->flags & XNN_NODE_FLAG_REQUIRES_ROW_SUM) {
convert_op->context.bf16_qd8_convert.rsum_ukernel = convert_op->reduce_config2->ukernel;
}
memcpy(&convert_op->context.bf16_qd8_convert.params, &convert_op->params.bf16_default, sizeof(convert_op->params.bf16_default));

convert_op->compute[0].type = xnn_parallelization_type_1d_tile_1d_dynamic;
switch (expected_type) {
case xnn_operator_type_convert_nc_bf16_qd8:
convert_op->compute[0].task_1d_tile_1d_dynamic =
(pthreadpool_task_1d_tile_1d_dynamic_t)xnn_compute_bf16_qd8_convert;
break;
default:
XNN_UNREACHABLE;
}
convert_op->compute[0].range[0] = batch_size;
convert_op->compute[0].tile[0] = divide_round_up(
get_tile_size(convert_op), convert_op->context.bf16_qd8_convert.n);

convert_op->compute[1].type = xnn_parallelization_type_1d;
convert_op->compute[1].task_1d = (pthreadpool_task_1d_t) xnn_compute_pad_qd8_params;
convert_op->compute[1].range[0] = 1;

convert_op->state = xnn_run_state_needs_setup;

return xnn_status_success;
}

enum xnn_status reshape_convert_nc_f16_qx8(
xnn_operator_t convert_op,
size_t batch_size,
Expand Down Expand Up @@ -1063,6 +1135,12 @@ enum xnn_status reshape_convert_nc_f32_qx8(
return xnn_status_success;
}

enum xnn_status xnn_reshape_convert_nc_bf16_qd8(
xnn_operator_t convert_op, size_t batch_size, size_t channels,
size_t input_stride, size_t output_stride, pthreadpool_t threadpool) {
return reshape_convert_nc_bf16_qx8(convert_op, batch_size, channels, input_stride, output_stride, xnn_operator_type_convert_nc_bf16_qd8, threadpool);
}

enum xnn_status xnn_reshape_convert_nc_f16_qd8(
xnn_operator_t convert_op,
size_t batch_size,
Expand Down Expand Up @@ -1221,6 +1299,47 @@ enum xnn_status xnn_reshape_copy_nc_x32(
threadpool);
}

enum xnn_status setup_convert_nc_bf16_qx8(
xnn_operator_t convert_op,
const void* input,
void* output,
enum xnn_operator_type expected_operator_type,
void* row_sum,
struct xnn_quantization_params* quantization_params)
{
if (convert_op->type != expected_operator_type) {
xnn_log_error(
"failed to setup operator: operator type mismatch (expected %s, got "
"%s)",
xnn_operator_type_to_string(expected_operator_type),
xnn_operator_type_to_string_v2(convert_op));
return xnn_status_invalid_parameter;
}

switch (convert_op->state) {
case xnn_run_state_skip:
return xnn_status_success;
case xnn_run_state_invalid:
xnn_log_error(
"failed to setup %s operator: operator has not been reshaped yet",
xnn_operator_type_to_string_v2(convert_op));
return xnn_status_invalid_state;
case xnn_run_state_needs_setup:
// Operator has been reshaped, but not setup, continue with setup.
case xnn_run_state_ready:
// Operator has been reshaped, and we are setting up with different pointers.
break;
}

convert_op->context.bf16_qd8_convert.x = input;
convert_op->context.bf16_qd8_convert.y = output;
convert_op->context.bf16_qd8_convert.quantization_params = (struct xnn_qd8_quantization_params*) quantization_params;
convert_op->context.bf16_qd8_convert.row_sum = row_sum;
convert_op->state = xnn_run_state_ready;

return xnn_status_success;
}

enum xnn_status setup_convert_nc_f16_qx8(
xnn_operator_t convert_op,
const void* input,
Expand Down Expand Up @@ -1325,6 +1444,16 @@ enum xnn_status xnn_setup_convert_nc_f16_qdu8(
return setup_convert_nc_f16_qx8(convert_op, input, output, xnn_operator_type_convert_nc_f16_qdu8, row_sum, quantization_params);
}

enum xnn_status xnn_setup_convert_nc_bf16_qd8(
xnn_operator_t convert_op,
const void* input,
int8_t* output,
float* row_sum,
struct xnn_quantization_params* quantization_params)
{
return setup_convert_nc_bf16_qx8(convert_op, input, output, xnn_operator_type_convert_nc_bf16_qd8, row_sum, quantization_params);
}

enum xnn_status xnn_setup_convert_nc_f32_qd8(
xnn_operator_t convert_op,
const float* input,
Expand Down
22 changes: 22 additions & 0 deletions src/xnnpack/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,24 @@ XNN_PRIVATE void xnn_compute_slice_4d(struct slice_context* context, size_t i,
XNN_PRIVATE void xnn_compute_slice_5d(struct slice_context* context, size_t i,
size_t j, size_t k, size_t l, size_t m);

struct bf16_qd8_convert_context {
size_t n;
const void* x;
size_t x_stride;
int8_t* y;
size_t y_stride;
size_t batch_size;
struct xnn_qd8_quantization_params* quantization_params;
float* row_sum;
xnn_reduce_ukernel_fn rminmax_ukernel;
xnn_reduce_ukernel_fn rsum_ukernel;
xnn_vunary_ukernel_fn convert_ukernel;
xnn_init_unary_uparams_fn init_params;
union {
struct xnn_bf16_default_params bf16_default;
} params;
};

struct f16_qd8_convert_context {
size_t n;
const void* x;
Expand Down Expand Up @@ -1152,6 +1170,10 @@ struct f32_qd8_convert_context {
} params;
};

XNN_PRIVATE void xnn_compute_bf16_qd8_convert(
struct bf16_qd8_convert_context* context, size_t batch_offset,
size_t batch_range);

XNN_PRIVATE void xnn_compute_f16_qd8_convert(
struct f16_qd8_convert_context* context, size_t batch_offset,
size_t batch_range);
Expand Down
13 changes: 13 additions & 0 deletions src/xnnpack/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,19 @@ enum xnn_status xnn_create_convolution2d_nhwc_pf32(
float output_min, float output_max, uint32_t flags,
xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);

// quantization_params must be padded with at least
// XNN_EXTRA_QUANTIZATION_PARAMS entries.
enum xnn_status xnn_setup_convert_nc_bf16_qd8(
xnn_operator_t convert_op, const void* input, int8_t* output,
float* row_sum, struct xnn_quantization_params* quantization_params);

enum xnn_status xnn_create_convert_nc_bf16_qd8(uint32_t flags,
xnn_operator_t* convert_op_out);

enum xnn_status xnn_reshape_convert_nc_bf16_qd8(
xnn_operator_t convert_op, size_t batch_size, size_t channels,
size_t input_stride, size_t output_stride, pthreadpool_t threadpool);

// quantization_params must be padded with at least
// XNN_EXTRA_QUANTIZATION_PARAMS entries.
enum xnn_status xnn_setup_convert_nc_f16_qdu8(
Expand Down
1 change: 1 addition & 0 deletions src/xnnpack/operator-type-defs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ XNN_ENUM_ITEM(xnn_operator_type_binary_elementwise, "Binary Elementwise (ND)")
XNN_ENUM_ITEM(xnn_operator_type_constant_pad_nd_x8, "Constant Pad (ND, X8)")
XNN_ENUM_ITEM(xnn_operator_type_constant_pad_nd_x16, "Constant Pad (ND, X16)")
XNN_ENUM_ITEM(xnn_operator_type_constant_pad_nd_x32, "Constant Pad (ND, X32)")
XNN_ENUM_ITEM(xnn_operator_type_convert_nc_bf16_qd8, "Convert (NC, BF16, QD8)")
XNN_ENUM_ITEM(xnn_operator_type_convert_nc_f16_qd8, "Convert (NC, F16, QD8)")
XNN_ENUM_ITEM(xnn_operator_type_convert_nc_f16_qdu8, "Convert (NC, F16, QDU8)")
XNN_ENUM_ITEM(xnn_operator_type_convert_nc_f32_qd8, "Convert (NC, F32, QD8)")
Expand Down
2 changes: 2 additions & 0 deletions src/xnnpack/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ struct xnn_convolution_operator {

union xnn_params {
union xnn_binary_uparams binary;
struct xnn_bf16_default_params bf16_default;
struct xnn_f16_default_params f16_default;
struct xnn_f32_default_params f32_default;
struct xnn_f16_minmax_params f16_minmax;
Expand Down Expand Up @@ -353,6 +354,7 @@ struct xnn_operator {
struct transpose_context transpose;
struct floating_point_softmax_context floating_point_softmax;
struct u8_softmax_context u8_softmax;
struct bf16_qd8_convert_context bf16_qd8_convert;
struct f16_qd8_convert_context f16_qd8_convert;
struct f32_qd8_convert_context f32_qd8_convert;
struct f32_qp8_convert_context f32_qp8_convert;
Expand Down
11 changes: 11 additions & 0 deletions src/xnnpack/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,15 @@ xnn_f16_qd8_asymmetric_quantization_params(xnn_float16 min, xnn_float16 max,
return params;
}

static inline struct xnn_qd8_quantization_params
xnn_bf16_qd8_asymmetric_quantization_params(xnn_bfloat16 min, xnn_bfloat16 max,
xnn_bfloat16* bf16_scale) {
struct xnn_qd8_quantization_params params =
xnn_qd8_asymmetric_quantization_params(xnn_bfloat16_to_float(min),
xnn_bfloat16_to_float(max));
*bf16_scale = xnn_bfloat16_from_float(params.inv_scale);
params.inv_scale = 1.f / params.inv_scale;
return params;
}

#endif // XNNPACK_SRC_XNNPACK_QUANTIZATION_H_
1 change: 1 addition & 0 deletions test/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ IF(XNNPACK_BUILD_LIBRARY)
binary-elementwise-nd
constant-pad-nd-eager
constant-pad-nd
convert-nc
convolution-nchw
convolution-nhwc
copy-nc-eager
Expand Down
Loading
Loading