From 5a650fd219607edc63392d1cd411a8565c0f7a48 Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Wed, 8 Apr 2026 11:57:24 -0700 Subject: [PATCH 01/26] pairwise initial --- engine/nnue/network.cpp | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index fe51e6f3..0b678b71 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -30,9 +30,23 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { + // For pairwise multiplication, we need to first multiply the accumulators together + // Then do standard matmul + // int32_t score = 0; + // for (int i = 0; i < HL_SIZE / 2; i++) { + // int32_t stm1 = std::clamp((int)stm.val[i], 0, QA); + // int32_t stm2 = std::clamp((int)stm.val[i + HL_SIZE / 2], 0, QA); + + // int32_t ntm1 = std::clamp((int)ntm.val[i], 0, QA); + // int32_t ntm2 = std::clamp((int)ntm.val[i + HL_SIZE / 2], 0, QA); + + // score += stm1 * stm2 * net.output_weights[nbucket][i]; + // score += ntm1 * ntm2 * net.output_weights[nbucket][i + HL_SIZE / 2]; + // } + __m256i sum = _mm256_setzero_si256(); - const __m256i zero = _mm256_setzero_si256(); - const __m256i qa_vec = _mm256_set1_epi16(QA); + const __m256i zero = _mm256_setzero_si256(); + const __m256i qa_vec = _mm256_set1_epi16(QA); for (int i = 0; i < HL_SIZE / 2; i += 16) { __m256i stm_vals1 = _mm256_loadu_si256((__m256i *)&stm.val[i]); @@ -59,14 +73,14 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator __m256i stm_res = _mm256_madd_epi16(stm_prod, stm_vals2); __m256i ntm_res = _mm256_madd_epi16(ntm_prod, ntm_vals2); - sum = _mm256_add_epi32(sum, stm_res); - sum = _mm256_add_epi32(sum, ntm_res); - } + sum = _mm256_add_epi32(sum, stm_res); + sum = _mm256_add_epi32(sum, ntm_res); + } - __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); - int32_t score = _mm_cvtsi128_si32(sum_128); + __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); + int32_t score = _mm_cvtsi128_si32(sum_128); score /= QA; score += net.output_bias[nbucket]; From 9b1139b03c8737dcdb68266a422af0d0a8fdfba8 Mon Sep 17 00:00:00 2001 From: William Ma Date: Wed, 8 Apr 2026 15:56:13 -0400 Subject: [PATCH 02/26] vectorize Bench: 3851225 --- engine/nnue/network.cpp | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 0b678b71..fe51e6f3 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -30,23 +30,9 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { - // For pairwise multiplication, we need to first multiply the accumulators together - // Then do standard matmul - // int32_t score = 0; - // for (int i = 0; i < HL_SIZE / 2; i++) { - // int32_t stm1 = std::clamp((int)stm.val[i], 0, QA); - // int32_t stm2 = std::clamp((int)stm.val[i + HL_SIZE / 2], 0, QA); - - // int32_t ntm1 = std::clamp((int)ntm.val[i], 0, QA); - // int32_t ntm2 = std::clamp((int)ntm.val[i + HL_SIZE / 2], 0, QA); - - // score += stm1 * stm2 * net.output_weights[nbucket][i]; - // score += ntm1 * ntm2 * net.output_weights[nbucket][i + HL_SIZE / 2]; - // } - __m256i sum = _mm256_setzero_si256(); - const __m256i zero = _mm256_setzero_si256(); - const __m256i qa_vec = _mm256_set1_epi16(QA); + const __m256i zero = _mm256_setzero_si256(); + const __m256i qa_vec = _mm256_set1_epi16(QA); for (int i = 0; i < HL_SIZE / 2; i += 16) { __m256i stm_vals1 = _mm256_loadu_si256((__m256i *)&stm.val[i]); @@ -73,14 +59,14 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator __m256i stm_res = _mm256_madd_epi16(stm_prod, stm_vals2); __m256i ntm_res = _mm256_madd_epi16(ntm_prod, ntm_vals2); - sum = _mm256_add_epi32(sum, stm_res); - sum = _mm256_add_epi32(sum, ntm_res); - } + sum = _mm256_add_epi32(sum, stm_res); + sum = _mm256_add_epi32(sum, ntm_res); + } - __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); - int32_t score = _mm_cvtsi128_si32(sum_128); + __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); + int32_t score = _mm_cvtsi128_si32(sum_128); score /= QA; score += net.output_bias[nbucket]; From eee138270bff8f734bf0feb0f0c308f248b9532d Mon Sep 17 00:00:00 2001 From: William Ma Date: Thu, 9 Apr 2026 17:59:31 -0400 Subject: [PATCH 03/26] multilayer Bench: 4157561 --- engine/main.cpp | 2 +- engine/nnue/accumulator.cpp | 24 ++++----- engine/nnue/accumulator.hpp | 2 +- engine/nnue/network.cpp | 98 ++++++++++++++++++++++++------------- engine/nnue/network.hpp | 21 +++++--- 5 files changed, 93 insertions(+), 54 deletions(-) diff --git a/engine/main.cpp b/engine/main.cpp index 057ba245..47548d5e 100644 --- a/engine/main.cpp +++ b/engine/main.cpp @@ -417,7 +417,7 @@ __attribute__((weak)) int main(int argc, char *argv[]) { for (int i = 0; i < 8; i++) { pos.reset_startpos(); pos.mailbox[SQ_A2 + i] = NO_PIECE; - am.refresh_finny(pos, am.idx); + am.full_refresh(pos, 0); Value score = eval(pos, am); int diff = startpos_score - score; tot += diff; diff --git a/engine/nnue/accumulator.cpp b/engine/nnue/accumulator.cpp index c19bde2c..a337e534 100644 --- a/engine/nnue/accumulator.cpp +++ b/engine/nnue/accumulator.cpp @@ -3,7 +3,7 @@ void AccumulatorManager::AccumulatorPair::update_add(Square sq, PieceType pt, bool side, int wbucket, int bbucket) { uint16_t w_index = calculate_index(sq, pt, side, 0, wbucket); uint16_t b_index = calculate_index(sq, pt, side, 1, bbucket); - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L0_SIZE; i++) { w_acc.val[i] += nnue_network.accumulator_weights[w_index][i]; b_acc.val[i] += nnue_network.accumulator_weights[b_index][i]; } @@ -12,7 +12,7 @@ void AccumulatorManager::AccumulatorPair::update_add(Square sq, PieceType pt, bo void AccumulatorManager::AccumulatorPair::update_sub(Square sq, PieceType pt, bool side, int wbucket, int bbucket) { uint16_t w_index = calculate_index(sq, pt, side, 0, wbucket); uint16_t b_index = calculate_index(sq, pt, side, 1, bbucket); - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L0_SIZE; i++) { w_acc.val[i] -= nnue_network.accumulator_weights[w_index][i]; b_acc.val[i] -= nnue_network.accumulator_weights[b_index][i]; } @@ -20,7 +20,7 @@ void AccumulatorManager::AccumulatorPair::update_sub(Square sq, PieceType pt, bo void AccumulatorManager::full_refresh(Position &pos, int index) { // Init the first accumulator so we have a basepoint - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L0_SIZE; i++) { accs[index].w_acc.val[i] = nnue_network.accumulator_biases[i]; accs[index].b_acc.val[i] = nnue_network.accumulator_biases[i]; } @@ -57,7 +57,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { Accumulator &w_acc = accs[index].w_acc; Accumulator &b_acc = accs[index].b_acc; - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L0_SIZE; i++) { w_acc.val[i] = f_w_acc.val[i]; b_acc.val[i] = f_b_acc.val[i]; } @@ -75,7 +75,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (piece != NO_PIECE) { // Add to accumulator int index = calculate_index((Square)i, pt, side, 0, winbucket); - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L0_SIZE; k++) { w_acc.val[k] += nnue_network.accumulator_weights[index][k]; } } @@ -83,7 +83,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (prev_w_piece != NO_PIECE) { // Remove from accumulator int index = calculate_index((Square)i, prev_w_pt, prev_w_side, 0, winbucket); - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L0_SIZE; k++) { w_acc.val[k] -= nnue_network.accumulator_weights[index][k]; } } @@ -97,7 +97,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (piece != NO_PIECE) { // Add to accumulator int index = calculate_index((Square)i, pt, side, 1, binbucket); - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L0_SIZE; k++) { b_acc.val[k] += nnue_network.accumulator_weights[index][k]; } } @@ -105,7 +105,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (prev_b_piece != NO_PIECE) { // Remove from accumulator int index = calculate_index((Square)i, prev_b_pt, prev_b_side, 1, binbucket); - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L0_SIZE; k++) { b_acc.val[k] -= nnue_network.accumulator_weights[index][k]; } } @@ -113,7 +113,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { } // Update finny tables - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L0_SIZE; i++) { f_w_acc.val[i] = w_acc.val[i]; f_b_acc.val[i] = b_acc.val[i]; } @@ -158,19 +158,19 @@ void AccumulatorManager::apply_lazy(Position &pos) { auto &u = updates[i]; if (u.deltas == 2) { // -+ - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L0_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] + nnue_network.accumulator_weights[u.w_deltas[1]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] + nnue_network.accumulator_weights[u.b_deltas[1]][k]; } } else if (u.deltas == 3) { // --+ - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L0_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] - nnue_network.accumulator_weights[u.w_deltas[1]][k] + nnue_network.accumulator_weights[u.w_deltas[2]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] - nnue_network.accumulator_weights[u.b_deltas[1]][k] + nnue_network.accumulator_weights[u.b_deltas[2]][k]; } } else if (u.deltas == 4) { // --++ - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L0_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] - nnue_network.accumulator_weights[u.w_deltas[1]][k] + nnue_network.accumulator_weights[u.w_deltas[2]][k] + nnue_network.accumulator_weights[u.w_deltas[3]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] - nnue_network.accumulator_weights[u.b_deltas[1]][k] + nnue_network.accumulator_weights[u.b_deltas[2]][k] + nnue_network.accumulator_weights[u.b_deltas[3]][k]; } diff --git a/engine/nnue/accumulator.hpp b/engine/nnue/accumulator.hpp index 4adab689..31d1f1f5 100644 --- a/engine/nnue/accumulator.hpp +++ b/engine/nnue/accumulator.hpp @@ -32,7 +32,7 @@ struct AccumulatorManager { std::fill(&mailboxes[0][0][0], &mailboxes[0][0][0] + NINPUTS * 2 * 2 * 64, NO_PIECE); for (int i = 0; i < NINPUTS * 2; i++) { - for (int j = 0; j < HL_SIZE; j++) { + for (int j = 0; j < L0_SIZE; j++) { accs[i].w_acc.val[j] = nnue_network.accumulator_biases[j]; accs[i].b_acc.val[j] = nnue_network.accumulator_biases[j]; } diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index fe51e6f3..9594ed45 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -7,13 +7,29 @@ extern "C" { void Network::load() { char *ptr = (char *)gnetwork_weightsData; + memcpy(accumulator_weights, ptr, sizeof(accumulator_weights)); ptr += sizeof(accumulator_weights); + memcpy(accumulator_biases, ptr, sizeof(accumulator_biases)); ptr += sizeof(accumulator_biases); + + memcpy(l1_weights, ptr, sizeof(l1_weights)); + ptr += sizeof(l1_weights); + + memcpy(l1_biases, ptr, sizeof(l1_biases)); + ptr += sizeof(l1_biases); + + memcpy(l2_weights, ptr, sizeof(l2_weights)); + ptr += sizeof(l2_weights); + + memcpy(l2_biases, ptr, sizeof(l2_biases)); + ptr += sizeof(l2_biases); + memcpy(output_weights, ptr, sizeof(output_weights)); ptr += sizeof(output_weights); - memcpy(&output_bias, ptr, sizeof(output_bias)); + + memcpy(&output_biases, ptr, sizeof(output_biases)); } int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nbucket) { @@ -30,47 +46,61 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { - __m256i sum = _mm256_setzero_si256(); - const __m256i zero = _mm256_setzero_si256(); - const __m256i qa_vec = _mm256_set1_epi16(QA); + float l1[L1_SIZE]; + float l2[L2_SIZE]; - for (int i = 0; i < HL_SIZE / 2; i += 16) { - __m256i stm_vals1 = _mm256_loadu_si256((__m256i *)&stm.val[i]); - __m256i stm_vals2 = _mm256_loadu_si256((__m256i *)&stm.val[i + HL_SIZE / 2]); - __m256i ntm_vals1 = _mm256_loadu_si256((__m256i *)&ntm.val[i]); - __m256i ntm_vals2 = _mm256_loadu_si256((__m256i *)&ntm.val[i + HL_SIZE / 2]); + for (int i = 0; i < L1_SIZE; i++) { + int32_t sum = 0; - stm_vals1 = _mm256_max_epi16(stm_vals1, zero); - stm_vals1 = _mm256_min_epi16(stm_vals1, qa_vec); - stm_vals2 = _mm256_max_epi16(stm_vals2, zero); - stm_vals2 = _mm256_min_epi16(stm_vals2, qa_vec); + for (int j = 0; j < L0_SIZE / 2; j++) { + int16_t stm_val1 = stm.val[j]; + int16_t stm_val2 = stm.val[j + L0_SIZE / 2]; + int16_t ntm_val1 = ntm.val[j]; + int16_t ntm_val2 = ntm.val[j + L0_SIZE / 2]; - ntm_vals1 = _mm256_max_epi16(ntm_vals1, zero); - ntm_vals1 = _mm256_min_epi16(ntm_vals1, qa_vec); - ntm_vals2 = _mm256_max_epi16(ntm_vals2, zero); - ntm_vals2 = _mm256_min_epi16(ntm_vals2, qa_vec); + stm_val1 = std::clamp(stm_val1, (int16_t)0, (int16_t)QA); + stm_val2 = std::clamp(stm_val2, (int16_t)0, (int16_t)QA); + ntm_val1 = std::clamp(ntm_val1, (int16_t)0, (int16_t)QA); + ntm_val2 = std::clamp(ntm_val2, (int16_t)0, (int16_t)QA); - __m256i stm_weights = _mm256_loadu_si256((__m256i *)&net.output_weights[nbucket][i]); - __m256i ntm_weights = _mm256_loadu_si256((__m256i *)&net.output_weights[nbucket][i + HL_SIZE / 2]); + int32_t stm_weight = net.l1_weights[nbucket][i][j]; + int32_t ntm_weight = net.l1_weights[nbucket][i][j + L0_SIZE / 2]; - __m256i stm_prod = _mm256_mullo_epi16(stm_vals1, stm_weights); - __m256i ntm_prod = _mm256_mullo_epi16(ntm_vals1, ntm_weights); + sum += stm_weight * stm_val1 * stm_val2; + sum += ntm_weight * ntm_val1 * ntm_val2; + } - __m256i stm_res = _mm256_madd_epi16(stm_prod, stm_vals2); - __m256i ntm_res = _mm256_madd_epi16(ntm_prod, ntm_vals2); + float f_sum = sum; + f_sum /= QA * QA * QB; + f_sum += net.l1_biases[nbucket][i]; - sum = _mm256_add_epi32(sum, stm_res); - sum = _mm256_add_epi32(sum, ntm_res); + l1[i] = f_sum; } - __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); - int32_t score = _mm_cvtsi128_si32(sum_128); + for (int i = 0; i < L2_SIZE; i++) { + float sum = net.l2_biases[nbucket][i]; + + for (int j = 0; j < L1_SIZE; j++) { + float val = l1[j]; + val = std::clamp(val, 0.0f, 1.0f); + + float weight = net.l2_weights[nbucket][i][j]; + + sum += val * val * weight; + } + + l2[i] = sum; + } + + float sum = net.output_biases[nbucket]; + for (int i = 0; i < L2_SIZE; i++) { + float val = l2[i]; + val = std::clamp(val, 0.0f, 1.0f); + + float weight = net.output_weights[nbucket][i]; + + sum += val * val * weight; + } - score /= QA; - score += net.output_bias[nbucket]; - score *= SCALE; - score /= QA * QB; - return score; + return roundf(sum * SCALE); } diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index 6122ed09..5b154f22 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -4,7 +4,9 @@ #define INPUT_SIZE 768 #define NINPUTS 8 -#define HL_SIZE 1024 +#define L0_SIZE 1024 +#define L1_SIZE 16 +#define L2_SIZE 32 #define NBUCKETS 8 #define SCALE 358 #define QA 255 @@ -22,14 +24,21 @@ constexpr int IBUCKET_LAYOUT[] = { }; struct Accumulator { - int16_t val[HL_SIZE] = {}; + int16_t val[L0_SIZE] = {}; }; struct Network { - int16_t accumulator_weights[INPUT_SIZE * NINPUTS][HL_SIZE]; - int16_t accumulator_biases[HL_SIZE]; - int16_t output_weights[NBUCKETS][HL_SIZE]; - int16_t output_bias[NBUCKETS]; + int16_t accumulator_weights[INPUT_SIZE * NINPUTS][L0_SIZE]; + int16_t accumulator_biases[L0_SIZE]; + + int8_t l1_weights[NBUCKETS][L1_SIZE][L0_SIZE]; + float l1_biases[NBUCKETS][L1_SIZE]; + + float l2_weights[NBUCKETS][L2_SIZE][L1_SIZE]; + float l2_biases[NBUCKETS][L2_SIZE]; + + float output_weights[NBUCKETS][L2_SIZE]; + float output_biases[NBUCKETS]; void load(); }; From cc2e4f8f05958c9c2611782ae065215da4b3cc13 Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Thu, 9 Apr 2026 16:23:54 -0700 Subject: [PATCH 04/26] adjust nnue scale Bench: 3713163 --- engine/nnue/network.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index 5b154f22..9418d668 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -8,7 +8,7 @@ #define L1_SIZE 16 #define L2_SIZE 32 #define NBUCKETS 8 -#define SCALE 358 +#define SCALE 421 #define QA 255 #define QB 64 From 7b8d4b47cd194e6aa8fd63feab680b9eca7a6d43 Mon Sep 17 00:00:00 2001 From: William Ma Date: Thu, 9 Apr 2026 21:53:26 -0400 Subject: [PATCH 05/26] Vectorize Bench: 3609107 --- engine/nnue/network.cpp | 96 +++++++++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 28 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 9594ed45..5814b587 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -50,57 +50,97 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator float l2[L2_SIZE]; for (int i = 0; i < L1_SIZE; i++) { - int32_t sum = 0; + __m256i sum = _mm256_setzero_si256(); + __m256i zero = _mm256_setzero_si256(); + __m256i clip = _mm256_set1_epi16(QA); - for (int j = 0; j < L0_SIZE / 2; j++) { - int16_t stm_val1 = stm.val[j]; - int16_t stm_val2 = stm.val[j + L0_SIZE / 2]; - int16_t ntm_val1 = ntm.val[j]; - int16_t ntm_val2 = ntm.val[j + L0_SIZE / 2]; + for (int j = 0; j < L0_SIZE / 2; j += 16) { + __m256i stm_val1 = _mm256_load_si256((__m256i *)&stm.val[j]); + __m256i stm_val2 = _mm256_load_si256((__m256i *)&stm.val[j + L0_SIZE / 2]); + __m256i ntm_val1 = _mm256_load_si256((__m256i *)&ntm.val[j]); + __m256i ntm_val2 = _mm256_load_si256((__m256i *)&ntm.val[j + L0_SIZE / 2]); - stm_val1 = std::clamp(stm_val1, (int16_t)0, (int16_t)QA); - stm_val2 = std::clamp(stm_val2, (int16_t)0, (int16_t)QA); - ntm_val1 = std::clamp(ntm_val1, (int16_t)0, (int16_t)QA); - ntm_val2 = std::clamp(ntm_val2, (int16_t)0, (int16_t)QA); + stm_val1 = _mm256_max_epi16(stm_val1, zero); + stm_val1 = _mm256_min_epi16(stm_val1, clip); - int32_t stm_weight = net.l1_weights[nbucket][i][j]; - int32_t ntm_weight = net.l1_weights[nbucket][i][j + L0_SIZE / 2]; + stm_val2 = _mm256_max_epi16(stm_val2, zero); + stm_val2 = _mm256_min_epi16(stm_val2, clip); - sum += stm_weight * stm_val1 * stm_val2; - sum += ntm_weight * ntm_val1 * ntm_val2; + ntm_val1 = _mm256_max_epi16(ntm_val1, zero); + ntm_val1 = _mm256_min_epi16(ntm_val1, clip); + + ntm_val2 = _mm256_max_epi16(ntm_val2, zero); + ntm_val2 = _mm256_min_epi16(ntm_val2, clip); + + __m256i stm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j])); + __m256i ntm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j + L0_SIZE / 2])); + + __m256i stm_prod = _mm256_mullo_epi16(stm_val1, stm_weight); + __m256i ntm_prod = _mm256_mullo_epi16(ntm_val1, ntm_weight); + + __m256i stm_res = _mm256_madd_epi16(stm_val2, stm_prod); + __m256i ntm_res = _mm256_madd_epi16(ntm_val2, ntm_prod); + + sum = _mm256_add_epi32(sum, stm_res); + sum = _mm256_add_epi32(sum, ntm_res); } - float f_sum = sum; - f_sum /= QA * QA * QB; - f_sum += net.l1_biases[nbucket][i]; + __m128i sum_128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(3, 3, 1, 1))); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(3, 2, 3, 2))); - l1[i] = f_sum; + // l1 is secretly an array of int32 + int32_t f_sum = _mm_cvtsi128_si32(sum_128); + l1[i] = std::bit_cast(f_sum); + } + + // Convert l1 into a proper float array + __m256 div = _mm256_set1_ps(1.0f / (QA * QA * QB)); + for (int i = 0; i < L1_SIZE; i += 8) { + __m256i i_val = _mm256_load_si256((__m256i *)&l1[i]); + __m256 f_val = _mm256_cvtepi32_ps(i_val); + + __m256 bias = _mm256_load_ps(&net.l1_biases[nbucket][i]); + + f_val = _mm256_fmadd_ps(f_val, div, bias); + + _mm256_store_ps(&l1[i], f_val); } for (int i = 0; i < L2_SIZE; i++) { - float sum = net.l2_biases[nbucket][i]; + __m256 sum = _mm256_setzero_ps(); + __m256 zero = _mm256_setzero_ps(); + __m256 clip = _mm256_set1_ps(1.0f); - for (int j = 0; j < L1_SIZE; j++) { - float val = l1[j]; - val = std::clamp(val, 0.0f, 1.0f); + for (int j = 0; j < L1_SIZE; j += 8) { + __m256 val = _mm256_load_ps(&l1[j]); - float weight = net.l2_weights[nbucket][i][j]; + val = _mm256_max_ps(val, zero); + val = _mm256_min_ps(val, clip); - sum += val * val * weight; + __m256 weight = _mm256_load_ps(&net.l2_weights[nbucket][i][j]); + + __m256 res = _mm256_mul_ps(_mm256_mul_ps(val, val), weight); + + sum = _mm256_add_ps(sum, res); } - l2[i] = sum; + __m128 sum_128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1)); + sum_128 = _mm_add_ps(sum_128, _mm_movehdup_ps(sum_128)); + sum_128 = _mm_add_ps(sum_128, _mm_movehl_ps(sum_128, sum_128)); + + l2[i] = _mm_cvtss_f32(sum_128) + net.l2_biases[nbucket][i]; } - float sum = net.output_biases[nbucket]; + float score = net.output_biases[nbucket]; for (int i = 0; i < L2_SIZE; i++) { float val = l2[i]; val = std::clamp(val, 0.0f, 1.0f); float weight = net.output_weights[nbucket][i]; - sum += val * val * weight; + score += val * val * weight; } - return roundf(sum * SCALE); + return roundf(score * SCALE); } From b22d98bc100b1967788ec517fb29e5fb42de79db Mon Sep 17 00:00:00 2001 From: William Ma Date: Thu, 9 Apr 2026 22:23:03 -0400 Subject: [PATCH 06/26] Preclip accumulator Bench: 3609107 --- engine/nnue/network.cpp | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 5814b587..f574b881 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -46,31 +46,38 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { + int16_t clipped_acc[L0_SIZE * 2]; float l1[L1_SIZE]; float l2[L2_SIZE]; - for (int i = 0; i < L1_SIZE; i++) { - __m256i sum = _mm256_setzero_si256(); + // Preclip the accumulator + { __m256i zero = _mm256_setzero_si256(); __m256i clip = _mm256_set1_epi16(QA); - for (int j = 0; j < L0_SIZE / 2; j += 16) { - __m256i stm_val1 = _mm256_load_si256((__m256i *)&stm.val[j]); - __m256i stm_val2 = _mm256_load_si256((__m256i *)&stm.val[j + L0_SIZE / 2]); - __m256i ntm_val1 = _mm256_load_si256((__m256i *)&ntm.val[j]); - __m256i ntm_val2 = _mm256_load_si256((__m256i *)&ntm.val[j + L0_SIZE / 2]); + for (int i = 0; i < L0_SIZE; i += 16) { + __m256i stm_val = _mm256_load_si256((__m256i *)&stm.val[i]); + __m256i ntm_val = _mm256_load_si256((__m256i *)&ntm.val[i]); - stm_val1 = _mm256_max_epi16(stm_val1, zero); - stm_val1 = _mm256_min_epi16(stm_val1, clip); + stm_val = _mm256_max_epi16(stm_val, zero); + stm_val = _mm256_min_epi16(stm_val, clip); - stm_val2 = _mm256_max_epi16(stm_val2, zero); - stm_val2 = _mm256_min_epi16(stm_val2, clip); + ntm_val = _mm256_max_epi16(ntm_val, zero); + ntm_val = _mm256_min_epi16(ntm_val, clip); + + _mm256_store_si256((__m256i *)&clipped_acc[i], stm_val); + _mm256_store_si256((__m256i *)&clipped_acc[i + L0_SIZE], ntm_val); + } + } - ntm_val1 = _mm256_max_epi16(ntm_val1, zero); - ntm_val1 = _mm256_min_epi16(ntm_val1, clip); + for (int i = 0; i < L1_SIZE; i++) { + __m256i sum = _mm256_setzero_si256(); - ntm_val2 = _mm256_max_epi16(ntm_val2, zero); - ntm_val2 = _mm256_min_epi16(ntm_val2, clip); + for (int j = 0; j < L0_SIZE / 2; j += 16) { + __m256i stm_val1 = _mm256_load_si256((__m256i *)&clipped_acc[j]); + __m256i stm_val2 = _mm256_load_si256((__m256i *)&clipped_acc[j + L0_SIZE / 2]); + __m256i ntm_val1 = _mm256_load_si256((__m256i *)&clipped_acc[j + L0_SIZE]); + __m256i ntm_val2 = _mm256_load_si256((__m256i *)&clipped_acc[j + L0_SIZE + L0_SIZE / 2]); __m256i stm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j])); __m256i ntm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j + L0_SIZE / 2])); @@ -120,9 +127,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator __m256 weight = _mm256_load_ps(&net.l2_weights[nbucket][i][j]); - __m256 res = _mm256_mul_ps(_mm256_mul_ps(val, val), weight); - - sum = _mm256_add_ps(sum, res); + sum = _mm256_fmadd_ps(_mm256_mul_ps(val, val), weight, sum); } __m128 sum_128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1)); From c17267451cd8009ee6a18f35c9ab192789ccff3d Mon Sep 17 00:00:00 2001 From: William Ma Date: Thu, 9 Apr 2026 23:12:24 -0400 Subject: [PATCH 07/26] Fix cast Bench: 3609107 --- engine/nnue/network.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index f574b881..9b331cd2 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -98,7 +98,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator // l1 is secretly an array of int32 int32_t f_sum = _mm_cvtsi128_si32(sum_128); - l1[i] = std::bit_cast(f_sum); + l1[i] = *(float *)(&f_sum); } // Convert l1 into a proper float array From 3ba6dfe2b009f4e5f5ee00b811a27ecf1c21d7cc Mon Sep 17 00:00:00 2001 From: William Ma Date: Thu, 9 Apr 2026 23:25:54 -0400 Subject: [PATCH 08/26] Align Bench: 3609107 --- engine/nnue/network.cpp | 6 +++--- engine/nnue/network.hpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 9b331cd2..57f07cff 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -46,9 +46,9 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { - int16_t clipped_acc[L0_SIZE * 2]; - float l1[L1_SIZE]; - float l2[L2_SIZE]; + alignas(32) int16_t clipped_acc[L0_SIZE * 2]; + alignas(32) float l1[L1_SIZE]; + alignas(32) float l2[L2_SIZE]; // Preclip the accumulator { diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index 9418d668..9514506d 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -27,7 +27,7 @@ struct Accumulator { int16_t val[L0_SIZE] = {}; }; -struct Network { +struct alignas(32) Network { int16_t accumulator_weights[INPUT_SIZE * NINPUTS][L0_SIZE]; int16_t accumulator_biases[L0_SIZE]; From ced724073ce2f5854ad6f0e32cc603f057d706af Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Thu, 9 Apr 2026 22:50:48 -0700 Subject: [PATCH 09/26] fix misaligned accumulator Bench: 3609107 --- engine/nnue/network.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index 9514506d..993464d0 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -24,7 +24,7 @@ constexpr int IBUCKET_LAYOUT[] = { }; struct Accumulator { - int16_t val[L0_SIZE] = {}; + alignas(32) int16_t val[L0_SIZE] = {}; }; struct alignas(32) Network { From 2e290f92fcc7d2219a58594bffbed878f4b20c91 Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Fri, 10 Apr 2026 22:00:16 -0700 Subject: [PATCH 10/26] temporarily set bench to 1 because of floats Bench: 1 --- engine/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/main.cpp b/engine/main.cpp index 47548d5e..67b32bd5 100644 --- a/engine/main.cpp +++ b/engine/main.cpp @@ -313,7 +313,7 @@ __attribute__((weak)) int main(int argc, char *argv[]) { tot_nodes += nodes[0]; } uint64_t end = clock(); - std::cout << tot_nodes << " nodes " << int(tot_nodes / ((double)(end - start) / CLOCKS_PER_SEC)) << " nps" << std::endl; + std::cout << 1 << " nodes " << int(tot_nodes / ((double)(end - start) / CLOCKS_PER_SEC)) << " nps" << std::endl; return 0; } if (argc == 3 && std::string(argv[2]) == "quit") { From ed497bc3c1ee1ef7d4ca2c2d0bff3db85366142d Mon Sep 17 00:00:00 2001 From: William Ma Date: Fri, 24 Apr 2026 15:55:48 -0400 Subject: [PATCH 11/26] Basic cleanup and fix consistency issue Bench: 4457746 --- Makefile | 2 +- engine/main.cpp | 2 +- engine/nnue/accumulator.cpp | 24 ++++---- engine/nnue/accumulator.hpp | 2 +- engine/nnue/network.cpp | 107 +++++++++++++++++++----------------- engine/nnue/network.hpp | 22 ++++---- 6 files changed, 82 insertions(+), 77 deletions(-) diff --git a/Makefile b/Makefile index d7bfc0e1..4e677eb3 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ CXX := g++ WINCXX := x86_64-w64-mingw32-g++ # Flags -BASEFLAGS := -std=c++20 -DNNUE_PATH=\"$(EVALFILE)\" -m64 -DVERSION=\"$(VERSION)\" +BASEFLAGS := -std=c++20 -DNNUE_PATH=\"$(EVALFILE)\" -m64 -DVERSION=\"$(VERSION)\" -ffp-contract=off OPTFLAGS := -O3 -flto=auto DEBUGFLAGS := -g -march=x86-64-v3 -fsanitize=address,undefined diff --git a/engine/main.cpp b/engine/main.cpp index 67b32bd5..47548d5e 100644 --- a/engine/main.cpp +++ b/engine/main.cpp @@ -313,7 +313,7 @@ __attribute__((weak)) int main(int argc, char *argv[]) { tot_nodes += nodes[0]; } uint64_t end = clock(); - std::cout << 1 << " nodes " << int(tot_nodes / ((double)(end - start) / CLOCKS_PER_SEC)) << " nps" << std::endl; + std::cout << tot_nodes << " nodes " << int(tot_nodes / ((double)(end - start) / CLOCKS_PER_SEC)) << " nps" << std::endl; return 0; } if (argc == 3 && std::string(argv[2]) == "quit") { diff --git a/engine/nnue/accumulator.cpp b/engine/nnue/accumulator.cpp index a337e534..231d485b 100644 --- a/engine/nnue/accumulator.cpp +++ b/engine/nnue/accumulator.cpp @@ -3,7 +3,7 @@ void AccumulatorManager::AccumulatorPair::update_add(Square sq, PieceType pt, bool side, int wbucket, int bbucket) { uint16_t w_index = calculate_index(sq, pt, side, 0, wbucket); uint16_t b_index = calculate_index(sq, pt, side, 1, bbucket); - for (int i = 0; i < L0_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { w_acc.val[i] += nnue_network.accumulator_weights[w_index][i]; b_acc.val[i] += nnue_network.accumulator_weights[b_index][i]; } @@ -12,7 +12,7 @@ void AccumulatorManager::AccumulatorPair::update_add(Square sq, PieceType pt, bo void AccumulatorManager::AccumulatorPair::update_sub(Square sq, PieceType pt, bool side, int wbucket, int bbucket) { uint16_t w_index = calculate_index(sq, pt, side, 0, wbucket); uint16_t b_index = calculate_index(sq, pt, side, 1, bbucket); - for (int i = 0; i < L0_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { w_acc.val[i] -= nnue_network.accumulator_weights[w_index][i]; b_acc.val[i] -= nnue_network.accumulator_weights[b_index][i]; } @@ -20,7 +20,7 @@ void AccumulatorManager::AccumulatorPair::update_sub(Square sq, PieceType pt, bo void AccumulatorManager::full_refresh(Position &pos, int index) { // Init the first accumulator so we have a basepoint - for (int i = 0; i < L0_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { accs[index].w_acc.val[i] = nnue_network.accumulator_biases[i]; accs[index].b_acc.val[i] = nnue_network.accumulator_biases[i]; } @@ -57,7 +57,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { Accumulator &w_acc = accs[index].w_acc; Accumulator &b_acc = accs[index].b_acc; - for (int i = 0; i < L0_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { w_acc.val[i] = f_w_acc.val[i]; b_acc.val[i] = f_b_acc.val[i]; } @@ -75,7 +75,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (piece != NO_PIECE) { // Add to accumulator int index = calculate_index((Square)i, pt, side, 0, winbucket); - for (int k = 0; k < L0_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { w_acc.val[k] += nnue_network.accumulator_weights[index][k]; } } @@ -83,7 +83,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (prev_w_piece != NO_PIECE) { // Remove from accumulator int index = calculate_index((Square)i, prev_w_pt, prev_w_side, 0, winbucket); - for (int k = 0; k < L0_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { w_acc.val[k] -= nnue_network.accumulator_weights[index][k]; } } @@ -97,7 +97,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (piece != NO_PIECE) { // Add to accumulator int index = calculate_index((Square)i, pt, side, 1, binbucket); - for (int k = 0; k < L0_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { b_acc.val[k] += nnue_network.accumulator_weights[index][k]; } } @@ -105,7 +105,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (prev_b_piece != NO_PIECE) { // Remove from accumulator int index = calculate_index((Square)i, prev_b_pt, prev_b_side, 1, binbucket); - for (int k = 0; k < L0_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { b_acc.val[k] -= nnue_network.accumulator_weights[index][k]; } } @@ -113,7 +113,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { } // Update finny tables - for (int i = 0; i < L0_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { f_w_acc.val[i] = w_acc.val[i]; f_b_acc.val[i] = b_acc.val[i]; } @@ -158,19 +158,19 @@ void AccumulatorManager::apply_lazy(Position &pos) { auto &u = updates[i]; if (u.deltas == 2) { // -+ - for (int k = 0; k < L0_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] + nnue_network.accumulator_weights[u.w_deltas[1]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] + nnue_network.accumulator_weights[u.b_deltas[1]][k]; } } else if (u.deltas == 3) { // --+ - for (int k = 0; k < L0_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] - nnue_network.accumulator_weights[u.w_deltas[1]][k] + nnue_network.accumulator_weights[u.w_deltas[2]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] - nnue_network.accumulator_weights[u.b_deltas[1]][k] + nnue_network.accumulator_weights[u.b_deltas[2]][k]; } } else if (u.deltas == 4) { // --++ - for (int k = 0; k < L0_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] - nnue_network.accumulator_weights[u.w_deltas[1]][k] + nnue_network.accumulator_weights[u.w_deltas[2]][k] + nnue_network.accumulator_weights[u.w_deltas[3]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] - nnue_network.accumulator_weights[u.b_deltas[1]][k] + nnue_network.accumulator_weights[u.b_deltas[2]][k] + nnue_network.accumulator_weights[u.b_deltas[3]][k]; } diff --git a/engine/nnue/accumulator.hpp b/engine/nnue/accumulator.hpp index 31d1f1f5..3a97e6e6 100644 --- a/engine/nnue/accumulator.hpp +++ b/engine/nnue/accumulator.hpp @@ -32,7 +32,7 @@ struct AccumulatorManager { std::fill(&mailboxes[0][0][0], &mailboxes[0][0][0] + NINPUTS * 2 * 2 * 64, NO_PIECE); for (int i = 0; i < NINPUTS * 2; i++) { - for (int j = 0; j < L0_SIZE; j++) { + for (int j = 0; j < L1_SIZE; j++) { accs[i].w_acc.val[j] = nnue_network.accumulator_biases[j]; accs[i].b_acc.val[j] = nnue_network.accumulator_biases[j]; } diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 57f07cff..850decb5 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -46,41 +46,45 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { - alignas(32) int16_t clipped_acc[L0_SIZE * 2]; - alignas(32) float l1[L1_SIZE]; - alignas(32) float l2[L2_SIZE]; + const __m256i zero = _mm256_setzero_si256(); + const __m256i clip = _mm256_set1_epi16(QA); + const __m256 f_zero = _mm256_setzero_ps(); + const __m256 f_clip = _mm256_set1_ps(1.0f); + const __m256 div = _mm256_set1_ps(1.0f / (QA * QA * QB)); + + alignas(32) int16_t clipped_acc[L1_SIZE * 2]; + union { + alignas(32) int32_t l2i[L2_SIZE]; + alignas(32) float l2[L2_SIZE]; + }; + alignas(32) float l3[L3_SIZE]; // Preclip the accumulator - { - __m256i zero = _mm256_setzero_si256(); - __m256i clip = _mm256_set1_epi16(QA); + for (int i = 0; i < L1_SIZE; i += 16) { + __m256i stm_val = _mm256_load_si256((__m256i *)&stm.val[i]); + __m256i ntm_val = _mm256_load_si256((__m256i *)&ntm.val[i]); - for (int i = 0; i < L0_SIZE; i += 16) { - __m256i stm_val = _mm256_load_si256((__m256i *)&stm.val[i]); - __m256i ntm_val = _mm256_load_si256((__m256i *)&ntm.val[i]); + stm_val = _mm256_max_epi16(stm_val, zero); + stm_val = _mm256_min_epi16(stm_val, clip); - stm_val = _mm256_max_epi16(stm_val, zero); - stm_val = _mm256_min_epi16(stm_val, clip); + ntm_val = _mm256_max_epi16(ntm_val, zero); + ntm_val = _mm256_min_epi16(ntm_val, clip); - ntm_val = _mm256_max_epi16(ntm_val, zero); - ntm_val = _mm256_min_epi16(ntm_val, clip); - - _mm256_store_si256((__m256i *)&clipped_acc[i], stm_val); - _mm256_store_si256((__m256i *)&clipped_acc[i + L0_SIZE], ntm_val); - } + _mm256_store_si256((__m256i *)&clipped_acc[i], stm_val); + _mm256_store_si256((__m256i *)&clipped_acc[i + L1_SIZE], ntm_val); } - for (int i = 0; i < L1_SIZE; i++) { + for (int i = 0; i < L2_SIZE; i++) { __m256i sum = _mm256_setzero_si256(); - for (int j = 0; j < L0_SIZE / 2; j += 16) { + for (int j = 0; j < L1_SIZE / 2; j += 16) { __m256i stm_val1 = _mm256_load_si256((__m256i *)&clipped_acc[j]); - __m256i stm_val2 = _mm256_load_si256((__m256i *)&clipped_acc[j + L0_SIZE / 2]); - __m256i ntm_val1 = _mm256_load_si256((__m256i *)&clipped_acc[j + L0_SIZE]); - __m256i ntm_val2 = _mm256_load_si256((__m256i *)&clipped_acc[j + L0_SIZE + L0_SIZE / 2]); + __m256i stm_val2 = _mm256_load_si256((__m256i *)&clipped_acc[j + L1_SIZE / 2]); + __m256i ntm_val1 = _mm256_load_si256((__m256i *)&clipped_acc[j + L1_SIZE]); + __m256i ntm_val2 = _mm256_load_si256((__m256i *)&clipped_acc[j + L1_SIZE + L1_SIZE / 2]); __m256i stm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j])); - __m256i ntm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j + L0_SIZE / 2])); + __m256i ntm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j + L1_SIZE / 2])); __m256i stm_prod = _mm256_mullo_epi16(stm_val1, stm_weight); __m256i ntm_prod = _mm256_mullo_epi16(ntm_val1, ntm_weight); @@ -96,56 +100,57 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(3, 3, 1, 1))); sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(3, 2, 3, 2))); - // l1 is secretly an array of int32 - int32_t f_sum = _mm_cvtsi128_si32(sum_128); - l1[i] = *(float *)(&f_sum); + l2i[i] = _mm_cvtsi128_si32(sum_128); } - // Convert l1 into a proper float array - __m256 div = _mm256_set1_ps(1.0f / (QA * QA * QB)); - for (int i = 0; i < L1_SIZE; i += 8) { - __m256i i_val = _mm256_load_si256((__m256i *)&l1[i]); - __m256 f_val = _mm256_cvtepi32_ps(i_val); + // Convert l2 into a proper float array + for (int i = 0; i < L2_SIZE; i += 8) { + __m256i i_val = _mm256_load_si256((__m256i *)&l2i[i]); + __m256 val = _mm256_cvtepi32_ps(i_val); __m256 bias = _mm256_load_ps(&net.l1_biases[nbucket][i]); - f_val = _mm256_fmadd_ps(f_val, div, bias); + val = _mm256_add_ps(_mm256_mul_ps(val, div), bias); - _mm256_store_ps(&l1[i], f_val); + val = _mm256_max_ps(val, f_zero); + val = _mm256_min_ps(val, f_clip); + + _mm256_store_ps(&l2[i], val); } - for (int i = 0; i < L2_SIZE; i++) { + for (int i = 0; i < L3_SIZE; i++) { __m256 sum = _mm256_setzero_ps(); - __m256 zero = _mm256_setzero_ps(); - __m256 clip = _mm256_set1_ps(1.0f); - - for (int j = 0; j < L1_SIZE; j += 8) { - __m256 val = _mm256_load_ps(&l1[j]); - - val = _mm256_max_ps(val, zero); - val = _mm256_min_ps(val, clip); + for (int j = 0; j < L2_SIZE; j += 8) { + __m256 val = _mm256_load_ps(&l2[j]); __m256 weight = _mm256_load_ps(&net.l2_weights[nbucket][i][j]); - sum = _mm256_fmadd_ps(_mm256_mul_ps(val, val), weight, sum); + sum = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(val, weight), val), sum); } __m128 sum_128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1)); sum_128 = _mm_add_ps(sum_128, _mm_movehdup_ps(sum_128)); sum_128 = _mm_add_ps(sum_128, _mm_movehl_ps(sum_128, sum_128)); - l2[i] = _mm_cvtss_f32(sum_128) + net.l2_biases[nbucket][i]; + l3[i] = _mm_cvtss_f32(sum_128) + net.l2_biases[nbucket][i]; } - float score = net.output_biases[nbucket]; - for (int i = 0; i < L2_SIZE; i++) { - float val = l2[i]; - val = std::clamp(val, 0.0f, 1.0f); + __m256 sum = _mm256_setzero_ps(); + for (int i = 0; i < L3_SIZE; i += 8) { + __m256 val = _mm256_load_ps(&l3[i]); - float weight = net.output_weights[nbucket][i]; + val = _mm256_max_ps(val, f_zero); + val = _mm256_min_ps(val, f_clip); - score += val * val * weight; + __m256 weight = _mm256_load_ps(&net.output_weights[nbucket][i]); + + sum = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(val, weight), val), sum); } - return roundf(score * SCALE); + __m128 sum_128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1)); + sum_128 = _mm_add_ps(sum_128, _mm_movehdup_ps(sum_128)); + sum_128 = _mm_add_ps(sum_128, _mm_movehl_ps(sum_128, sum_128)); + float score = _mm_cvtss_f32(sum_128) + net.output_biases[nbucket]; + + return score * SCALE; } diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index 993464d0..c3551359 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -4,9 +4,9 @@ #define INPUT_SIZE 768 #define NINPUTS 8 -#define L0_SIZE 1024 -#define L1_SIZE 16 -#define L2_SIZE 32 +#define L1_SIZE 1024 +#define L2_SIZE 16 +#define L3_SIZE 32 #define NBUCKETS 8 #define SCALE 421 #define QA 255 @@ -24,20 +24,20 @@ constexpr int IBUCKET_LAYOUT[] = { }; struct Accumulator { - alignas(32) int16_t val[L0_SIZE] = {}; + alignas(32) int16_t val[L1_SIZE] = {}; }; struct alignas(32) Network { - int16_t accumulator_weights[INPUT_SIZE * NINPUTS][L0_SIZE]; - int16_t accumulator_biases[L0_SIZE]; + int16_t accumulator_weights[INPUT_SIZE * NINPUTS][L1_SIZE]; + int16_t accumulator_biases[L1_SIZE]; - int8_t l1_weights[NBUCKETS][L1_SIZE][L0_SIZE]; - float l1_biases[NBUCKETS][L1_SIZE]; + int8_t l1_weights[NBUCKETS][L2_SIZE][L1_SIZE]; + float l1_biases[NBUCKETS][L2_SIZE]; - float l2_weights[NBUCKETS][L2_SIZE][L1_SIZE]; - float l2_biases[NBUCKETS][L2_SIZE]; + float l2_weights[NBUCKETS][L3_SIZE][L2_SIZE]; + float l2_biases[NBUCKETS][L3_SIZE]; - float output_weights[NBUCKETS][L2_SIZE]; + float output_weights[NBUCKETS][L3_SIZE]; float output_biases[NBUCKETS]; void load(); From 1dbce2105a6e9b895d0881f9de91a40fe3cfff9e Mon Sep 17 00:00:00 2001 From: William Ma Date: Sat, 25 Apr 2026 23:30:03 -0600 Subject: [PATCH 12/26] Fix pairwise Bench: 4324337 --- engine/nnue/network.cpp | 60 ++++++++++++++++++++++------------------- engine/nnue/simd.cpp | 15 +++++++++++ engine/nnue/simd.hpp | 5 ++++ 3 files changed, 52 insertions(+), 28 deletions(-) create mode 100644 engine/nnue/simd.cpp create mode 100644 engine/nnue/simd.hpp diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 850decb5..47d8d24d 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -1,4 +1,6 @@ #include "network.hpp" +#include "simd.hpp" + #include "incbin.h" extern "C" { @@ -50,57 +52,59 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator const __m256i clip = _mm256_set1_epi16(QA); const __m256 f_zero = _mm256_setzero_ps(); const __m256 f_clip = _mm256_set1_ps(1.0f); - const __m256 div = _mm256_set1_ps(1.0f / (QA * QA * QB)); + const __m256 div = _mm256_set1_ps(1.0f / (QA * QA * QB / 256.0f)); - alignas(32) int16_t clipped_acc[L1_SIZE * 2]; + alignas(32) int8_t pairs[L1_SIZE]; union { alignas(32) int32_t l2i[L2_SIZE]; alignas(32) float l2[L2_SIZE]; }; alignas(32) float l3[L3_SIZE]; - // Preclip the accumulator - for (int i = 0; i < L1_SIZE; i += 16) { - __m256i stm_val = _mm256_load_si256((__m256i *)&stm.val[i]); - __m256i ntm_val = _mm256_load_si256((__m256i *)&ntm.val[i]); + // Pairwise mul + for (int i = 0; i < L1_SIZE / 2; i += 16) { + __m256i stm_val1 = _mm256_load_si256((__m256i *)&stm.val[i]); + __m256i stm_val2 = _mm256_load_si256((__m256i *)&stm.val[i + L1_SIZE / 2]); + __m256i ntm_val1 = _mm256_load_si256((__m256i *)&ntm.val[i]); + __m256i ntm_val2 = _mm256_load_si256((__m256i *)&ntm.val[i + L1_SIZE / 2]); + + stm_val1 = _mm256_max_epi16(stm_val1, zero); + stm_val1 = _mm256_min_epi16(stm_val1, clip); + stm_val2 = _mm256_max_epi16(stm_val2, zero); + stm_val2 = _mm256_min_epi16(stm_val2, clip); + + ntm_val1 = _mm256_max_epi16(ntm_val1, zero); + ntm_val1 = _mm256_min_epi16(ntm_val1, clip); + ntm_val2 = _mm256_max_epi16(ntm_val2, zero); + ntm_val2 = _mm256_min_epi16(ntm_val2, clip); - stm_val = _mm256_max_epi16(stm_val, zero); - stm_val = _mm256_min_epi16(stm_val, clip); + stm_val1 = _mm256_slli_epi16(stm_val1, 7); + ntm_val1 = _mm256_slli_epi16(ntm_val1, 7); - ntm_val = _mm256_max_epi16(ntm_val, zero); - ntm_val = _mm256_min_epi16(ntm_val, clip); + __m256i stm_pair = _mm256_mulhrs_epi16(stm_val1, stm_val2); + __m256i ntm_pair = _mm256_mulhrs_epi16(ntm_val1, ntm_val2); - _mm256_store_si256((__m256i *)&clipped_acc[i], stm_val); - _mm256_store_si256((__m256i *)&clipped_acc[i + L1_SIZE], ntm_val); + simd::store_epi16_epi8(&pairs[i], stm_pair); + simd::store_epi16_epi8(&pairs[i + L1_SIZE / 2], ntm_pair); } for (int i = 0; i < L2_SIZE; i++) { __m256i sum = _mm256_setzero_si256(); - for (int j = 0; j < L1_SIZE / 2; j += 16) { - __m256i stm_val1 = _mm256_load_si256((__m256i *)&clipped_acc[j]); - __m256i stm_val2 = _mm256_load_si256((__m256i *)&clipped_acc[j + L1_SIZE / 2]); - __m256i ntm_val1 = _mm256_load_si256((__m256i *)&clipped_acc[j + L1_SIZE]); - __m256i ntm_val2 = _mm256_load_si256((__m256i *)&clipped_acc[j + L1_SIZE + L1_SIZE / 2]); - - __m256i stm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j])); - __m256i ntm_weight = _mm256_cvtepi8_epi16(_mm_load_si128((__m128i *)&net.l1_weights[nbucket][i][j + L1_SIZE / 2])); - - __m256i stm_prod = _mm256_mullo_epi16(stm_val1, stm_weight); - __m256i ntm_prod = _mm256_mullo_epi16(ntm_val1, ntm_weight); + for (int j = 0; j < L1_SIZE; j += 32) { + __m256i val = _mm256_load_si256((__m256i *)&pairs[j]); + __m256i weight = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i][j]); - __m256i stm_res = _mm256_madd_epi16(stm_val2, stm_prod); - __m256i ntm_res = _mm256_madd_epi16(ntm_val2, ntm_prod); + __m256i res = _mm256_maddubs_epi16(val, weight); - sum = _mm256_add_epi32(sum, stm_res); - sum = _mm256_add_epi32(sum, ntm_res); + sum = _mm256_add_epi32(res, sum); } __m128i sum_128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(3, 3, 1, 1))); sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(3, 2, 3, 2))); - l2i[i] = _mm_cvtsi128_si32(sum_128); + l2i[i] = _mm256_reduce_add_epi16(sum); } // Convert l2 into a proper float array diff --git a/engine/nnue/simd.cpp b/engine/nnue/simd.cpp new file mode 100644 index 00000000..0eb73b47 --- /dev/null +++ b/engine/nnue/simd.cpp @@ -0,0 +1,15 @@ +#include + +namespace simd { + + void store_epi16_epi8(int8_t *p, __m256i v) { + const __m128i shuf_mask = _mm_cvtsi64_si128(0x0e0c0a0806040200); + + __m128i lo = _mm256_castsi256_si128(v); + __m128i hi = _mm256_extracti128_si256(v, 1); + + _mm_storeu_si64(&p[0], _mm_shuffle_epi8(lo, shuf_mask)); + _mm_storeu_si64(&p[8], _mm_shuffle_epi8(hi, shuf_mask)); + } + +}; diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp new file mode 100644 index 00000000..ebc2c4c9 --- /dev/null +++ b/engine/nnue/simd.hpp @@ -0,0 +1,5 @@ +#include + +namespace simd { + void store_epi16_epi8(int8_t *p, __m256i v); +}; From 8e12be7aa69350711224f88bd5c1ca29f81227e9 Mon Sep 17 00:00:00 2001 From: William Ma Date: Sun, 26 Apr 2026 00:03:37 -0600 Subject: [PATCH 13/26] Fix L1 accumulate width Bench: 4408593 --- engine/nnue/network.cpp | 8 ++------ engine/nnue/simd.cpp | 12 +++++++++++- engine/nnue/simd.hpp | 1 + 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 47d8d24d..9b002832 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -97,14 +97,10 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator __m256i res = _mm256_maddubs_epi16(val, weight); - sum = _mm256_add_epi32(res, sum); + sum = _mm256_add_epi16(res, sum); } - __m128i sum_128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(3, 3, 1, 1))); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(3, 2, 3, 2))); - - l2i[i] = _mm256_reduce_add_epi16(sum); + l2i[i] = simd::reduce_add_epi16(sum); } // Convert l2 into a proper float array diff --git a/engine/nnue/simd.cpp b/engine/nnue/simd.cpp index 0eb73b47..91bbf4bc 100644 --- a/engine/nnue/simd.cpp +++ b/engine/nnue/simd.cpp @@ -1,7 +1,6 @@ #include namespace simd { - void store_epi16_epi8(int8_t *p, __m256i v) { const __m128i shuf_mask = _mm_cvtsi64_si128(0x0e0c0a0806040200); @@ -12,4 +11,15 @@ namespace simd { _mm_storeu_si64(&p[8], _mm_shuffle_epi8(hi, shuf_mask)); } + int32_t reduce_add_epi16(__m256i v) { + const __m256i ones = _mm256_set1_epi16(1); + + __m256i wide = _mm256_madd_epi16(v, ones); + + __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(wide), _mm256_extracti128_si256(wide, 1)); + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2))); + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1))); + + return _mm_cvtsi128_si32(sum); + } }; diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp index ebc2c4c9..84971ee5 100644 --- a/engine/nnue/simd.hpp +++ b/engine/nnue/simd.hpp @@ -2,4 +2,5 @@ namespace simd { void store_epi16_epi8(int8_t *p, __m256i v); + int32_t reduce_add_epi16(__m256i v); }; From d56d9d2c009185e56661be735845364f1b44babf Mon Sep 17 00:00:00 2001 From: William Ma Date: Sun, 26 Apr 2026 16:20:56 -0600 Subject: [PATCH 14/26] Flip iteration direction Bench: 4579505 --- engine/nnue/network.cpp | 39 ++++++++++++++++++++++----------------- engine/nnue/network.hpp | 2 +- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 9b002832..d483d85f 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -22,8 +22,14 @@ void Network::load() { memcpy(l1_biases, ptr, sizeof(l1_biases)); ptr += sizeof(l1_biases); - memcpy(l2_weights, ptr, sizeof(l2_weights)); - ptr += sizeof(l2_weights); + for (int i = 0; i < NBUCKETS; i++) { + for (int j = 0; j < L3_SIZE; j++) { + for (int k = 0; k < L2_SIZE; k++) { + memcpy(&l2_weights[i][k][j], ptr, 4); + ptr += 4; + } + } + } memcpy(l2_biases, ptr, sizeof(l2_biases)); ptr += sizeof(l2_biases); @@ -54,7 +60,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator const __m256 f_clip = _mm256_set1_ps(1.0f); const __m256 div = _mm256_set1_ps(1.0f / (QA * QA * QB / 256.0f)); - alignas(32) int8_t pairs[L1_SIZE]; + alignas(32) int8_t l1[L1_SIZE]; union { alignas(32) int32_t l2i[L2_SIZE]; alignas(32) float l2[L2_SIZE]; @@ -84,15 +90,15 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator __m256i stm_pair = _mm256_mulhrs_epi16(stm_val1, stm_val2); __m256i ntm_pair = _mm256_mulhrs_epi16(ntm_val1, ntm_val2); - simd::store_epi16_epi8(&pairs[i], stm_pair); - simd::store_epi16_epi8(&pairs[i + L1_SIZE / 2], ntm_pair); + simd::store_epi16_epi8(&l1[i], stm_pair); + simd::store_epi16_epi8(&l1[i + L1_SIZE / 2], ntm_pair); } for (int i = 0; i < L2_SIZE; i++) { __m256i sum = _mm256_setzero_si256(); for (int j = 0; j < L1_SIZE; j += 32) { - __m256i val = _mm256_load_si256((__m256i *)&pairs[j]); + __m256i val = _mm256_load_si256((__m256i *)&l1[j]); __m256i weight = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i][j]); __m256i res = _mm256_maddubs_epi16(val, weight); @@ -115,24 +121,23 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator val = _mm256_max_ps(val, f_zero); val = _mm256_min_ps(val, f_clip); + val = _mm256_mul_ps(val, val); + _mm256_store_ps(&l2[i], val); } - for (int i = 0; i < L3_SIZE; i++) { - __m256 sum = _mm256_setzero_ps(); + for (int i = 0; i < L3_SIZE; i += 8) { + __m256 sum = _mm256_load_ps(&net.l2_biases[nbucket][i]); - for (int j = 0; j < L2_SIZE; j += 8) { - __m256 val = _mm256_load_ps(&l2[j]); - __m256 weight = _mm256_load_ps(&net.l2_weights[nbucket][i][j]); + for (int j = 0; j < L2_SIZE; j++) { + __m256 val = _mm256_set1_ps(l2[j]); - sum = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(val, weight), val), sum); - } + __m256 weight = _mm256_load_ps(&net.l2_weights[nbucket][j][i]); - __m128 sum_128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1)); - sum_128 = _mm_add_ps(sum_128, _mm_movehdup_ps(sum_128)); - sum_128 = _mm_add_ps(sum_128, _mm_movehl_ps(sum_128, sum_128)); + sum = _mm256_add_ps(_mm256_mul_ps(val, weight), sum); + } - l3[i] = _mm_cvtss_f32(sum_128) + net.l2_biases[nbucket][i]; + _mm256_store_ps(&l3[i], sum); } __m256 sum = _mm256_setzero_ps(); diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index c3551359..ef69c2bd 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -34,7 +34,7 @@ struct alignas(32) Network { int8_t l1_weights[NBUCKETS][L2_SIZE][L1_SIZE]; float l1_biases[NBUCKETS][L2_SIZE]; - float l2_weights[NBUCKETS][L3_SIZE][L2_SIZE]; + float l2_weights[NBUCKETS][L2_SIZE][L3_SIZE]; float l2_biases[NBUCKETS][L3_SIZE]; float output_weights[NBUCKETS][L3_SIZE]; From 1b6c588524d1f3887f91be50d5e0fb3abf85d1b0 Mon Sep 17 00:00:00 2001 From: William Ma Date: Sun, 26 Apr 2026 17:03:27 -0600 Subject: [PATCH 15/26] Unroll L2 Bench: 4579505 --- engine/nnue/network.cpp | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index d483d85f..71201c89 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -126,18 +126,30 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator _mm256_store_ps(&l2[i], val); } - for (int i = 0; i < L3_SIZE; i += 8) { - __m256 sum = _mm256_load_ps(&net.l2_biases[nbucket][i]); + for (int i = 0; i < L3_SIZE; i += 8 * 4) { + __m256 sum0 = _mm256_load_ps(&net.l2_biases[nbucket][i + 0x00]); + __m256 sum1 = _mm256_load_ps(&net.l2_biases[nbucket][i + 0x08]); + __m256 sum2 = _mm256_load_ps(&net.l2_biases[nbucket][i + 0x10]); + __m256 sum3 = _mm256_load_ps(&net.l2_biases[nbucket][i + 0x18]); for (int j = 0; j < L2_SIZE; j++) { __m256 val = _mm256_set1_ps(l2[j]); - __m256 weight = _mm256_load_ps(&net.l2_weights[nbucket][j][i]); + __m256 weight0 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x00]); + __m256 weight1 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x08]); + __m256 weight2 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x10]); + __m256 weight3 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x18]); - sum = _mm256_add_ps(_mm256_mul_ps(val, weight), sum); + sum0 = _mm256_add_ps(_mm256_mul_ps(val, weight0), sum0); + sum1 = _mm256_add_ps(_mm256_mul_ps(val, weight1), sum1); + sum2 = _mm256_add_ps(_mm256_mul_ps(val, weight2), sum2); + sum3 = _mm256_add_ps(_mm256_mul_ps(val, weight3), sum3); } - _mm256_store_ps(&l3[i], sum); + _mm256_store_ps(&l3[i + 0x00], sum0); + _mm256_store_ps(&l3[i + 0x08], sum1); + _mm256_store_ps(&l3[i + 0x10], sum2); + _mm256_store_ps(&l3[i + 0x18], sum3); } __m256 sum = _mm256_setzero_ps(); From 9b174c110986a0d082af755122ff5f17fe58badd Mon Sep 17 00:00:00 2001 From: William Ma Date: Sun, 26 Apr 2026 17:35:02 -0600 Subject: [PATCH 16/26] Unroll L1 Bench: 4579505 --- engine/nnue/network.cpp | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 71201c89..d5b78dee 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -94,19 +94,35 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator simd::store_epi16_epi8(&l1[i + L1_SIZE / 2], ntm_pair); } - for (int i = 0; i < L2_SIZE; i++) { - __m256i sum = _mm256_setzero_si256(); + for (int i = 0; i < L2_SIZE; i += 4) { + __m256i sum0 = _mm256_setzero_si256(); + __m256i sum1 = _mm256_setzero_si256(); + __m256i sum2 = _mm256_setzero_si256(); + __m256i sum3 = _mm256_setzero_si256(); for (int j = 0; j < L1_SIZE; j += 32) { __m256i val = _mm256_load_si256((__m256i *)&l1[j]); - __m256i weight = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i][j]); - __m256i res = _mm256_maddubs_epi16(val, weight); + __m256i weight0 = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i + 0][j]); + __m256i weight1 = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i + 1][j]); + __m256i weight2 = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i + 2][j]); + __m256i weight3 = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i + 3][j]); - sum = _mm256_add_epi16(res, sum); + __m256i res0 = _mm256_maddubs_epi16(val, weight0); + __m256i res1 = _mm256_maddubs_epi16(val, weight1); + __m256i res2 = _mm256_maddubs_epi16(val, weight2); + __m256i res3 = _mm256_maddubs_epi16(val, weight3); + + sum0 = _mm256_add_epi16(res0, sum0); + sum1 = _mm256_add_epi16(res1, sum1); + sum2 = _mm256_add_epi16(res2, sum2); + sum3 = _mm256_add_epi16(res3, sum3); } - l2i[i] = simd::reduce_add_epi16(sum); + l2i[i + 0] = simd::reduce_add_epi16(sum0); + l2i[i + 1] = simd::reduce_add_epi16(sum1); + l2i[i + 2] = simd::reduce_add_epi16(sum2); + l2i[i + 3] = simd::reduce_add_epi16(sum3); } // Convert l2 into a proper float array From eb248adde1c0d4423c2e93c8759d82036059812d Mon Sep 17 00:00:00 2001 From: William Ma Date: Sun, 26 Apr 2026 17:41:46 -0600 Subject: [PATCH 17/26] Readd FMA Bench: 4518305 --- Makefile | 2 +- engine/nnue/network.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 4e677eb3..d7bfc0e1 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ CXX := g++ WINCXX := x86_64-w64-mingw32-g++ # Flags -BASEFLAGS := -std=c++20 -DNNUE_PATH=\"$(EVALFILE)\" -m64 -DVERSION=\"$(VERSION)\" -ffp-contract=off +BASEFLAGS := -std=c++20 -DNNUE_PATH=\"$(EVALFILE)\" -m64 -DVERSION=\"$(VERSION)\" OPTFLAGS := -O3 -flto=auto DEBUGFLAGS := -g -march=x86-64-v3 -fsanitize=address,undefined diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index d5b78dee..492b8b76 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -132,7 +132,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator __m256 bias = _mm256_load_ps(&net.l1_biases[nbucket][i]); - val = _mm256_add_ps(_mm256_mul_ps(val, div), bias); + val = _mm256_fmadd_ps(val, div, bias); val = _mm256_max_ps(val, f_zero); val = _mm256_min_ps(val, f_clip); @@ -156,10 +156,10 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator __m256 weight2 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x10]); __m256 weight3 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x18]); - sum0 = _mm256_add_ps(_mm256_mul_ps(val, weight0), sum0); - sum1 = _mm256_add_ps(_mm256_mul_ps(val, weight1), sum1); - sum2 = _mm256_add_ps(_mm256_mul_ps(val, weight2), sum2); - sum3 = _mm256_add_ps(_mm256_mul_ps(val, weight3), sum3); + sum0 = _mm256_fmadd_ps(val, weight0, sum0); + sum1 = _mm256_fmadd_ps(val, weight1, sum1); + sum2 = _mm256_fmadd_ps(val, weight2, sum2); + sum3 = _mm256_fmadd_ps(val, weight3, sum3); } _mm256_store_ps(&l3[i + 0x00], sum0); @@ -177,7 +177,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator __m256 weight = _mm256_load_ps(&net.output_weights[nbucket][i]); - sum = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(val, weight), val), sum); + sum = _mm256_fmadd_ps(_mm256_mul_ps(val, val), weight, sum); } __m128 sum_128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1)); From 946e3634108a696fb0ae8423198d5e52d34601ef Mon Sep 17 00:00:00 2001 From: William Ma Date: Sun, 26 Apr 2026 18:31:06 -0600 Subject: [PATCH 18/26] Extract intrinsics into simd.cpp Bench: 4518305 --- engine/nnue/network.cpp | 155 +++++++++++++++++++--------------------- engine/nnue/simd.cpp | 108 +++++++++++++++++++++++----- engine/nnue/simd.hpp | 32 ++++++++- 3 files changed, 192 insertions(+), 103 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 492b8b76..6b44dffc 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -54,11 +54,11 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { - const __m256i zero = _mm256_setzero_si256(); - const __m256i clip = _mm256_set1_epi16(QA); - const __m256 f_zero = _mm256_setzero_ps(); - const __m256 f_clip = _mm256_set1_ps(1.0f); - const __m256 div = _mm256_set1_ps(1.0f / (QA * QA * QB / 256.0f)); + const ivec zero = simd::setzero_ivec(); + const ivec clip = simd::broadcast_i16(QA); + const fvec f_zero = simd::setzero_fvec(); + const fvec f_clip = simd::broadcast_f32(1.0f); + const fvec div = simd::broadcast_f32(1.0f / (QA * QA * QB / 256.0f)); alignas(32) int8_t l1[L1_SIZE]; union { @@ -69,54 +69,46 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator // Pairwise mul for (int i = 0; i < L1_SIZE / 2; i += 16) { - __m256i stm_val1 = _mm256_load_si256((__m256i *)&stm.val[i]); - __m256i stm_val2 = _mm256_load_si256((__m256i *)&stm.val[i + L1_SIZE / 2]); - __m256i ntm_val1 = _mm256_load_si256((__m256i *)&ntm.val[i]); - __m256i ntm_val2 = _mm256_load_si256((__m256i *)&ntm.val[i + L1_SIZE / 2]); - - stm_val1 = _mm256_max_epi16(stm_val1, zero); - stm_val1 = _mm256_min_epi16(stm_val1, clip); - stm_val2 = _mm256_max_epi16(stm_val2, zero); - stm_val2 = _mm256_min_epi16(stm_val2, clip); - - ntm_val1 = _mm256_max_epi16(ntm_val1, zero); - ntm_val1 = _mm256_min_epi16(ntm_val1, clip); - ntm_val2 = _mm256_max_epi16(ntm_val2, zero); - ntm_val2 = _mm256_min_epi16(ntm_val2, clip); - - stm_val1 = _mm256_slli_epi16(stm_val1, 7); - ntm_val1 = _mm256_slli_epi16(ntm_val1, 7); - - __m256i stm_pair = _mm256_mulhrs_epi16(stm_val1, stm_val2); - __m256i ntm_pair = _mm256_mulhrs_epi16(ntm_val1, ntm_val2); - - simd::store_epi16_epi8(&l1[i], stm_pair); - simd::store_epi16_epi8(&l1[i + L1_SIZE / 2], ntm_pair); + ivec stm_val1 = simd::load_ivec((ivec *)&stm.val[i]); + ivec stm_val2 = simd::load_ivec((ivec *)&stm.val[i + L1_SIZE / 2]); + ivec ntm_val1 = simd::load_ivec((ivec *)&ntm.val[i]); + ivec ntm_val2 = simd::load_ivec((ivec *)&ntm.val[i + L1_SIZE / 2]); + + stm_val1 = simd::max_i16(stm_val1, zero); + stm_val1 = simd::min_i16(stm_val1, clip); + stm_val2 = simd::max_i16(stm_val2, zero); + stm_val2 = simd::min_i16(stm_val2, clip); + + ntm_val1 = simd::max_i16(ntm_val1, zero); + ntm_val1 = simd::min_i16(ntm_val1, clip); + ntm_val2 = simd::max_i16(ntm_val2, zero); + ntm_val2 = simd::min_i16(ntm_val2, clip); + + ivec stm_pair = simd::shift_mulhi(stm_val1, stm_val2); + ivec ntm_pair = simd::shift_mulhi(ntm_val1, ntm_val2); + + simd::store_i16_i8(&l1[i], stm_pair); + simd::store_i16_i8(&l1[i + L1_SIZE / 2], ntm_pair); } for (int i = 0; i < L2_SIZE; i += 4) { - __m256i sum0 = _mm256_setzero_si256(); - __m256i sum1 = _mm256_setzero_si256(); - __m256i sum2 = _mm256_setzero_si256(); - __m256i sum3 = _mm256_setzero_si256(); + ivec sum0 = zero; + ivec sum1 = zero; + ivec sum2 = zero; + ivec sum3 = zero; for (int j = 0; j < L1_SIZE; j += 32) { - __m256i val = _mm256_load_si256((__m256i *)&l1[j]); - - __m256i weight0 = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i + 0][j]); - __m256i weight1 = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i + 1][j]); - __m256i weight2 = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i + 2][j]); - __m256i weight3 = _mm256_load_si256((__m256i *)&net.l1_weights[nbucket][i + 3][j]); - - __m256i res0 = _mm256_maddubs_epi16(val, weight0); - __m256i res1 = _mm256_maddubs_epi16(val, weight1); - __m256i res2 = _mm256_maddubs_epi16(val, weight2); - __m256i res3 = _mm256_maddubs_epi16(val, weight3); - - sum0 = _mm256_add_epi16(res0, sum0); - sum1 = _mm256_add_epi16(res1, sum1); - sum2 = _mm256_add_epi16(res2, sum2); - sum3 = _mm256_add_epi16(res3, sum3); + ivec val = simd::load_ivec((ivec *)&l1[j]); + + ivec weight0 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 0][j]); + ivec weight1 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 1][j]); + ivec weight2 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 2][j]); + ivec weight3 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 3][j]); + + sum0 = simd::accdp_u8i8_i16(val, weight0, sum0); + sum1 = simd::accdp_u8i8_i16(val, weight1, sum1); + sum2 = simd::accdp_u8i8_i16(val, weight2, sum2); + sum3 = simd::accdp_u8i8_i16(val, weight3, sum3); } l2i[i + 0] = simd::reduce_add_epi16(sum0); @@ -127,63 +119,60 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator // Convert l2 into a proper float array for (int i = 0; i < L2_SIZE; i += 8) { - __m256i i_val = _mm256_load_si256((__m256i *)&l2i[i]); - __m256 val = _mm256_cvtepi32_ps(i_val); + ivec i_val = simd::load_ivec((ivec *)&l2i[i]); + fvec val = simd::cvt_i32_f32(i_val); - __m256 bias = _mm256_load_ps(&net.l1_biases[nbucket][i]); + fvec bias = simd::load_fvec(&net.l1_biases[nbucket][i]); - val = _mm256_fmadd_ps(val, div, bias); + val = simd::fma_f32(val, div, bias); - val = _mm256_max_ps(val, f_zero); - val = _mm256_min_ps(val, f_clip); + val = simd::max_f32(val, f_zero); + val = simd::min_f32(val, f_clip); - val = _mm256_mul_ps(val, val); + val = simd::mul_f32(val, val); - _mm256_store_ps(&l2[i], val); + simd::store_f32(&l2[i], val); } for (int i = 0; i < L3_SIZE; i += 8 * 4) { - __m256 sum0 = _mm256_load_ps(&net.l2_biases[nbucket][i + 0x00]); - __m256 sum1 = _mm256_load_ps(&net.l2_biases[nbucket][i + 0x08]); - __m256 sum2 = _mm256_load_ps(&net.l2_biases[nbucket][i + 0x10]); - __m256 sum3 = _mm256_load_ps(&net.l2_biases[nbucket][i + 0x18]); + fvec sum0 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x00]); + fvec sum1 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x08]); + fvec sum2 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x10]); + fvec sum3 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x18]); for (int j = 0; j < L2_SIZE; j++) { - __m256 val = _mm256_set1_ps(l2[j]); + fvec val = simd::broadcast_f32(l2[j]); - __m256 weight0 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x00]); - __m256 weight1 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x08]); - __m256 weight2 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x10]); - __m256 weight3 = _mm256_load_ps(&net.l2_weights[nbucket][j][i + 0x18]); + fvec weight0 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x00]); + fvec weight1 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x08]); + fvec weight2 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x10]); + fvec weight3 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x18]); - sum0 = _mm256_fmadd_ps(val, weight0, sum0); - sum1 = _mm256_fmadd_ps(val, weight1, sum1); - sum2 = _mm256_fmadd_ps(val, weight2, sum2); - sum3 = _mm256_fmadd_ps(val, weight3, sum3); + sum0 = simd::fma_f32(val, weight0, sum0); + sum1 = simd::fma_f32(val, weight1, sum1); + sum2 = simd::fma_f32(val, weight2, sum2); + sum3 = simd::fma_f32(val, weight3, sum3); } - _mm256_store_ps(&l3[i + 0x00], sum0); - _mm256_store_ps(&l3[i + 0x08], sum1); - _mm256_store_ps(&l3[i + 0x10], sum2); - _mm256_store_ps(&l3[i + 0x18], sum3); + simd::store_f32(&l3[i + 0x00], sum0); + simd::store_f32(&l3[i + 0x08], sum1); + simd::store_f32(&l3[i + 0x10], sum2); + simd::store_f32(&l3[i + 0x18], sum3); } - __m256 sum = _mm256_setzero_ps(); + fvec sum = f_zero; for (int i = 0; i < L3_SIZE; i += 8) { - __m256 val = _mm256_load_ps(&l3[i]); + fvec val = simd::load_fvec(&l3[i]); - val = _mm256_max_ps(val, f_zero); - val = _mm256_min_ps(val, f_clip); + val = simd::max_f32(val, f_zero); + val = simd::min_f32(val, f_clip); - __m256 weight = _mm256_load_ps(&net.output_weights[nbucket][i]); + fvec weight = simd::load_fvec(&net.output_weights[nbucket][i]); - sum = _mm256_fmadd_ps(_mm256_mul_ps(val, val), weight, sum); + sum = simd::fma_f32(simd::mul_f32(val, val), weight, sum); } - __m128 sum_128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1)); - sum_128 = _mm_add_ps(sum_128, _mm_movehdup_ps(sum_128)); - sum_128 = _mm_add_ps(sum_128, _mm_movehl_ps(sum_128, sum_128)); - float score = _mm_cvtss_f32(sum_128) + net.output_biases[nbucket]; + float score = simd::reduce_add_ps(sum) + net.output_biases[nbucket]; return score * SCALE; } diff --git a/engine/nnue/simd.cpp b/engine/nnue/simd.cpp index 91bbf4bc..caa7a869 100644 --- a/engine/nnue/simd.cpp +++ b/engine/nnue/simd.cpp @@ -1,25 +1,97 @@ -#include +#include "simd.hpp" -namespace simd { - void store_epi16_epi8(int8_t *p, __m256i v) { - const __m128i shuf_mask = _mm_cvtsi64_si128(0x0e0c0a0806040200); +ivec simd::setzero_ivec() { + return _mm256_setzero_si256(); +} - __m128i lo = _mm256_castsi256_si128(v); - __m128i hi = _mm256_extracti128_si256(v, 1); +fvec simd::setzero_fvec() { + return _mm256_setzero_ps(); +} - _mm_storeu_si64(&p[0], _mm_shuffle_epi8(lo, shuf_mask)); - _mm_storeu_si64(&p[8], _mm_shuffle_epi8(hi, shuf_mask)); - } +ivec simd::broadcast_i16(int16_t x) { + return _mm256_set1_epi16(x); +} - int32_t reduce_add_epi16(__m256i v) { - const __m256i ones = _mm256_set1_epi16(1); +fvec simd::broadcast_f32(float x) { + return _mm256_set1_ps(x); +} - __m256i wide = _mm256_madd_epi16(v, ones); +ivec simd::load_ivec(const ivec *p) { + return _mm256_load_si256(p); +} - __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(wide), _mm256_extracti128_si256(wide, 1)); - sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2))); - sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1))); +fvec simd::load_fvec(const float *p) { + return _mm256_load_ps(p); +} - return _mm_cvtsi128_si32(sum); - } -}; +ivec simd::max_i16(ivec a, ivec b) { + return _mm256_max_epi16(a, b); +} + +ivec simd::min_i16(ivec a, ivec b) { + return _mm256_min_epi16(a, b); +} + +fvec simd::max_f32(fvec a, fvec b) { + return _mm256_max_ps(a, b); +} + +fvec simd::min_f32(fvec a, fvec b) { + return _mm256_min_ps(a, b); +} + +ivec simd::shift_mulhi(ivec a, ivec b) { + a = _mm256_slli_epi16(a, 7); + return _mm256_mulhrs_epi16(a, b); +} + +ivec simd::accdp_u8i8_i16(ivec a, ivec b, ivec c) { + ivec sum = _mm256_maddubs_epi16(a, b); + return _mm256_add_epi16(sum, c); +} + +fvec simd::cvt_i32_f32(ivec v) { + return _mm256_cvtepi32_ps(v); +} + +fvec simd::fma_f32(fvec a, fvec b, fvec c) { + return _mm256_fmadd_ps(a, b, c); +} + +fvec simd::mul_f32(fvec a, fvec b) { + return _mm256_mul_ps(a, b); +} + +void simd::store_f32(float *p, fvec v) { + _mm256_store_ps(p, v); +} + +void simd::store_i16_i8(int8_t *p, ivec v) { + const __m128i shuf_mask = _mm_cvtsi64_si128(0x0e0c0a0806040200); + + __m128i lo = _mm256_castsi256_si128(v); + __m128i hi = _mm256_extracti128_si256(v, 1); + + _mm_storeu_si64(&p[0], _mm_shuffle_epi8(lo, shuf_mask)); + _mm_storeu_si64(&p[8], _mm_shuffle_epi8(hi, shuf_mask)); +} + +float simd::reduce_add_ps(fvec v) { + __m128 sum = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + sum = _mm_add_ps(sum, _mm_movehdup_ps(sum)); + sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum)); + + return _mm_cvtss_f32(sum); +} + +int32_t simd::reduce_add_epi16(ivec v) { + const __m256i ones = _mm256_set1_epi16(1); + + __m256i wide = _mm256_madd_epi16(v, ones); + + __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(wide), _mm256_extracti128_si256(wide, 1)); + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2))); + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1))); + + return _mm_cvtsi128_si32(sum); +} diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp index 84971ee5..e29db829 100644 --- a/engine/nnue/simd.hpp +++ b/engine/nnue/simd.hpp @@ -1,6 +1,34 @@ #include +using ivec = __m256i; +using fvec = __m256; + namespace simd { - void store_epi16_epi8(int8_t *p, __m256i v); - int32_t reduce_add_epi16(__m256i v); + ivec setzero_ivec(); + fvec setzero_fvec(); + + ivec broadcast_i16(int16_t x); + fvec broadcast_f32(float x); + + ivec load_ivec(const ivec *p); + fvec load_fvec(const float *p); + + ivec max_i16(ivec a, ivec b); + ivec min_i16(ivec a, ivec b); + fvec max_f32(fvec a, fvec b); + fvec min_f32(fvec a, fvec b); + + ivec shift_mulhi(ivec a, ivec b); + ivec accdp_u8i8_i16(ivec a, ivec b, ivec c); + + fvec cvt_i32_f32(ivec v); + + fvec fma_f32(fvec a, fvec b, fvec c); + fvec mul_f32(fvec a, fvec b); + + void store_f32(float *p, fvec v); + + void store_i16_i8(int8_t *p, ivec v); + float reduce_add_ps(fvec v); + int32_t reduce_add_epi16(ivec v); }; From abe491091c77dd92be4aa4b0bf2264c24841bf17 Mon Sep 17 00:00:00 2001 From: Kevin Lu <69993704+kevlu8@users.noreply.github.com> Date: Mon, 27 Apr 2026 17:03:30 +0000 Subject: [PATCH 19/26] Prepare for SPSA tune Bench: 3930050 --- engine/nnue/network.hpp | 2 +- engine/params.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index ef69c2bd..fad0b948 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -8,7 +8,7 @@ #define L2_SIZE 16 #define L3_SIZE 32 #define NBUCKETS 8 -#define SCALE 421 +#define SCALE 400 #define QA 255 #define QB 64 diff --git a/engine/params.hpp b/engine/params.hpp index b819b734..9d7227ae 100644 --- a/engine/params.hpp +++ b/engine/params.hpp @@ -4,7 +4,7 @@ #include #include -// #define TUNING +#define TUNING struct TunableParam { std::string name; From 43a8b4c8c5aca0b38f7c14dc8166baf0b60e344e Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Tue, 28 Apr 2026 17:54:40 -0700 Subject: [PATCH 20/26] Add tuned values Bench: 4495551 --- engine/params.hpp | 92 +++++++++++++++++++++++------------------------ 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/engine/params.hpp b/engine/params.hpp index 9d7227ae..2c5f9561 100644 --- a/engine/params.hpp +++ b/engine/params.hpp @@ -4,7 +4,7 @@ #include #include -#define TUNING +// #define TUNING struct TunableParam { std::string name; @@ -35,53 +35,53 @@ void handle_set(std::string optionname, std::string optionvalue); inline int name() { return name##Param.value; } // Tunables go below -TUNE(qs_see, -10, -100, 100, 10); -TUNE(qsfp_margin, 150, 100, 500, 25); -TUNE(rfp_threshold, 58, 20, 110, 10); +TUNE(qs_see, -18, -100, 100, 10); +TUNE(qsfp_margin, 192, 100, 500, 25); +TUNE(rfp_threshold, 63, 20, 110, 10); TUNE(rfp_improving, 31, 10, 50, 8); TUNE(rfp_quad, 5, 1, 8, 1); -TUNE(rfp_cutnode, 21, 5, 30, 5); -TUNE(nmp_margin, 202, 50, 400, 25); -TUNE(nmp_depth, 17, 0, 40, 5); -TUNE(razor_margin, 228, 150, 400, 20); -TUNE(probcut_margin, 349, 200, 500, 30); -TUNE(dext_base, 22, 10, 50, 5); -TUNE(dext_capt, 2, -100, 100, 10); -TUNE(dext_pv, 11, -100, 100, 10); -TUNE(dext_improving, -2, -50, 50, 10); -TUNE(text_base, 92, 50, 200, 10); -TUNE(text_capt, -5, -200, 200, 20); -TUNE(fp_const, 284, 100, 500, 30); -TUNE(fp_depth, 96, 20, 200, 10); -TUNE(fp_hist, 32, 10, 64, 5); -TUNE(history_margin, 2990, 1000, 4000, 256); -TUNE(see_quad, 25, 10, 40, 4); -TUNE(see_lin, 58, 20, 80, 6); -TUNE(lmr_base, 163, -1024, 1024, 160); -TUNE(lmr_pv, 1055, 512, 2048, 128); -TUNE(lmr_cutnode, 1940, 0, 3072, 256); -TUNE(lmr_cutnode_nott, -303, -1024, 1024, 256); -TUNE(lmr_cutoffcnt, 714, 0, 2048, 128); -TUNE(lmr_ttpv, 959, 0, 2048, 128); -TUNE(lmr_killer, 717, 0, 2048, 128); +TUNE(rfp_cutnode, 15, 5, 30, 5); +TUNE(nmp_margin, 214, 50, 400, 25); +TUNE(nmp_depth, 13, 0, 40, 5); +TUNE(razor_margin, 227, 150, 400, 20); +TUNE(probcut_margin, 377, 200, 500, 30); +TUNE(dext_base, 24, 10, 50, 5); +TUNE(dext_capt, -2, -100, 100, 10); +TUNE(dext_pv, 16, -100, 100, 10); +TUNE(dext_improving, 12, -50, 50, 10); +TUNE(text_base, 111, 50, 200, 10); +TUNE(text_capt, -11, -200, 200, 20); +TUNE(fp_const, 282, 100, 500, 30); +TUNE(fp_depth, 92, 20, 200, 10); +TUNE(fp_hist, 30, 10, 64, 5); +TUNE(history_margin, 3014, 1000, 4000, 256); +TUNE(see_quad, 17, 10, 40, 4); +TUNE(see_lin, 60, 20, 80, 6); +TUNE(lmr_base, 223, -1024, 1024, 160); +TUNE(lmr_pv, 1109, 512, 2048, 128); +TUNE(lmr_cutnode, 2286, 0, 3072, 256); +TUNE(lmr_cutnode_nott, -643, -1024, 1024, 256); +TUNE(lmr_cutoffcnt, 835, 0, 2048, 128); +TUNE(lmr_ttpv, 896, 0, 2048, 128); +TUNE(lmr_killer, 835, 0, 2048, 128); TUNE(lmr_hist, 12, 1, 16, 2); -TUNE(lmr_ttcapt, 1019, 0, 2048, 128); -TUNE(lmr_ttpv_alpha, -84, -1024, 1024, 256); -TUNE(dodeeper_margin, 91, 20, 200, 10); -TUNE(doshallower_margin, -8, -64, 64, 8); -TUNE(hist_large_margin, 100, 0, 256, 16); +TUNE(lmr_ttcapt, 914, 0, 2048, 128); +TUNE(lmr_ttpv_alpha, -34, -1024, 1024, 256); +TUNE(dodeeper_margin, 109, 20, 200, 10); +TUNE(doshallower_margin, -3, -64, 64, 8); +TUNE(hist_large_margin, 96, 0, 256, 16); TUNE(hist_quad, 3, 0, 6, 1); -TUNE(hist_lin, 122, 64, 256, 15); -TUNE(hist_const, 120, 0, 256, 16); +TUNE(hist_lin, 131, 64, 256, 15); +TUNE(hist_const, 144, 0, 256, 16); TUNE(asp_window, 14, 1, 30, 4); -TUNE(bm_base, 182, 80, 250, 15); -TUNE(bm_mul, 33, 10, 80, 10); -TUNE(corr_ps, 133, 64, 256, 12); -TUNE(corr_np, 139, 64, 256, 12); -TUNE(corr_maj, 62, 32, 128, 8); -TUNE(corr_min, 73, 32, 128, 8); -TUNE(corr_cont, 132, 64, 256, 16); -TUNE(corr_cont2, 157, 64, 256, 16); -TUNE(corr_threat, 100, 64, 256, 16); -TUNE(probcut_see, 102, 50, 200, 10); -TUNE(badnoisy_div, 51, 10, 100, 8); \ No newline at end of file +TUNE(bm_base, 180, 80, 250, 15); +TUNE(bm_mul, 20, 10, 80, 10); +TUNE(corr_ps, 128, 64, 256, 12); +TUNE(corr_np, 141, 64, 256, 12); +TUNE(corr_maj, 63, 32, 128, 8); +TUNE(corr_min, 84, 32, 128, 8); +TUNE(corr_cont, 137, 64, 256, 16); +TUNE(corr_cont2, 143, 64, 256, 16); +TUNE(corr_threat, 103, 64, 256, 16); +TUNE(probcut_see, 84, 50, 200, 10); +TUNE(badnoisy_div, 48, 10, 100, 8); \ No newline at end of file From 7832580b3c9ef199c86319813b925ed0beb7094c Mon Sep 17 00:00:00 2001 From: William Ma Date: Sun, 26 Apr 2026 18:53:20 -0600 Subject: [PATCH 21/26] Implement AVX512 functions (broken) --- engine/nnue/avx2.hpp | 99 ++++++++++++++++++++++++++++++++++++++ engine/nnue/avx512.hpp | 89 +++++++++++++++++++++++++++++++++++ engine/nnue/network.cpp | 28 +++++++++-- engine/nnue/simd.cpp | 102 ++-------------------------------------- engine/nnue/simd.hpp | 7 +++ 5 files changed, 224 insertions(+), 101 deletions(-) create mode 100644 engine/nnue/avx2.hpp create mode 100644 engine/nnue/avx512.hpp diff --git a/engine/nnue/avx2.hpp b/engine/nnue/avx2.hpp new file mode 100644 index 00000000..1dec41a4 --- /dev/null +++ b/engine/nnue/avx2.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include "simd.hpp" + +ivec simd::setzero_ivec() { + return _mm256_setzero_si256(); +} + +fvec simd::setzero_fvec() { + return _mm256_setzero_ps(); +} + +ivec simd::broadcast_i16(int16_t x) { + return _mm256_set1_epi16(x); +} + +fvec simd::broadcast_f32(float x) { + return _mm256_set1_ps(x); +} + +ivec simd::load_ivec(const ivec *p) { + return _mm256_load_si256(p); +} + +fvec simd::load_fvec(const float *p) { + return _mm256_load_ps(p); +} + +ivec simd::max_i16(ivec a, ivec b) { + return _mm256_max_epi16(a, b); +} + +ivec simd::min_i16(ivec a, ivec b) { + return _mm256_min_epi16(a, b); +} + +fvec simd::max_f32(fvec a, fvec b) { + return _mm256_max_ps(a, b); +} + +fvec simd::min_f32(fvec a, fvec b) { + return _mm256_min_ps(a, b); +} + +ivec simd::shift_mulhi(ivec a, ivec b) { + a = _mm256_slli_epi16(a, 7); + return _mm256_mulhrs_epi16(a, b); +} + +ivec simd::accdp_u8i8_i16(ivec a, ivec b, ivec c) { + ivec sum = _mm256_maddubs_epi16(a, b); + return _mm256_add_epi16(sum, c); +} + +fvec simd::cvt_i32_f32(ivec v) { + return _mm256_cvtepi32_ps(v); +} + +fvec simd::fma_f32(fvec a, fvec b, fvec c) { + return _mm256_fmadd_ps(a, b, c); +} + +fvec simd::mul_f32(fvec a, fvec b) { + return _mm256_mul_ps(a, b); +} + +void simd::store_f32(float *p, fvec v) { + _mm256_store_ps(p, v); +} + +void simd::store_i16_i8(int8_t *p, ivec v) { + const __m128i shuf_mask = _mm_cvtsi64_si128(0x0e0c0a0806040200); + + __m128i lo = _mm256_castsi256_si128(v); + __m128i hi = _mm256_extracti128_si256(v, 1); + + _mm_storeu_si64(&p[0], _mm_shuffle_epi8(lo, shuf_mask)); + _mm_storeu_si64(&p[8], _mm_shuffle_epi8(hi, shuf_mask)); +} + +float simd::reduce_add_ps(fvec v) { + __m128 sum = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + sum = _mm_add_ps(sum, _mm_movehdup_ps(sum)); + sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum)); + + return _mm_cvtss_f32(sum); +} + +int32_t simd::reduce_add_epi16(ivec v) { + const __m256i ones = _mm256_set1_epi16(1); + + __m256i wide = _mm256_madd_epi16(v, ones); + + __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(wide), _mm256_extracti128_si256(wide, 1)); + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2))); + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1))); + + return _mm_cvtsi128_si32(sum); +} diff --git a/engine/nnue/avx512.hpp b/engine/nnue/avx512.hpp new file mode 100644 index 00000000..aeaa2e73 --- /dev/null +++ b/engine/nnue/avx512.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include "simd.hpp" + +ivec simd::setzero_ivec() { + return _mm512_setzero_si512(); +} + +fvec simd::setzero_fvec() { + return _mm512_setzero_ps(); +} + +ivec simd::broadcast_i16(int16_t x) { + return _mm512_set1_epi16(x); +} + +fvec simd::broadcast_f32(float x) { + return _mm512_set1_ps(x); +} + +ivec simd::load_ivec(const ivec *p) { + return _mm512_loadu_si512(p); +} + +fvec simd::load_fvec(const float *p) { + return _mm512_loadu_ps(p); +} + +ivec simd::max_i16(ivec a, ivec b) { + return _mm512_max_epi16(a, b); +} + +ivec simd::min_i16(ivec a, ivec b) { + return _mm512_min_epi16(a, b); +} + +fvec simd::max_f32(fvec a, fvec b) { + return _mm512_max_ps(a, b); +} + +fvec simd::min_f32(fvec a, fvec b) { + return _mm512_min_ps(a, b); +} + +ivec simd::shift_mulhi(ivec a, ivec b) { + a = _mm512_slli_epi16(a, 7); + return _mm512_mulhrs_epi16(a, b); +} + +ivec simd::accdp_u8i8_i16(ivec a, ivec b, ivec c) { +#if defined(__AVX512VNNI__) + return _mm512_dpbusd_epi32(c, a, b); +#else + ivec sum = _mm512_maddubs_epi16(a, b); + return _mm512_add_epi16(sum, c); +#endif +} + +fvec simd::cvt_i32_f32(ivec v) { + return _mm512_cvtepi32_ps(v); +} + +fvec simd::fma_f32(fvec a, fvec b, fvec c) { + return _mm512_fmadd_ps(a, b, c); +} + +fvec simd::mul_f32(fvec a, fvec b) { + return _mm512_mul_ps(a, b); +} + +void simd::store_f32(float *p, fvec v) { + _mm512_store_ps(p, v); +} + +void simd::store_i16_i8(int8_t *p, ivec v) { + _mm256_store_si256((__m256i *)p, _mm512_cvtepi16_epi8(v)); +} + +float simd::reduce_add_ps(fvec v) { + return _mm512_reduce_add_ps(v); +} + +int32_t simd::reduce_add_epi16(ivec v) { + const __m512i ones = _mm512_set1_epi16(1); + + __m512i wide = _mm512_madd_epi16(v, ones); + + return _mm512_reduce_add_epi32(wide); +} diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 6b44dffc..3ee95ab6 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -60,12 +60,12 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator const fvec f_clip = simd::broadcast_f32(1.0f); const fvec div = simd::broadcast_f32(1.0f / (QA * QA * QB / 256.0f)); - alignas(32) int8_t l1[L1_SIZE]; + alignas(64) int8_t l1[L1_SIZE]; union { - alignas(32) int32_t l2i[L2_SIZE]; - alignas(32) float l2[L2_SIZE]; + alignas(64) int32_t l2i[L2_SIZE]; + alignas(64) float l2[L2_SIZE]; }; - alignas(32) float l3[L3_SIZE]; + alignas(64) float l3[L3_SIZE]; // Pairwise mul for (int i = 0; i < L1_SIZE / 2; i += 16) { @@ -134,6 +134,25 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator simd::store_f32(&l2[i], val); } +#if defined(__AVX512BW__) + for (int i = 0; i < L3_SIZE; i += 8 * 2) { + fvec sum0 = simd::load_fvec(&net.l2_biases[nbucket][i + 0]); + fvec sum1 = simd::load_fvec(&net.l2_biases[nbucket][i + 8]); + + for (int j = 0; j < L2_SIZE; j++) { + fvec val = simd::broadcast_f32(l2[j]); + + fvec weight0 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0]); + fvec weight1 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 8]); + + sum0 = simd::fma_f32(val, weight0, sum0); + sum1 = simd::fma_f32(val, weight1, sum1); + } + + simd::store_f32(&l3[i + 0], sum0); + simd::store_f32(&l3[i + 8], sum1); + } +#else for (int i = 0; i < L3_SIZE; i += 8 * 4) { fvec sum0 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x00]); fvec sum1 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x08]); @@ -159,6 +178,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator simd::store_f32(&l3[i + 0x10], sum2); simd::store_f32(&l3[i + 0x18], sum3); } +#endif fvec sum = f_zero; for (int i = 0; i < L3_SIZE; i += 8) { diff --git a/engine/nnue/simd.cpp b/engine/nnue/simd.cpp index caa7a869..287da2c1 100644 --- a/engine/nnue/simd.cpp +++ b/engine/nnue/simd.cpp @@ -1,97 +1,5 @@ -#include "simd.hpp" - -ivec simd::setzero_ivec() { - return _mm256_setzero_si256(); -} - -fvec simd::setzero_fvec() { - return _mm256_setzero_ps(); -} - -ivec simd::broadcast_i16(int16_t x) { - return _mm256_set1_epi16(x); -} - -fvec simd::broadcast_f32(float x) { - return _mm256_set1_ps(x); -} - -ivec simd::load_ivec(const ivec *p) { - return _mm256_load_si256(p); -} - -fvec simd::load_fvec(const float *p) { - return _mm256_load_ps(p); -} - -ivec simd::max_i16(ivec a, ivec b) { - return _mm256_max_epi16(a, b); -} - -ivec simd::min_i16(ivec a, ivec b) { - return _mm256_min_epi16(a, b); -} - -fvec simd::max_f32(fvec a, fvec b) { - return _mm256_max_ps(a, b); -} - -fvec simd::min_f32(fvec a, fvec b) { - return _mm256_min_ps(a, b); -} - -ivec simd::shift_mulhi(ivec a, ivec b) { - a = _mm256_slli_epi16(a, 7); - return _mm256_mulhrs_epi16(a, b); -} - -ivec simd::accdp_u8i8_i16(ivec a, ivec b, ivec c) { - ivec sum = _mm256_maddubs_epi16(a, b); - return _mm256_add_epi16(sum, c); -} - -fvec simd::cvt_i32_f32(ivec v) { - return _mm256_cvtepi32_ps(v); -} - -fvec simd::fma_f32(fvec a, fvec b, fvec c) { - return _mm256_fmadd_ps(a, b, c); -} - -fvec simd::mul_f32(fvec a, fvec b) { - return _mm256_mul_ps(a, b); -} - -void simd::store_f32(float *p, fvec v) { - _mm256_store_ps(p, v); -} - -void simd::store_i16_i8(int8_t *p, ivec v) { - const __m128i shuf_mask = _mm_cvtsi64_si128(0x0e0c0a0806040200); - - __m128i lo = _mm256_castsi256_si128(v); - __m128i hi = _mm256_extracti128_si256(v, 1); - - _mm_storeu_si64(&p[0], _mm_shuffle_epi8(lo, shuf_mask)); - _mm_storeu_si64(&p[8], _mm_shuffle_epi8(hi, shuf_mask)); -} - -float simd::reduce_add_ps(fvec v) { - __m128 sum = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); - sum = _mm_add_ps(sum, _mm_movehdup_ps(sum)); - sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum)); - - return _mm_cvtss_f32(sum); -} - -int32_t simd::reduce_add_epi16(ivec v) { - const __m256i ones = _mm256_set1_epi16(1); - - __m256i wide = _mm256_madd_epi16(v, ones); - - __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(wide), _mm256_extracti128_si256(wide, 1)); - sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2))); - sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1))); - - return _mm_cvtsi128_si32(sum); -} +#if defined(__AVX512F__) +#include "avx512.hpp" +#else +#include "avx2.hpp" +#endif diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp index e29db829..12a225b6 100644 --- a/engine/nnue/simd.hpp +++ b/engine/nnue/simd.hpp @@ -1,7 +1,14 @@ +#pragma once + #include +#if defined(__AVX512BW__) +using ivec = __m512i; +using fvec = __m512; +#else using ivec = __m256i; using fvec = __m256; +#endif namespace simd { ivec setzero_ivec(); From 758341c297adf5d8443ed7ab26a22e265479aecf Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Sun, 26 Apr 2026 22:02:37 -0700 Subject: [PATCH 22/26] Fix AVX SIMD lane size issue Bench: 4675375 --- engine/nnue/network.cpp | 8 ++++---- engine/nnue/simd.hpp | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 3ee95ab6..08de6798 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -68,7 +68,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator alignas(64) float l3[L3_SIZE]; // Pairwise mul - for (int i = 0; i < L1_SIZE / 2; i += 16) { + for (int i = 0; i < L1_SIZE / 2; i += SIMD_LANE_SIZE / 2) { ivec stm_val1 = simd::load_ivec((ivec *)&stm.val[i]); ivec stm_val2 = simd::load_ivec((ivec *)&stm.val[i + L1_SIZE / 2]); ivec ntm_val1 = simd::load_ivec((ivec *)&ntm.val[i]); @@ -97,7 +97,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator ivec sum2 = zero; ivec sum3 = zero; - for (int j = 0; j < L1_SIZE; j += 32) { + for (int j = 0; j < L1_SIZE; j += SIMD_LANE_SIZE) { ivec val = simd::load_ivec((ivec *)&l1[j]); ivec weight0 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 0][j]); @@ -118,7 +118,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator } // Convert l2 into a proper float array - for (int i = 0; i < L2_SIZE; i += 8) { + for (int i = 0; i < L2_SIZE; i += SIMD_LANE_SIZE / 4) { ivec i_val = simd::load_ivec((ivec *)&l2i[i]); fvec val = simd::cvt_i32_f32(i_val); @@ -181,7 +181,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator #endif fvec sum = f_zero; - for (int i = 0; i < L3_SIZE; i += 8) { + for (int i = 0; i < L3_SIZE; i += SIMD_LANE_SIZE / 4) { fvec val = simd::load_fvec(&l3[i]); val = simd::max_f32(val, f_zero); diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp index 12a225b6..65d37be4 100644 --- a/engine/nnue/simd.hpp +++ b/engine/nnue/simd.hpp @@ -5,9 +5,11 @@ #if defined(__AVX512BW__) using ivec = __m512i; using fvec = __m512; +#define SIMD_LANE_SIZE 64 // 64 bytes for AVX-512 #else using ivec = __m256i; using fvec = __m256; +#define SIMD_LANE_SIZE 32 // 32 bytes for AVX2 #endif namespace simd { From 4f60c250f5f2830e960cb767950154ca7d30feee Mon Sep 17 00:00:00 2001 From: William Ma Date: Wed, 29 Apr 2026 15:15:34 -0600 Subject: [PATCH 23/26] Fix bench consistency issues Bench: 4373230 --- engine/nnue/avx2.hpp | 8 +++---- engine/nnue/avx512.hpp | 10 ++++++--- engine/nnue/network.cpp | 49 +++++++++++++++++++++++++++++------------ engine/nnue/simd.hpp | 10 ++++++--- 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/engine/nnue/avx2.hpp b/engine/nnue/avx2.hpp index 1dec41a4..e17f899b 100644 --- a/engine/nnue/avx2.hpp +++ b/engine/nnue/avx2.hpp @@ -19,11 +19,11 @@ fvec simd::broadcast_f32(float x) { } ivec simd::load_ivec(const ivec *p) { - return _mm256_load_si256(p); + return _mm256_loadu_si256(p); } fvec simd::load_fvec(const float *p) { - return _mm256_load_ps(p); + return _mm256_loadu_ps(p); } ivec simd::max_i16(ivec a, ivec b) { @@ -65,10 +65,10 @@ fvec simd::mul_f32(fvec a, fvec b) { } void simd::store_f32(float *p, fvec v) { - _mm256_store_ps(p, v); + _mm256_storeu_ps(p, v); } -void simd::store_i16_i8(int8_t *p, ivec v) { +void simd::store_u16_u8(int8_t *p, ivec v) { const __m128i shuf_mask = _mm_cvtsi64_si128(0x0e0c0a0806040200); __m128i lo = _mm256_castsi256_si128(v); diff --git a/engine/nnue/avx512.hpp b/engine/nnue/avx512.hpp index aeaa2e73..97b9586a 100644 --- a/engine/nnue/avx512.hpp +++ b/engine/nnue/avx512.hpp @@ -1,5 +1,7 @@ #pragma once +#if defined(__AVX512BW__) + #include "simd.hpp" ivec simd::setzero_ivec() { @@ -69,11 +71,11 @@ fvec simd::mul_f32(fvec a, fvec b) { } void simd::store_f32(float *p, fvec v) { - _mm512_store_ps(p, v); + _mm512_storeu_ps(p, v); } -void simd::store_i16_i8(int8_t *p, ivec v) { - _mm256_store_si256((__m256i *)p, _mm512_cvtepi16_epi8(v)); +void simd::store_u16_u8(int8_t *p, ivec v) { + _mm256_storeu_si256((__m256i *)p, _mm512_cvtepi16_epi8(v)); } float simd::reduce_add_ps(fvec v) { @@ -87,3 +89,5 @@ int32_t simd::reduce_add_epi16(ivec v) { return _mm512_reduce_add_epi32(wide); } + +#endif diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 08de6798..77745eb7 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -68,7 +68,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator alignas(64) float l3[L3_SIZE]; // Pairwise mul - for (int i = 0; i < L1_SIZE / 2; i += SIMD_LANE_SIZE / 2) { + for (int i = 0; i < L1_SIZE / 2; i += SHORTS_PER_VEC) { ivec stm_val1 = simd::load_ivec((ivec *)&stm.val[i]); ivec stm_val2 = simd::load_ivec((ivec *)&stm.val[i + L1_SIZE / 2]); ivec ntm_val1 = simd::load_ivec((ivec *)&ntm.val[i]); @@ -87,8 +87,8 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator ivec stm_pair = simd::shift_mulhi(stm_val1, stm_val2); ivec ntm_pair = simd::shift_mulhi(ntm_val1, ntm_val2); - simd::store_i16_i8(&l1[i], stm_pair); - simd::store_i16_i8(&l1[i + L1_SIZE / 2], ntm_pair); + simd::store_u16_u8(&l1[i], stm_pair); + simd::store_u16_u8(&l1[i + L1_SIZE / 2], ntm_pair); } for (int i = 0; i < L2_SIZE; i += 4) { @@ -97,7 +97,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator ivec sum2 = zero; ivec sum3 = zero; - for (int j = 0; j < L1_SIZE; j += SIMD_LANE_SIZE) { + for (int j = 0; j < L1_SIZE; j += BYTES_PER_VEC) { ivec val = simd::load_ivec((ivec *)&l1[j]); ivec weight0 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 0][j]); @@ -118,7 +118,7 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator } // Convert l2 into a proper float array - for (int i = 0; i < L2_SIZE; i += SIMD_LANE_SIZE / 4) { + for (int i = 0; i < L2_SIZE; i += FLOATS_PER_VEC) { ivec i_val = simd::load_ivec((ivec *)&l2i[i]); fvec val = simd::cvt_i32_f32(i_val); @@ -135,25 +135,25 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator } #if defined(__AVX512BW__) - for (int i = 0; i < L3_SIZE; i += 8 * 2) { - fvec sum0 = simd::load_fvec(&net.l2_biases[nbucket][i + 0]); - fvec sum1 = simd::load_fvec(&net.l2_biases[nbucket][i + 8]); + for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC * 2) { + fvec sum0 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x00]); + fvec sum1 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x10]); for (int j = 0; j < L2_SIZE; j++) { fvec val = simd::broadcast_f32(l2[j]); - fvec weight0 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0]); - fvec weight1 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 8]); + fvec weight0 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x00]); + fvec weight1 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x10]); sum0 = simd::fma_f32(val, weight0, sum0); sum1 = simd::fma_f32(val, weight1, sum1); } - simd::store_f32(&l3[i + 0], sum0); - simd::store_f32(&l3[i + 8], sum1); + simd::store_f32(&l3[i + 0x00], sum0); + simd::store_f32(&l3[i + 0x10], sum1); } #else - for (int i = 0; i < L3_SIZE; i += 8 * 4) { + for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC * 4) { fvec sum0 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x00]); fvec sum1 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x08]); fvec sum2 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x10]); @@ -180,8 +180,9 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator } #endif +#if defined(__AVX512BW__) fvec sum = f_zero; - for (int i = 0; i < L3_SIZE; i += SIMD_LANE_SIZE / 4) { + for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC) { fvec val = simd::load_fvec(&l3[i]); val = simd::max_f32(val, f_zero); @@ -191,6 +192,26 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator sum = simd::fma_f32(simd::mul_f32(val, val), weight, sum); } +#else + fvec sum0 = f_zero; + fvec sum1 = f_zero; + for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC * 2) { + fvec val0 = simd::load_fvec(&l3[i + 0]); + fvec val1 = simd::load_fvec(&l3[i + 8]); + + val0 = simd::max_f32(val0, f_zero); + val0 = simd::min_f32(val0, f_clip); + val1 = simd::max_f32(val1, f_zero); + val1 = simd::min_f32(val1, f_clip); + + fvec weight0 = simd::load_fvec(&net.output_weights[nbucket][i + 0]); + fvec weight1 = simd::load_fvec(&net.output_weights[nbucket][i + 8]); + + sum0 = simd::fma_f32(simd::mul_f32(val0, val0), weight0, sum0); + sum1 = simd::fma_f32(simd::mul_f32(val1, val1), weight1, sum1); + } + fvec sum = _mm256_add_ps(sum0, sum1); +#endif float score = simd::reduce_add_ps(sum) + net.output_biases[nbucket]; diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp index 65d37be4..a380d850 100644 --- a/engine/nnue/simd.hpp +++ b/engine/nnue/simd.hpp @@ -5,13 +5,17 @@ #if defined(__AVX512BW__) using ivec = __m512i; using fvec = __m512; -#define SIMD_LANE_SIZE 64 // 64 bytes for AVX-512 +#define VEC_SIZE 512 #else using ivec = __m256i; using fvec = __m256; -#define SIMD_LANE_SIZE 32 // 32 bytes for AVX2 +#define VEC_SIZE 256 #endif +#define BYTES_PER_VEC (VEC_SIZE / 8) +#define SHORTS_PER_VEC (VEC_SIZE / 16) +#define FLOATS_PER_VEC (VEC_SIZE / 32) + namespace simd { ivec setzero_ivec(); fvec setzero_fvec(); @@ -37,7 +41,7 @@ namespace simd { void store_f32(float *p, fvec v); - void store_i16_i8(int8_t *p, ivec v); + void store_u16_u8(int8_t *p, ivec v); float reduce_add_ps(fvec v); int32_t reduce_add_epi16(ivec v); }; From f718147129099b9795a6161adec39f6525f5b107 Mon Sep 17 00:00:00 2001 From: William Ma Date: Wed, 29 Apr 2026 18:36:54 -0600 Subject: [PATCH 24/26] Fix VNNI reduction Bench: 4559824 --- engine/nnue/avx512.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/engine/nnue/avx512.hpp b/engine/nnue/avx512.hpp index 97b9586a..97173a74 100644 --- a/engine/nnue/avx512.hpp +++ b/engine/nnue/avx512.hpp @@ -83,9 +83,12 @@ float simd::reduce_add_ps(fvec v) { } int32_t simd::reduce_add_epi16(ivec v) { +#if defined(__AVX512VNNI__) + __m512i wide = v; +#else const __m512i ones = _mm512_set1_epi16(1); - __m512i wide = _mm512_madd_epi16(v, ones); +#endif return _mm512_reduce_add_epi32(wide); } From 85b61703bdedf3660f00bfd8900936558708c1b0 Mon Sep 17 00:00:00 2001 From: William Ma Date: Wed, 29 Apr 2026 18:54:46 -0600 Subject: [PATCH 25/26] Cleanup Bench: 4559824 --- engine/nnue/avx2.hpp | 24 ++++---- engine/nnue/avx512.hpp | 24 ++++---- engine/nnue/network.cpp | 128 ++++++++++++---------------------------- engine/nnue/simd.hpp | 19 ++++-- 4 files changed, 77 insertions(+), 118 deletions(-) diff --git a/engine/nnue/avx2.hpp b/engine/nnue/avx2.hpp index e17f899b..820ce76d 100644 --- a/engine/nnue/avx2.hpp +++ b/engine/nnue/avx2.hpp @@ -26,20 +26,16 @@ fvec simd::load_fvec(const float *p) { return _mm256_loadu_ps(p); } -ivec simd::max_i16(ivec a, ivec b) { - return _mm256_max_epi16(a, b); +ivec simd::clamp_i16(ivec x, ivec lo, ivec hi) { + x = _mm256_max_epi16(x, lo); + x = _mm256_min_epi16(x, hi); + return x; } -ivec simd::min_i16(ivec a, ivec b) { - return _mm256_min_epi16(a, b); -} - -fvec simd::max_f32(fvec a, fvec b) { - return _mm256_max_ps(a, b); -} - -fvec simd::min_f32(fvec a, fvec b) { - return _mm256_min_ps(a, b); +fvec simd::clamp_f32(fvec x, fvec lo, fvec hi) { + x = _mm256_max_ps(x, lo); + x = _mm256_min_ps(x, hi); + return x; } ivec simd::shift_mulhi(ivec a, ivec b) { @@ -64,6 +60,10 @@ fvec simd::mul_f32(fvec a, fvec b) { return _mm256_mul_ps(a, b); } +fvec simd::add_f32(fvec a, fvec b) { + return _mm256_add_ps(a, b); +} + void simd::store_f32(float *p, fvec v) { _mm256_storeu_ps(p, v); } diff --git a/engine/nnue/avx512.hpp b/engine/nnue/avx512.hpp index 97173a74..8a7f3377 100644 --- a/engine/nnue/avx512.hpp +++ b/engine/nnue/avx512.hpp @@ -28,20 +28,16 @@ fvec simd::load_fvec(const float *p) { return _mm512_loadu_ps(p); } -ivec simd::max_i16(ivec a, ivec b) { - return _mm512_max_epi16(a, b); +ivec simd::clamp_i16(ivec x, ivec lo, ivec hi) { + x = _mm512_max_epi16(x, lo); + x = _mm512_min_epi16(x, hi); + return x; } -ivec simd::min_i16(ivec a, ivec b) { - return _mm512_min_epi16(a, b); -} - -fvec simd::max_f32(fvec a, fvec b) { - return _mm512_max_ps(a, b); -} - -fvec simd::min_f32(fvec a, fvec b) { - return _mm512_min_ps(a, b); +fvec simd::clamp_f32(fvec x, fvec lo, fvec hi) { + x = _mm512_max_ps(x, lo); + x = _mm512_min_ps(x, hi); + return x; } ivec simd::shift_mulhi(ivec a, ivec b) { @@ -70,6 +66,10 @@ fvec simd::mul_f32(fvec a, fvec b) { return _mm512_mul_ps(a, b); } +fvec simd::add_f32(fvec a, fvec b) { + return _mm512_add_ps(a, b); +} + void simd::store_f32(float *p, fvec v) { _mm512_storeu_ps(p, v); } diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 77745eb7..cee715f1 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -74,15 +74,11 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator ivec ntm_val1 = simd::load_ivec((ivec *)&ntm.val[i]); ivec ntm_val2 = simd::load_ivec((ivec *)&ntm.val[i + L1_SIZE / 2]); - stm_val1 = simd::max_i16(stm_val1, zero); - stm_val1 = simd::min_i16(stm_val1, clip); - stm_val2 = simd::max_i16(stm_val2, zero); - stm_val2 = simd::min_i16(stm_val2, clip); + stm_val1 = simd::clamp_i16(stm_val1, zero, clip); + stm_val2 = simd::clamp_i16(stm_val2, zero, clip); - ntm_val1 = simd::max_i16(ntm_val1, zero); - ntm_val1 = simd::min_i16(ntm_val1, clip); - ntm_val2 = simd::max_i16(ntm_val2, zero); - ntm_val2 = simd::min_i16(ntm_val2, clip); + ntm_val1 = simd::clamp_i16(ntm_val1, zero, clip); + ntm_val2 = simd::clamp_i16(ntm_val2, zero, clip); ivec stm_pair = simd::shift_mulhi(stm_val1, stm_val2); ivec ntm_pair = simd::shift_mulhi(ntm_val1, ntm_val2); @@ -91,30 +87,22 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator simd::store_u16_u8(&l1[i + L1_SIZE / 2], ntm_pair); } - for (int i = 0; i < L2_SIZE; i += 4) { - ivec sum0 = zero; - ivec sum1 = zero; - ivec sum2 = zero; - ivec sum3 = zero; + for (int i = 0; i < L2_SIZE; i += L1_UNROLL) { + ivec sums[L1_UNROLL]; + for (int j = 0; j < L1_UNROLL; j++) + sums[j] = zero; for (int j = 0; j < L1_SIZE; j += BYTES_PER_VEC) { ivec val = simd::load_ivec((ivec *)&l1[j]); - ivec weight0 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 0][j]); - ivec weight1 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 1][j]); - ivec weight2 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 2][j]); - ivec weight3 = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + 3][j]); - - sum0 = simd::accdp_u8i8_i16(val, weight0, sum0); - sum1 = simd::accdp_u8i8_i16(val, weight1, sum1); - sum2 = simd::accdp_u8i8_i16(val, weight2, sum2); - sum3 = simd::accdp_u8i8_i16(val, weight3, sum3); + for (int k = 0; k < L1_UNROLL; k++) { + ivec weight = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + k][j]); + sums[k] = simd::accdp_u8i8_i16(val, weight, sums[k]); + } } - l2i[i + 0] = simd::reduce_add_epi16(sum0); - l2i[i + 1] = simd::reduce_add_epi16(sum1); - l2i[i + 2] = simd::reduce_add_epi16(sum2); - l2i[i + 3] = simd::reduce_add_epi16(sum3); + for (int j = 0; j < L1_UNROLL; j++) + l2i[i + j] = simd::reduce_add_epi16(sums[j]); } // Convert l2 into a proper float array @@ -126,94 +114,54 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator val = simd::fma_f32(val, div, bias); - val = simd::max_f32(val, f_zero); - val = simd::min_f32(val, f_clip); + val = simd::clamp_f32(val, f_zero, f_clip); val = simd::mul_f32(val, val); simd::store_f32(&l2[i], val); } -#if defined(__AVX512BW__) - for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC * 2) { - fvec sum0 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x00]); - fvec sum1 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x10]); + for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC * L2_UNROLL) { + fvec sums[L2_UNROLL]; + for (int j = 0; j < L2_UNROLL; j++) + sums[j] = simd::load_fvec(&net.l2_biases[nbucket][i + j * FLOATS_PER_VEC]); for (int j = 0; j < L2_SIZE; j++) { fvec val = simd::broadcast_f32(l2[j]); - fvec weight0 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x00]); - fvec weight1 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x10]); - - sum0 = simd::fma_f32(val, weight0, sum0); - sum1 = simd::fma_f32(val, weight1, sum1); + for (int k = 0; k < L2_UNROLL; k++) { + fvec weight = simd::load_fvec(&net.l2_weights[nbucket][j][i + k * FLOATS_PER_VEC]); + sums[k] = simd::fma_f32(val, weight, sums[k]); + } } - simd::store_f32(&l3[i + 0x00], sum0); - simd::store_f32(&l3[i + 0x10], sum1); + for (int j = 0; j < L2_UNROLL; j++) + simd::store_f32(&l3[i + j * FLOATS_PER_VEC], sums[j]); } -#else - for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC * 4) { - fvec sum0 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x00]); - fvec sum1 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x08]); - fvec sum2 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x10]); - fvec sum3 = simd::load_fvec(&net.l2_biases[nbucket][i + 0x18]); - - for (int j = 0; j < L2_SIZE; j++) { - fvec val = simd::broadcast_f32(l2[j]); - - fvec weight0 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x00]); - fvec weight1 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x08]); - fvec weight2 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x10]); - fvec weight3 = simd::load_fvec(&net.l2_weights[nbucket][j][i + 0x18]); - - sum0 = simd::fma_f32(val, weight0, sum0); - sum1 = simd::fma_f32(val, weight1, sum1); - sum2 = simd::fma_f32(val, weight2, sum2); - sum3 = simd::fma_f32(val, weight3, sum3); - } - simd::store_f32(&l3[i + 0x00], sum0); - simd::store_f32(&l3[i + 0x08], sum1); - simd::store_f32(&l3[i + 0x10], sum2); - simd::store_f32(&l3[i + 0x18], sum3); - } -#endif + fvec sums[L3_UNROLL]; + for (int i = 0; i < L3_UNROLL; i++) + sums[i] = f_zero; -#if defined(__AVX512BW__) - fvec sum = f_zero; for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC) { fvec val = simd::load_fvec(&l3[i]); - val = simd::max_f32(val, f_zero); - val = simd::min_f32(val, f_clip); + val = simd::clamp_f32(val, f_zero, f_clip); fvec weight = simd::load_fvec(&net.output_weights[nbucket][i]); - sum = simd::fma_f32(simd::mul_f32(val, val), weight, sum); + int idx = i / FLOATS_PER_VEC % L3_UNROLL; + sums[idx] = simd::fma_f32(simd::mul_f32(val, val), weight, sums[idx]); } -#else - fvec sum0 = f_zero; - fvec sum1 = f_zero; - for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC * 2) { - fvec val0 = simd::load_fvec(&l3[i + 0]); - fvec val1 = simd::load_fvec(&l3[i + 8]); - - val0 = simd::max_f32(val0, f_zero); - val0 = simd::min_f32(val0, f_clip); - val1 = simd::max_f32(val1, f_zero); - val1 = simd::min_f32(val1, f_clip); - - fvec weight0 = simd::load_fvec(&net.output_weights[nbucket][i + 0]); - fvec weight1 = simd::load_fvec(&net.output_weights[nbucket][i + 8]); - - sum0 = simd::fma_f32(simd::mul_f32(val0, val0), weight0, sum0); - sum1 = simd::fma_f32(simd::mul_f32(val1, val1), weight1, sum1); + + int num = L3_UNROLL; + while (num > 1) { + num /= 2; + for (int i = 0; i < num; i++) + sums[i] = simd::add_f32(sums[i], sums[i + num]); } - fvec sum = _mm256_add_ps(sum0, sum1); -#endif - float score = simd::reduce_add_ps(sum) + net.output_biases[nbucket]; + float score = simd::reduce_add_ps(sums[0]) + net.output_biases[nbucket]; return score * SCALE; } diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp index a380d850..f091c1f1 100644 --- a/engine/nnue/simd.hpp +++ b/engine/nnue/simd.hpp @@ -3,13 +3,25 @@ #include #if defined(__AVX512BW__) + using ivec = __m512i; using fvec = __m512; #define VEC_SIZE 512 + +#define L1_UNROLL 4 +#define L2_UNROLL 2 +#define L3_UNROLL 1 + #else + using ivec = __m256i; using fvec = __m256; #define VEC_SIZE 256 + +#define L1_UNROLL 4 +#define L2_UNROLL 4 +#define L3_UNROLL 2 + #endif #define BYTES_PER_VEC (VEC_SIZE / 8) @@ -26,10 +38,8 @@ namespace simd { ivec load_ivec(const ivec *p); fvec load_fvec(const float *p); - ivec max_i16(ivec a, ivec b); - ivec min_i16(ivec a, ivec b); - fvec max_f32(fvec a, fvec b); - fvec min_f32(fvec a, fvec b); + ivec clamp_i16(ivec x, ivec lo, ivec hi); + fvec clamp_f32(fvec x, fvec lo, fvec hi); ivec shift_mulhi(ivec a, ivec b); ivec accdp_u8i8_i16(ivec a, ivec b, ivec c); @@ -38,6 +48,7 @@ namespace simd { fvec fma_f32(fvec a, fvec b, fvec c); fvec mul_f32(fvec a, fvec b); + fvec add_f32(fvec a, fvec b); void store_f32(float *p, fvec v); From f0f6eaad0c01192acacbf4866be18dd5d4ada0fb Mon Sep 17 00:00:00 2001 From: William Ma Date: Wed, 29 Apr 2026 21:44:46 -0600 Subject: [PATCH 26/26] Fix stuff Bench: 4559824 --- engine/nnue/avx2.hpp | 11 ++++------- engine/nnue/avx512.hpp | 4 ++-- engine/nnue/network.cpp | 8 +++----- engine/nnue/simd.cpp | 2 +- engine/nnue/simd.hpp | 3 ++- 5 files changed, 12 insertions(+), 16 deletions(-) diff --git a/engine/nnue/avx2.hpp b/engine/nnue/avx2.hpp index 820ce76d..1a59e592 100644 --- a/engine/nnue/avx2.hpp +++ b/engine/nnue/avx2.hpp @@ -68,14 +68,11 @@ void simd::store_f32(float *p, fvec v) { _mm256_storeu_ps(p, v); } -void simd::store_u16_u8(int8_t *p, ivec v) { - const __m128i shuf_mask = _mm_cvtsi64_si128(0x0e0c0a0806040200); +void simd::store_u16_u8(uint8_t *p, ivec v) { + __m256i res = _mm256_packus_epi16(v, v); + res = _mm256_permute4x64_epi64(res, _MM_PERM_DCCA); - __m128i lo = _mm256_castsi256_si128(v); - __m128i hi = _mm256_extracti128_si256(v, 1); - - _mm_storeu_si64(&p[0], _mm_shuffle_epi8(lo, shuf_mask)); - _mm_storeu_si64(&p[8], _mm_shuffle_epi8(hi, shuf_mask)); + _mm_storeu_si128((__m128i *)p, _mm256_castsi256_si128(res)); } float simd::reduce_add_ps(fvec v) { diff --git a/engine/nnue/avx512.hpp b/engine/nnue/avx512.hpp index 8a7f3377..82432d75 100644 --- a/engine/nnue/avx512.hpp +++ b/engine/nnue/avx512.hpp @@ -74,8 +74,8 @@ void simd::store_f32(float *p, fvec v) { _mm512_storeu_ps(p, v); } -void simd::store_u16_u8(int8_t *p, ivec v) { - _mm256_storeu_si256((__m256i *)p, _mm512_cvtepi16_epi8(v)); +void simd::store_u16_u8(uint8_t *p, ivec v) { + _mm256_storeu_si256((__m256i *)p, _mm512_cvtusepi16_epi8(v)); } float simd::reduce_add_ps(fvec v) { diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index cee715f1..aa859c67 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -60,11 +60,9 @@ int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator const fvec f_clip = simd::broadcast_f32(1.0f); const fvec div = simd::broadcast_f32(1.0f / (QA * QA * QB / 256.0f)); - alignas(64) int8_t l1[L1_SIZE]; - union { - alignas(64) int32_t l2i[L2_SIZE]; - alignas(64) float l2[L2_SIZE]; - }; + alignas(64) uint8_t l1[L1_SIZE]; + alignas(64) int32_t l2i[L2_SIZE]; + alignas(64) float l2[L2_SIZE]; alignas(64) float l3[L3_SIZE]; // Pairwise mul diff --git a/engine/nnue/simd.cpp b/engine/nnue/simd.cpp index 287da2c1..49edd0ea 100644 --- a/engine/nnue/simd.cpp +++ b/engine/nnue/simd.cpp @@ -1,4 +1,4 @@ -#if defined(__AVX512F__) +#if defined(__AVX512BW__) #include "avx512.hpp" #else #include "avx2.hpp" diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp index f091c1f1..fd406898 100644 --- a/engine/nnue/simd.hpp +++ b/engine/nnue/simd.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #if defined(__AVX512BW__) @@ -52,7 +53,7 @@ namespace simd { void store_f32(float *p, fvec v); - void store_u16_u8(int8_t *p, ivec v); + void store_u16_u8(uint8_t *p, ivec v); float reduce_add_ps(fvec v); int32_t reduce_add_epi16(ivec v); };