diff --git a/src/operators/reduce-nd.c b/src/operators/reduce-nd.c index cdc737eb0a4..d304fd059c2 100644 --- a/src/operators/reduce-nd.c +++ b/src/operators/reduce-nd.c @@ -241,8 +241,16 @@ static enum xnn_status reshape_reduce_nd( size_t num_reduction_elements; if (normalized_reduction_axes[num_reduction_axes - 1] == num_input_dims - 1) { if (workspace_size != NULL) { - const size_t num_output_elements = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4]; - *workspace_size = (num_output_elements << log2_accumulator_element_size) + XNN_EXTRA_BYTES; + size_t num_output_elements; + size_t tmp; + if (__builtin_mul_overflow(normalized_input_shape[0], normalized_input_shape[2], &tmp) || + __builtin_mul_overflow(tmp, normalized_input_shape[4], &num_output_elements) || + __builtin_mul_overflow(num_output_elements, (size_t)1 << log2_accumulator_element_size, &tmp)) { + xnn_log_error("failed to reshape %s operator: workspace size overflow", + xnn_operator_type_to_string_v2(reduce_op)); + return xnn_status_invalid_parameter; + } + *workspace_size = tmp + XNN_EXTRA_BYTES; } num_reduction_elements = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5]; const size_t axis_dim = normalized_input_shape[5]; @@ -283,8 +291,16 @@ static enum xnn_status reshape_reduce_nd( // Reduction along the non-innermost dimension const size_t channel_like_dim = normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1]; if (workspace_size != NULL) { - const size_t num_output_elements = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5]; - *workspace_size = (num_output_elements << log2_accumulator_element_size) + XNN_EXTRA_BYTES; + size_t num_output_elements; + size_t tmp; + if (__builtin_mul_overflow(normalized_input_shape[1], normalized_input_shape[3], &tmp) || + __builtin_mul_overflow(tmp, normalized_input_shape[5], &num_output_elements) || + __builtin_mul_overflow(num_output_elements, (size_t)1 << log2_accumulator_element_size, &tmp)) { + xnn_log_error("failed to reshape %s operator: workspace size overflow", + xnn_operator_type_to_string_v2(reduce_op)); + return xnn_status_invalid_parameter; + } + *workspace_size = tmp + XNN_EXTRA_BYTES; } num_reduction_elements = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4]; const size_t axis_dim = normalized_input_shape[4];