diff --git a/engine/main.cpp b/engine/main.cpp index 057ba245..47548d5e 100644 --- a/engine/main.cpp +++ b/engine/main.cpp @@ -417,7 +417,7 @@ __attribute__((weak)) int main(int argc, char *argv[]) { for (int i = 0; i < 8; i++) { pos.reset_startpos(); pos.mailbox[SQ_A2 + i] = NO_PIECE; - am.refresh_finny(pos, am.idx); + am.full_refresh(pos, 0); Value score = eval(pos, am); int diff = startpos_score - score; tot += diff; diff --git a/engine/nnue/accumulator.cpp b/engine/nnue/accumulator.cpp index c19bde2c..231d485b 100644 --- a/engine/nnue/accumulator.cpp +++ b/engine/nnue/accumulator.cpp @@ -3,7 +3,7 @@ void AccumulatorManager::AccumulatorPair::update_add(Square sq, PieceType pt, bool side, int wbucket, int bbucket) { uint16_t w_index = calculate_index(sq, pt, side, 0, wbucket); uint16_t b_index = calculate_index(sq, pt, side, 1, bbucket); - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { w_acc.val[i] += nnue_network.accumulator_weights[w_index][i]; b_acc.val[i] += nnue_network.accumulator_weights[b_index][i]; } @@ -12,7 +12,7 @@ void AccumulatorManager::AccumulatorPair::update_add(Square sq, PieceType pt, bo void AccumulatorManager::AccumulatorPair::update_sub(Square sq, PieceType pt, bool side, int wbucket, int bbucket) { uint16_t w_index = calculate_index(sq, pt, side, 0, wbucket); uint16_t b_index = calculate_index(sq, pt, side, 1, bbucket); - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { w_acc.val[i] -= nnue_network.accumulator_weights[w_index][i]; b_acc.val[i] -= nnue_network.accumulator_weights[b_index][i]; } @@ -20,7 +20,7 @@ void AccumulatorManager::AccumulatorPair::update_sub(Square sq, PieceType pt, bo void AccumulatorManager::full_refresh(Position &pos, int index) { // Init the first accumulator so we have a basepoint - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { accs[index].w_acc.val[i] = nnue_network.accumulator_biases[i]; accs[index].b_acc.val[i] = nnue_network.accumulator_biases[i]; } @@ -57,7 +57,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { Accumulator &w_acc = accs[index].w_acc; Accumulator &b_acc = accs[index].b_acc; - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { w_acc.val[i] = f_w_acc.val[i]; b_acc.val[i] = f_b_acc.val[i]; } @@ -75,7 +75,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (piece != NO_PIECE) { // Add to accumulator int index = calculate_index((Square)i, pt, side, 0, winbucket); - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { w_acc.val[k] += nnue_network.accumulator_weights[index][k]; } } @@ -83,7 +83,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (prev_w_piece != NO_PIECE) { // Remove from accumulator int index = calculate_index((Square)i, prev_w_pt, prev_w_side, 0, winbucket); - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { w_acc.val[k] -= nnue_network.accumulator_weights[index][k]; } } @@ -97,7 +97,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (piece != NO_PIECE) { // Add to accumulator int index = calculate_index((Square)i, pt, side, 1, binbucket); - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { b_acc.val[k] += nnue_network.accumulator_weights[index][k]; } } @@ -105,7 +105,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { if (prev_b_piece != NO_PIECE) { // Remove from accumulator int index = calculate_index((Square)i, prev_b_pt, prev_b_side, 1, binbucket); - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { b_acc.val[k] -= nnue_network.accumulator_weights[index][k]; } } @@ -113,7 +113,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) { } // Update finny tables - for (int i = 0; i < HL_SIZE; i++) { + for (int i = 0; i < L1_SIZE; i++) { f_w_acc.val[i] = w_acc.val[i]; f_b_acc.val[i] = b_acc.val[i]; } @@ -158,19 +158,19 @@ void AccumulatorManager::apply_lazy(Position &pos) { auto &u = updates[i]; if (u.deltas == 2) { // -+ - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] + nnue_network.accumulator_weights[u.w_deltas[1]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] + nnue_network.accumulator_weights[u.b_deltas[1]][k]; } } else if (u.deltas == 3) { // --+ - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] - nnue_network.accumulator_weights[u.w_deltas[1]][k] + nnue_network.accumulator_weights[u.w_deltas[2]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] - nnue_network.accumulator_weights[u.b_deltas[1]][k] + nnue_network.accumulator_weights[u.b_deltas[2]][k]; } } else if (u.deltas == 4) { // --++ - for (int k = 0; k < HL_SIZE; k++) { + for (int k = 0; k < L1_SIZE; k++) { accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] - nnue_network.accumulator_weights[u.w_deltas[1]][k] + nnue_network.accumulator_weights[u.w_deltas[2]][k] + nnue_network.accumulator_weights[u.w_deltas[3]][k]; accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] - nnue_network.accumulator_weights[u.b_deltas[1]][k] + nnue_network.accumulator_weights[u.b_deltas[2]][k] + nnue_network.accumulator_weights[u.b_deltas[3]][k]; } diff --git a/engine/nnue/accumulator.hpp b/engine/nnue/accumulator.hpp index 4adab689..3a97e6e6 100644 --- a/engine/nnue/accumulator.hpp +++ b/engine/nnue/accumulator.hpp @@ -32,7 +32,7 @@ struct AccumulatorManager { std::fill(&mailboxes[0][0][0], &mailboxes[0][0][0] + NINPUTS * 2 * 2 * 64, NO_PIECE); for (int i = 0; i < NINPUTS * 2; i++) { - for (int j = 0; j < HL_SIZE; j++) { + for (int j = 0; j < L1_SIZE; j++) { accs[i].w_acc.val[j] = nnue_network.accumulator_biases[j]; accs[i].b_acc.val[j] = nnue_network.accumulator_biases[j]; } diff --git a/engine/nnue/avx2.hpp b/engine/nnue/avx2.hpp new file mode 100644 index 00000000..1a59e592 --- /dev/null +++ b/engine/nnue/avx2.hpp @@ -0,0 +1,96 @@ +#pragma once + +#include "simd.hpp" + +ivec simd::setzero_ivec() { + return _mm256_setzero_si256(); +} + +fvec simd::setzero_fvec() { + return _mm256_setzero_ps(); +} + +ivec simd::broadcast_i16(int16_t x) { + return _mm256_set1_epi16(x); +} + +fvec simd::broadcast_f32(float x) { + return _mm256_set1_ps(x); +} + +ivec simd::load_ivec(const ivec *p) { + return _mm256_loadu_si256(p); +} + +fvec simd::load_fvec(const float *p) { + return _mm256_loadu_ps(p); +} + +ivec simd::clamp_i16(ivec x, ivec lo, ivec hi) { + x = _mm256_max_epi16(x, lo); + x = _mm256_min_epi16(x, hi); + return x; +} + +fvec simd::clamp_f32(fvec x, fvec lo, fvec hi) { + x = _mm256_max_ps(x, lo); + x = _mm256_min_ps(x, hi); + return x; +} + +ivec simd::shift_mulhi(ivec a, ivec b) { + a = _mm256_slli_epi16(a, 7); + return _mm256_mulhrs_epi16(a, b); +} + +ivec simd::accdp_u8i8_i16(ivec a, ivec b, ivec c) { + ivec sum = _mm256_maddubs_epi16(a, b); + return _mm256_add_epi16(sum, c); +} + +fvec simd::cvt_i32_f32(ivec v) { + return _mm256_cvtepi32_ps(v); +} + +fvec simd::fma_f32(fvec a, fvec b, fvec c) { + return _mm256_fmadd_ps(a, b, c); +} + +fvec simd::mul_f32(fvec a, fvec b) { + return _mm256_mul_ps(a, b); +} + +fvec simd::add_f32(fvec a, fvec b) { + return _mm256_add_ps(a, b); +} + +void simd::store_f32(float *p, fvec v) { + _mm256_storeu_ps(p, v); +} + +void simd::store_u16_u8(uint8_t *p, ivec v) { + __m256i res = _mm256_packus_epi16(v, v); + res = _mm256_permute4x64_epi64(res, _MM_PERM_DCCA); + + _mm_storeu_si128((__m128i *)p, _mm256_castsi256_si128(res)); +} + +float simd::reduce_add_ps(fvec v) { + __m128 sum = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + sum = _mm_add_ps(sum, _mm_movehdup_ps(sum)); + sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum)); + + return _mm_cvtss_f32(sum); +} + +int32_t simd::reduce_add_epi16(ivec v) { + const __m256i ones = _mm256_set1_epi16(1); + + __m256i wide = _mm256_madd_epi16(v, ones); + + __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(wide), _mm256_extracti128_si256(wide, 1)); + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2))); + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1))); + + return _mm_cvtsi128_si32(sum); +} diff --git a/engine/nnue/avx512.hpp b/engine/nnue/avx512.hpp new file mode 100644 index 00000000..82432d75 --- /dev/null +++ b/engine/nnue/avx512.hpp @@ -0,0 +1,96 @@ +#pragma once + +#if defined(__AVX512BW__) + +#include "simd.hpp" + +ivec simd::setzero_ivec() { + return _mm512_setzero_si512(); +} + +fvec simd::setzero_fvec() { + return _mm512_setzero_ps(); +} + +ivec simd::broadcast_i16(int16_t x) { + return _mm512_set1_epi16(x); +} + +fvec simd::broadcast_f32(float x) { + return _mm512_set1_ps(x); +} + +ivec simd::load_ivec(const ivec *p) { + return _mm512_loadu_si512(p); +} + +fvec simd::load_fvec(const float *p) { + return _mm512_loadu_ps(p); +} + +ivec simd::clamp_i16(ivec x, ivec lo, ivec hi) { + x = _mm512_max_epi16(x, lo); + x = _mm512_min_epi16(x, hi); + return x; +} + +fvec simd::clamp_f32(fvec x, fvec lo, fvec hi) { + x = _mm512_max_ps(x, lo); + x = _mm512_min_ps(x, hi); + return x; +} + +ivec simd::shift_mulhi(ivec a, ivec b) { + a = _mm512_slli_epi16(a, 7); + return _mm512_mulhrs_epi16(a, b); +} + +ivec simd::accdp_u8i8_i16(ivec a, ivec b, ivec c) { +#if defined(__AVX512VNNI__) + return _mm512_dpbusd_epi32(c, a, b); +#else + ivec sum = _mm512_maddubs_epi16(a, b); + return _mm512_add_epi16(sum, c); +#endif +} + +fvec simd::cvt_i32_f32(ivec v) { + return _mm512_cvtepi32_ps(v); +} + +fvec simd::fma_f32(fvec a, fvec b, fvec c) { + return _mm512_fmadd_ps(a, b, c); +} + +fvec simd::mul_f32(fvec a, fvec b) { + return _mm512_mul_ps(a, b); +} + +fvec simd::add_f32(fvec a, fvec b) { + return _mm512_add_ps(a, b); +} + +void simd::store_f32(float *p, fvec v) { + _mm512_storeu_ps(p, v); +} + +void simd::store_u16_u8(uint8_t *p, ivec v) { + _mm256_storeu_si256((__m256i *)p, _mm512_cvtusepi16_epi8(v)); +} + +float simd::reduce_add_ps(fvec v) { + return _mm512_reduce_add_ps(v); +} + +int32_t simd::reduce_add_epi16(ivec v) { +#if defined(__AVX512VNNI__) + __m512i wide = v; +#else + const __m512i ones = _mm512_set1_epi16(1); + __m512i wide = _mm512_madd_epi16(v, ones); +#endif + + return _mm512_reduce_add_epi32(wide); +} + +#endif diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index fe51e6f3..aa859c67 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -1,4 +1,6 @@ #include "network.hpp" +#include "simd.hpp" + #include "incbin.h" extern "C" { @@ -7,13 +9,35 @@ extern "C" { void Network::load() { char *ptr = (char *)gnetwork_weightsData; + memcpy(accumulator_weights, ptr, sizeof(accumulator_weights)); ptr += sizeof(accumulator_weights); + memcpy(accumulator_biases, ptr, sizeof(accumulator_biases)); ptr += sizeof(accumulator_biases); + + memcpy(l1_weights, ptr, sizeof(l1_weights)); + ptr += sizeof(l1_weights); + + memcpy(l1_biases, ptr, sizeof(l1_biases)); + ptr += sizeof(l1_biases); + + for (int i = 0; i < NBUCKETS; i++) { + for (int j = 0; j < L3_SIZE; j++) { + for (int k = 0; k < L2_SIZE; k++) { + memcpy(&l2_weights[i][k][j], ptr, 4); + ptr += 4; + } + } + } + + memcpy(l2_biases, ptr, sizeof(l2_biases)); + ptr += sizeof(l2_biases); + memcpy(output_weights, ptr, sizeof(output_weights)); ptr += sizeof(output_weights); - memcpy(&output_bias, ptr, sizeof(output_bias)); + + memcpy(&output_biases, ptr, sizeof(output_biases)); } int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nbucket) { @@ -30,47 +54,112 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { - __m256i sum = _mm256_setzero_si256(); - const __m256i zero = _mm256_setzero_si256(); - const __m256i qa_vec = _mm256_set1_epi16(QA); + const ivec zero = simd::setzero_ivec(); + const ivec clip = simd::broadcast_i16(QA); + const fvec f_zero = simd::setzero_fvec(); + const fvec f_clip = simd::broadcast_f32(1.0f); + const fvec div = simd::broadcast_f32(1.0f / (QA * QA * QB / 256.0f)); + + alignas(64) uint8_t l1[L1_SIZE]; + alignas(64) int32_t l2i[L2_SIZE]; + alignas(64) float l2[L2_SIZE]; + alignas(64) float l3[L3_SIZE]; + + // Pairwise mul + for (int i = 0; i < L1_SIZE / 2; i += SHORTS_PER_VEC) { + ivec stm_val1 = simd::load_ivec((ivec *)&stm.val[i]); + ivec stm_val2 = simd::load_ivec((ivec *)&stm.val[i + L1_SIZE / 2]); + ivec ntm_val1 = simd::load_ivec((ivec *)&ntm.val[i]); + ivec ntm_val2 = simd::load_ivec((ivec *)&ntm.val[i + L1_SIZE / 2]); + + stm_val1 = simd::clamp_i16(stm_val1, zero, clip); + stm_val2 = simd::clamp_i16(stm_val2, zero, clip); + + ntm_val1 = simd::clamp_i16(ntm_val1, zero, clip); + ntm_val2 = simd::clamp_i16(ntm_val2, zero, clip); + + ivec stm_pair = simd::shift_mulhi(stm_val1, stm_val2); + ivec ntm_pair = simd::shift_mulhi(ntm_val1, ntm_val2); + + simd::store_u16_u8(&l1[i], stm_pair); + simd::store_u16_u8(&l1[i + L1_SIZE / 2], ntm_pair); + } - for (int i = 0; i < HL_SIZE / 2; i += 16) { - __m256i stm_vals1 = _mm256_loadu_si256((__m256i *)&stm.val[i]); - __m256i stm_vals2 = _mm256_loadu_si256((__m256i *)&stm.val[i + HL_SIZE / 2]); - __m256i ntm_vals1 = _mm256_loadu_si256((__m256i *)&ntm.val[i]); - __m256i ntm_vals2 = _mm256_loadu_si256((__m256i *)&ntm.val[i + HL_SIZE / 2]); + for (int i = 0; i < L2_SIZE; i += L1_UNROLL) { + ivec sums[L1_UNROLL]; + for (int j = 0; j < L1_UNROLL; j++) + sums[j] = zero; - stm_vals1 = _mm256_max_epi16(stm_vals1, zero); - stm_vals1 = _mm256_min_epi16(stm_vals1, qa_vec); - stm_vals2 = _mm256_max_epi16(stm_vals2, zero); - stm_vals2 = _mm256_min_epi16(stm_vals2, qa_vec); + for (int j = 0; j < L1_SIZE; j += BYTES_PER_VEC) { + ivec val = simd::load_ivec((ivec *)&l1[j]); - ntm_vals1 = _mm256_max_epi16(ntm_vals1, zero); - ntm_vals1 = _mm256_min_epi16(ntm_vals1, qa_vec); - ntm_vals2 = _mm256_max_epi16(ntm_vals2, zero); - ntm_vals2 = _mm256_min_epi16(ntm_vals2, qa_vec); + for (int k = 0; k < L1_UNROLL; k++) { + ivec weight = simd::load_ivec((ivec *)&net.l1_weights[nbucket][i + k][j]); + sums[k] = simd::accdp_u8i8_i16(val, weight, sums[k]); + } + } - __m256i stm_weights = _mm256_loadu_si256((__m256i *)&net.output_weights[nbucket][i]); - __m256i ntm_weights = _mm256_loadu_si256((__m256i *)&net.output_weights[nbucket][i + HL_SIZE / 2]); + for (int j = 0; j < L1_UNROLL; j++) + l2i[i + j] = simd::reduce_add_epi16(sums[j]); + } + + // Convert l2 into a proper float array + for (int i = 0; i < L2_SIZE; i += FLOATS_PER_VEC) { + ivec i_val = simd::load_ivec((ivec *)&l2i[i]); + fvec val = simd::cvt_i32_f32(i_val); + + fvec bias = simd::load_fvec(&net.l1_biases[nbucket][i]); + + val = simd::fma_f32(val, div, bias); + + val = simd::clamp_f32(val, f_zero, f_clip); + + val = simd::mul_f32(val, val); + + simd::store_f32(&l2[i], val); + } + + for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC * L2_UNROLL) { + fvec sums[L2_UNROLL]; + for (int j = 0; j < L2_UNROLL; j++) + sums[j] = simd::load_fvec(&net.l2_biases[nbucket][i + j * FLOATS_PER_VEC]); - __m256i stm_prod = _mm256_mullo_epi16(stm_vals1, stm_weights); - __m256i ntm_prod = _mm256_mullo_epi16(ntm_vals1, ntm_weights); + for (int j = 0; j < L2_SIZE; j++) { + fvec val = simd::broadcast_f32(l2[j]); - __m256i stm_res = _mm256_madd_epi16(stm_prod, stm_vals2); - __m256i ntm_res = _mm256_madd_epi16(ntm_prod, ntm_vals2); + for (int k = 0; k < L2_UNROLL; k++) { + fvec weight = simd::load_fvec(&net.l2_weights[nbucket][j][i + k * FLOATS_PER_VEC]); + sums[k] = simd::fma_f32(val, weight, sums[k]); + } + } + + for (int j = 0; j < L2_UNROLL; j++) + simd::store_f32(&l3[i + j * FLOATS_PER_VEC], sums[j]); + } + + fvec sums[L3_UNROLL]; + for (int i = 0; i < L3_UNROLL; i++) + sums[i] = f_zero; + + for (int i = 0; i < L3_SIZE; i += FLOATS_PER_VEC) { + fvec val = simd::load_fvec(&l3[i]); + + val = simd::clamp_f32(val, f_zero, f_clip); + + fvec weight = simd::load_fvec(&net.output_weights[nbucket][i]); + + int idx = i / FLOATS_PER_VEC % L3_UNROLL; + sums[idx] = simd::fma_f32(simd::mul_f32(val, val), weight, sums[idx]); + } - sum = _mm256_add_epi32(sum, stm_res); - sum = _mm256_add_epi32(sum, ntm_res); + int num = L3_UNROLL; + while (num > 1) { + num /= 2; + for (int i = 0; i < num; i++) + sums[i] = simd::add_f32(sums[i], sums[i + num]); } - __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); - int32_t score = _mm_cvtsi128_si32(sum_128); + float score = simd::reduce_add_ps(sums[0]) + net.output_biases[nbucket]; - score /= QA; - score += net.output_bias[nbucket]; - score *= SCALE; - score /= QA * QB; - return score; + return score * SCALE; } diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index 6122ed09..fad0b948 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -4,9 +4,11 @@ #define INPUT_SIZE 768 #define NINPUTS 8 -#define HL_SIZE 1024 +#define L1_SIZE 1024 +#define L2_SIZE 16 +#define L3_SIZE 32 #define NBUCKETS 8 -#define SCALE 358 +#define SCALE 400 #define QA 255 #define QB 64 @@ -22,14 +24,21 @@ constexpr int IBUCKET_LAYOUT[] = { }; struct Accumulator { - int16_t val[HL_SIZE] = {}; + alignas(32) int16_t val[L1_SIZE] = {}; }; -struct Network { - int16_t accumulator_weights[INPUT_SIZE * NINPUTS][HL_SIZE]; - int16_t accumulator_biases[HL_SIZE]; - int16_t output_weights[NBUCKETS][HL_SIZE]; - int16_t output_bias[NBUCKETS]; +struct alignas(32) Network { + int16_t accumulator_weights[INPUT_SIZE * NINPUTS][L1_SIZE]; + int16_t accumulator_biases[L1_SIZE]; + + int8_t l1_weights[NBUCKETS][L2_SIZE][L1_SIZE]; + float l1_biases[NBUCKETS][L2_SIZE]; + + float l2_weights[NBUCKETS][L2_SIZE][L3_SIZE]; + float l2_biases[NBUCKETS][L3_SIZE]; + + float output_weights[NBUCKETS][L3_SIZE]; + float output_biases[NBUCKETS]; void load(); }; diff --git a/engine/nnue/simd.cpp b/engine/nnue/simd.cpp new file mode 100644 index 00000000..49edd0ea --- /dev/null +++ b/engine/nnue/simd.cpp @@ -0,0 +1,5 @@ +#if defined(__AVX512BW__) +#include "avx512.hpp" +#else +#include "avx2.hpp" +#endif diff --git a/engine/nnue/simd.hpp b/engine/nnue/simd.hpp new file mode 100644 index 00000000..fd406898 --- /dev/null +++ b/engine/nnue/simd.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +#if defined(__AVX512BW__) + +using ivec = __m512i; +using fvec = __m512; +#define VEC_SIZE 512 + +#define L1_UNROLL 4 +#define L2_UNROLL 2 +#define L3_UNROLL 1 + +#else + +using ivec = __m256i; +using fvec = __m256; +#define VEC_SIZE 256 + +#define L1_UNROLL 4 +#define L2_UNROLL 4 +#define L3_UNROLL 2 + +#endif + +#define BYTES_PER_VEC (VEC_SIZE / 8) +#define SHORTS_PER_VEC (VEC_SIZE / 16) +#define FLOATS_PER_VEC (VEC_SIZE / 32) + +namespace simd { + ivec setzero_ivec(); + fvec setzero_fvec(); + + ivec broadcast_i16(int16_t x); + fvec broadcast_f32(float x); + + ivec load_ivec(const ivec *p); + fvec load_fvec(const float *p); + + ivec clamp_i16(ivec x, ivec lo, ivec hi); + fvec clamp_f32(fvec x, fvec lo, fvec hi); + + ivec shift_mulhi(ivec a, ivec b); + ivec accdp_u8i8_i16(ivec a, ivec b, ivec c); + + fvec cvt_i32_f32(ivec v); + + fvec fma_f32(fvec a, fvec b, fvec c); + fvec mul_f32(fvec a, fvec b); + fvec add_f32(fvec a, fvec b); + + void store_f32(float *p, fvec v); + + void store_u16_u8(uint8_t *p, ivec v); + float reduce_add_ps(fvec v); + int32_t reduce_add_epi16(ivec v); +}; diff --git a/engine/params.hpp b/engine/params.hpp index b819b734..2c5f9561 100644 --- a/engine/params.hpp +++ b/engine/params.hpp @@ -35,53 +35,53 @@ void handle_set(std::string optionname, std::string optionvalue); inline int name() { return name##Param.value; } // Tunables go below -TUNE(qs_see, -10, -100, 100, 10); -TUNE(qsfp_margin, 150, 100, 500, 25); -TUNE(rfp_threshold, 58, 20, 110, 10); +TUNE(qs_see, -18, -100, 100, 10); +TUNE(qsfp_margin, 192, 100, 500, 25); +TUNE(rfp_threshold, 63, 20, 110, 10); TUNE(rfp_improving, 31, 10, 50, 8); TUNE(rfp_quad, 5, 1, 8, 1); -TUNE(rfp_cutnode, 21, 5, 30, 5); -TUNE(nmp_margin, 202, 50, 400, 25); -TUNE(nmp_depth, 17, 0, 40, 5); -TUNE(razor_margin, 228, 150, 400, 20); -TUNE(probcut_margin, 349, 200, 500, 30); -TUNE(dext_base, 22, 10, 50, 5); -TUNE(dext_capt, 2, -100, 100, 10); -TUNE(dext_pv, 11, -100, 100, 10); -TUNE(dext_improving, -2, -50, 50, 10); -TUNE(text_base, 92, 50, 200, 10); -TUNE(text_capt, -5, -200, 200, 20); -TUNE(fp_const, 284, 100, 500, 30); -TUNE(fp_depth, 96, 20, 200, 10); -TUNE(fp_hist, 32, 10, 64, 5); -TUNE(history_margin, 2990, 1000, 4000, 256); -TUNE(see_quad, 25, 10, 40, 4); -TUNE(see_lin, 58, 20, 80, 6); -TUNE(lmr_base, 163, -1024, 1024, 160); -TUNE(lmr_pv, 1055, 512, 2048, 128); -TUNE(lmr_cutnode, 1940, 0, 3072, 256); -TUNE(lmr_cutnode_nott, -303, -1024, 1024, 256); -TUNE(lmr_cutoffcnt, 714, 0, 2048, 128); -TUNE(lmr_ttpv, 959, 0, 2048, 128); -TUNE(lmr_killer, 717, 0, 2048, 128); +TUNE(rfp_cutnode, 15, 5, 30, 5); +TUNE(nmp_margin, 214, 50, 400, 25); +TUNE(nmp_depth, 13, 0, 40, 5); +TUNE(razor_margin, 227, 150, 400, 20); +TUNE(probcut_margin, 377, 200, 500, 30); +TUNE(dext_base, 24, 10, 50, 5); +TUNE(dext_capt, -2, -100, 100, 10); +TUNE(dext_pv, 16, -100, 100, 10); +TUNE(dext_improving, 12, -50, 50, 10); +TUNE(text_base, 111, 50, 200, 10); +TUNE(text_capt, -11, -200, 200, 20); +TUNE(fp_const, 282, 100, 500, 30); +TUNE(fp_depth, 92, 20, 200, 10); +TUNE(fp_hist, 30, 10, 64, 5); +TUNE(history_margin, 3014, 1000, 4000, 256); +TUNE(see_quad, 17, 10, 40, 4); +TUNE(see_lin, 60, 20, 80, 6); +TUNE(lmr_base, 223, -1024, 1024, 160); +TUNE(lmr_pv, 1109, 512, 2048, 128); +TUNE(lmr_cutnode, 2286, 0, 3072, 256); +TUNE(lmr_cutnode_nott, -643, -1024, 1024, 256); +TUNE(lmr_cutoffcnt, 835, 0, 2048, 128); +TUNE(lmr_ttpv, 896, 0, 2048, 128); +TUNE(lmr_killer, 835, 0, 2048, 128); TUNE(lmr_hist, 12, 1, 16, 2); -TUNE(lmr_ttcapt, 1019, 0, 2048, 128); -TUNE(lmr_ttpv_alpha, -84, -1024, 1024, 256); -TUNE(dodeeper_margin, 91, 20, 200, 10); -TUNE(doshallower_margin, -8, -64, 64, 8); -TUNE(hist_large_margin, 100, 0, 256, 16); +TUNE(lmr_ttcapt, 914, 0, 2048, 128); +TUNE(lmr_ttpv_alpha, -34, -1024, 1024, 256); +TUNE(dodeeper_margin, 109, 20, 200, 10); +TUNE(doshallower_margin, -3, -64, 64, 8); +TUNE(hist_large_margin, 96, 0, 256, 16); TUNE(hist_quad, 3, 0, 6, 1); -TUNE(hist_lin, 122, 64, 256, 15); -TUNE(hist_const, 120, 0, 256, 16); +TUNE(hist_lin, 131, 64, 256, 15); +TUNE(hist_const, 144, 0, 256, 16); TUNE(asp_window, 14, 1, 30, 4); -TUNE(bm_base, 182, 80, 250, 15); -TUNE(bm_mul, 33, 10, 80, 10); -TUNE(corr_ps, 133, 64, 256, 12); -TUNE(corr_np, 139, 64, 256, 12); -TUNE(corr_maj, 62, 32, 128, 8); -TUNE(corr_min, 73, 32, 128, 8); -TUNE(corr_cont, 132, 64, 256, 16); -TUNE(corr_cont2, 157, 64, 256, 16); -TUNE(corr_threat, 100, 64, 256, 16); -TUNE(probcut_see, 102, 50, 200, 10); -TUNE(badnoisy_div, 51, 10, 100, 8); \ No newline at end of file +TUNE(bm_base, 180, 80, 250, 15); +TUNE(bm_mul, 20, 10, 80, 10); +TUNE(corr_ps, 128, 64, 256, 12); +TUNE(corr_np, 141, 64, 256, 12); +TUNE(corr_maj, 63, 32, 128, 8); +TUNE(corr_min, 84, 32, 128, 8); +TUNE(corr_cont, 137, 64, 256, 16); +TUNE(corr_cont2, 143, 64, 256, 16); +TUNE(corr_threat, 103, 64, 256, 16); +TUNE(probcut_see, 84, 50, 200, 10); +TUNE(badnoisy_div, 48, 10, 100, 8); \ No newline at end of file