Skip to content
21 changes: 18 additions & 3 deletions ggml/src/ggml-hexagon/ggml-hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2164,8 +2164,14 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
}

// src0, src1 & dst must be mapped to the same session
if (!hex_supported_buffer(sess, src0, src1, dst)) {
return false;
if(src1){
if (!hex_supported_buffer(sess, src0, src1, dst)) {
return false;
}
}else{
if (!hex_supported_buffer(sess, src0, dst)) {
return false;
}
}

return true;
Expand Down Expand Up @@ -2665,6 +2671,10 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
req.op = HTP_OP_UNARY_SILU;
supported = true;
}
else if (ggml_get_unary_op(dst) == GGML_UNARY_OP_GELU){
req.op = HTP_OP_UNARY_GELU;
supported = true;
}
break;

case GGML_OP_GLU:
Expand All @@ -2680,6 +2690,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
case GGML_OP_SOFT_MAX:
req.op = HTP_OP_SOFTMAX;
supported = true;
break;

default:
break;
Expand Down Expand Up @@ -2959,6 +2970,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
case GGML_OP_UNARY:
if (ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) {
ggml_hexagon_unary(node, flags);
} else if (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU) {
ggml_hexagon_unary(node, flags);
}
break;
case GGML_OP_GLU:
Expand Down Expand Up @@ -3257,7 +3270,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
auto sess = static_cast<ggml_hexagon_session *>(dev->context);

bool supp = false;

switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
Expand Down Expand Up @@ -3297,6 +3309,9 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
if (ggml_get_unary_op(op) == GGML_UNARY_OP_SILU) {
supp = ggml_hexagon_supported_activations(sess, op);
}
else if (ggml_get_unary_op(op) == GGML_UNARY_OP_GELU){
supp = ggml_hexagon_supported_activations(sess, op);
}
break;

case GGML_OP_GLU:
Expand Down
94 changes: 93 additions & 1 deletion ggml/src/ggml-hexagon/htp/act-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,95 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}


static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread) {
htp_act_preamble2;

uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();

const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;

const uint32_t src0_nrows = ne01 * ne02 * ne03;

const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);

// no work for this thread
if (src0_start_row >= src0_end_row) {
return;
}

int is_aligned = 1;
int opt_path = 0;
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
is_aligned = 0;
FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
}

const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;

uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);

const int BLOCK = 8;
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_end = MIN(ir + BLOCK, src0_end_row);

// Prefetch next block
if (block_end < src0_end_row) {
const float * restrict prefetch_ptr = (float *) (data_src0 + (block_end * src0_row_size));
htp_l2fetch(prefetch_ptr, 1, block_end * src0_row_size, src0_row_size);
}

// Process rows in current block
for (uint32_t ib = ir; ib < block_end; ib++) {
const float * restrict src0 = (float *) (data_src0 + (ib * src0_row_size));
float * restrict dst = (float *) (data_dst + (ib * dst_row_size));

// gelu = 0.5 * x * (1.0 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3) )) // gelu_tanh
// gelu = x * sigmoid(1.702 * x) // current implementation
if (1 == opt_path) {
hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0);
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);

hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
}
else {
hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0);
// sigmoid
hvx_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
}
}
}

t2 = HAP_perf_get_qtimer_count();

FARF(HIGH, "gelu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}

static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread);
}



static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
Expand Down Expand Up @@ -371,7 +460,10 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
act_op_func = glu_swiglu_oai_fp32;
op_type = "swiglu-oai-f32";
break;

case HTP_OP_UNARY_GELU:
act_op_func = unary_gelu_fp32;
op_type = "gelu-f32";
break;
default:
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
return HTP_STATUS_NO_SUPPORT;
Expand Down
11 changes: 6 additions & 5 deletions ggml/src/ggml-hexagon/htp/htp-msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ enum htp_op {
HTP_OP_MUL_MAT_ID = 5,
HTP_OP_RMS_NORM = 6,
HTP_OP_UNARY_SILU = 7,
HTP_OP_GLU_SWIGLU = 8,
HTP_OP_GLU_SWIGLU_OAI = 9,
HTP_OP_SOFTMAX = 10,
HTP_OP_ADD_ID = 11,
HTP_OP_ROPE = 12,
HTP_OP_UNARY_GELU = 8,
HTP_OP_GLU_SWIGLU = 9,
HTP_OP_GLU_SWIGLU_OAI = 10,
HTP_OP_SOFTMAX = 11,
HTP_OP_ADD_ID = 12,
HTP_OP_ROPE = 13,
INVALID
};

Expand Down
Loading