Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ __attribute__((weak)) int main(int argc, char *argv[]) {
for (int i = 0; i < 8; i++) {
pos.reset_startpos();
pos.mailbox[SQ_A2 + i] = NO_PIECE;
am.refresh_finny(pos, am.idx);
am.full_refresh(pos, 0);
Value score = eval(pos, am);
int diff = startpos_score - score;
tot += diff;
Expand Down
24 changes: 12 additions & 12 deletions engine/nnue/accumulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
void AccumulatorManager::AccumulatorPair::update_add(Square sq, PieceType pt, bool side, int wbucket, int bbucket) {
uint16_t w_index = calculate_index(sq, pt, side, 0, wbucket);
uint16_t b_index = calculate_index(sq, pt, side, 1, bbucket);
for (int i = 0; i < HL_SIZE; i++) {
for (int i = 0; i < L1_SIZE; i++) {
w_acc.val[i] += nnue_network.accumulator_weights[w_index][i];
b_acc.val[i] += nnue_network.accumulator_weights[b_index][i];
}
Expand All @@ -12,15 +12,15 @@ void AccumulatorManager::AccumulatorPair::update_add(Square sq, PieceType pt, bo
void AccumulatorManager::AccumulatorPair::update_sub(Square sq, PieceType pt, bool side, int wbucket, int bbucket) {
uint16_t w_index = calculate_index(sq, pt, side, 0, wbucket);
uint16_t b_index = calculate_index(sq, pt, side, 1, bbucket);
for (int i = 0; i < HL_SIZE; i++) {
for (int i = 0; i < L1_SIZE; i++) {
w_acc.val[i] -= nnue_network.accumulator_weights[w_index][i];
b_acc.val[i] -= nnue_network.accumulator_weights[b_index][i];
}
}

void AccumulatorManager::full_refresh(Position &pos, int index) {
// Init the first accumulator so we have a basepoint
for (int i = 0; i < HL_SIZE; i++) {
for (int i = 0; i < L1_SIZE; i++) {
accs[index].w_acc.val[i] = nnue_network.accumulator_biases[i];
accs[index].b_acc.val[i] = nnue_network.accumulator_biases[i];
}
Expand Down Expand Up @@ -57,7 +57,7 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) {
Accumulator &w_acc = accs[index].w_acc;
Accumulator &b_acc = accs[index].b_acc;

for (int i = 0; i < HL_SIZE; i++) {
for (int i = 0; i < L1_SIZE; i++) {
w_acc.val[i] = f_w_acc.val[i];
b_acc.val[i] = f_b_acc.val[i];
}
Expand All @@ -75,15 +75,15 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) {
if (piece != NO_PIECE) {
// Add to accumulator
int index = calculate_index((Square)i, pt, side, 0, winbucket);
for (int k = 0; k < HL_SIZE; k++) {
for (int k = 0; k < L1_SIZE; k++) {
w_acc.val[k] += nnue_network.accumulator_weights[index][k];
}
}

if (prev_w_piece != NO_PIECE) {
// Remove from accumulator
int index = calculate_index((Square)i, prev_w_pt, prev_w_side, 0, winbucket);
for (int k = 0; k < HL_SIZE; k++) {
for (int k = 0; k < L1_SIZE; k++) {
w_acc.val[k] -= nnue_network.accumulator_weights[index][k];
}
}
Expand All @@ -97,23 +97,23 @@ void AccumulatorManager::refresh_finny(Position &pos, int index) {
if (piece != NO_PIECE) {
// Add to accumulator
int index = calculate_index((Square)i, pt, side, 1, binbucket);
for (int k = 0; k < HL_SIZE; k++) {
for (int k = 0; k < L1_SIZE; k++) {
b_acc.val[k] += nnue_network.accumulator_weights[index][k];
}
}

if (prev_b_piece != NO_PIECE) {
// Remove from accumulator
int index = calculate_index((Square)i, prev_b_pt, prev_b_side, 1, binbucket);
for (int k = 0; k < HL_SIZE; k++) {
for (int k = 0; k < L1_SIZE; k++) {
b_acc.val[k] -= nnue_network.accumulator_weights[index][k];
}
}
}
}

// Update finny tables
for (int i = 0; i < HL_SIZE; i++) {
for (int i = 0; i < L1_SIZE; i++) {
f_w_acc.val[i] = w_acc.val[i];
f_b_acc.val[i] = b_acc.val[i];
}
Expand Down Expand Up @@ -158,19 +158,19 @@ void AccumulatorManager::apply_lazy(Position &pos) {
auto &u = updates[i];
if (u.deltas == 2) {
// -+
for (int k = 0; k < HL_SIZE; k++) {
for (int k = 0; k < L1_SIZE; k++) {
accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] + nnue_network.accumulator_weights[u.w_deltas[1]][k];
accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] + nnue_network.accumulator_weights[u.b_deltas[1]][k];
}
} else if (u.deltas == 3) {
// --+
for (int k = 0; k < HL_SIZE; k++) {
for (int k = 0; k < L1_SIZE; k++) {
accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] - nnue_network.accumulator_weights[u.w_deltas[1]][k] + nnue_network.accumulator_weights[u.w_deltas[2]][k];
accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] - nnue_network.accumulator_weights[u.b_deltas[1]][k] + nnue_network.accumulator_weights[u.b_deltas[2]][k];
}
} else if (u.deltas == 4) {
// --++
for (int k = 0; k < HL_SIZE; k++) {
for (int k = 0; k < L1_SIZE; k++) {
accs[i].w_acc.val[k] = accs[i-1].w_acc.val[k] - nnue_network.accumulator_weights[u.w_deltas[0]][k] - nnue_network.accumulator_weights[u.w_deltas[1]][k] + nnue_network.accumulator_weights[u.w_deltas[2]][k] + nnue_network.accumulator_weights[u.w_deltas[3]][k];
accs[i].b_acc.val[k] = accs[i-1].b_acc.val[k] - nnue_network.accumulator_weights[u.b_deltas[0]][k] - nnue_network.accumulator_weights[u.b_deltas[1]][k] + nnue_network.accumulator_weights[u.b_deltas[2]][k] + nnue_network.accumulator_weights[u.b_deltas[3]][k];
}
Expand Down
2 changes: 1 addition & 1 deletion engine/nnue/accumulator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct AccumulatorManager {
std::fill(&mailboxes[0][0][0], &mailboxes[0][0][0] + NINPUTS * 2 * 2 * 64, NO_PIECE);

for (int i = 0; i < NINPUTS * 2; i++) {
for (int j = 0; j < HL_SIZE; j++) {
for (int j = 0; j < L1_SIZE; j++) {
accs[i].w_acc.val[j] = nnue_network.accumulator_biases[j];
accs[i].b_acc.val[j] = nnue_network.accumulator_biases[j];
}
Expand Down
96 changes: 96 additions & 0 deletions engine/nnue/avx2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#pragma once

#include "simd.hpp"

ivec simd::setzero_ivec() {
return _mm256_setzero_si256();
}

fvec simd::setzero_fvec() {
return _mm256_setzero_ps();
}

ivec simd::broadcast_i16(int16_t x) {
return _mm256_set1_epi16(x);
}

fvec simd::broadcast_f32(float x) {
return _mm256_set1_ps(x);
}

ivec simd::load_ivec(const ivec *p) {
return _mm256_loadu_si256(p);
}

fvec simd::load_fvec(const float *p) {
return _mm256_loadu_ps(p);
}

ivec simd::clamp_i16(ivec x, ivec lo, ivec hi) {
x = _mm256_max_epi16(x, lo);
x = _mm256_min_epi16(x, hi);
return x;
}

fvec simd::clamp_f32(fvec x, fvec lo, fvec hi) {
x = _mm256_max_ps(x, lo);
x = _mm256_min_ps(x, hi);
return x;
}

ivec simd::shift_mulhi(ivec a, ivec b) {
a = _mm256_slli_epi16(a, 7);
return _mm256_mulhrs_epi16(a, b);
}

ivec simd::accdp_u8i8_i16(ivec a, ivec b, ivec c) {
ivec sum = _mm256_maddubs_epi16(a, b);
return _mm256_add_epi16(sum, c);
Comment thread
kevlu8 marked this conversation as resolved.
}

fvec simd::cvt_i32_f32(ivec v) {
return _mm256_cvtepi32_ps(v);
}

fvec simd::fma_f32(fvec a, fvec b, fvec c) {
return _mm256_fmadd_ps(a, b, c);
}

fvec simd::mul_f32(fvec a, fvec b) {
return _mm256_mul_ps(a, b);
}

fvec simd::add_f32(fvec a, fvec b) {
return _mm256_add_ps(a, b);
}

void simd::store_f32(float *p, fvec v) {
_mm256_storeu_ps(p, v);
}

void simd::store_u16_u8(uint8_t *p, ivec v) {
__m256i res = _mm256_packus_epi16(v, v);
res = _mm256_permute4x64_epi64(res, _MM_PERM_DCCA);

_mm_storeu_si128((__m128i *)p, _mm256_castsi256_si128(res));
}

float simd::reduce_add_ps(fvec v) {
__m128 sum = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
sum = _mm_add_ps(sum, _mm_movehdup_ps(sum));
sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));

return _mm_cvtss_f32(sum);
}

int32_t simd::reduce_add_epi16(ivec v) {
const __m256i ones = _mm256_set1_epi16(1);

__m256i wide = _mm256_madd_epi16(v, ones);

__m128i sum = _mm_add_epi32(_mm256_castsi256_si128(wide), _mm256_extracti128_si256(wide, 1));
sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2)));
sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1)));

return _mm_cvtsi128_si32(sum);
}
96 changes: 96 additions & 0 deletions engine/nnue/avx512.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#pragma once

#if defined(__AVX512BW__)

#include "simd.hpp"

ivec simd::setzero_ivec() {
return _mm512_setzero_si512();
}

fvec simd::setzero_fvec() {
return _mm512_setzero_ps();
}

ivec simd::broadcast_i16(int16_t x) {
return _mm512_set1_epi16(x);
}

fvec simd::broadcast_f32(float x) {
return _mm512_set1_ps(x);
}

ivec simd::load_ivec(const ivec *p) {
return _mm512_loadu_si512(p);
}

fvec simd::load_fvec(const float *p) {
return _mm512_loadu_ps(p);
}

ivec simd::clamp_i16(ivec x, ivec lo, ivec hi) {
x = _mm512_max_epi16(x, lo);
x = _mm512_min_epi16(x, hi);
return x;
}

fvec simd::clamp_f32(fvec x, fvec lo, fvec hi) {
x = _mm512_max_ps(x, lo);
x = _mm512_min_ps(x, hi);
return x;
}

ivec simd::shift_mulhi(ivec a, ivec b) {
a = _mm512_slli_epi16(a, 7);
return _mm512_mulhrs_epi16(a, b);
}

ivec simd::accdp_u8i8_i16(ivec a, ivec b, ivec c) {
#if defined(__AVX512VNNI__)
return _mm512_dpbusd_epi32(c, a, b);
#else
ivec sum = _mm512_maddubs_epi16(a, b);
return _mm512_add_epi16(sum, c);
#endif
}
Comment thread
kevlu8 marked this conversation as resolved.

fvec simd::cvt_i32_f32(ivec v) {
return _mm512_cvtepi32_ps(v);
}

fvec simd::fma_f32(fvec a, fvec b, fvec c) {
return _mm512_fmadd_ps(a, b, c);
}

fvec simd::mul_f32(fvec a, fvec b) {
return _mm512_mul_ps(a, b);
}

fvec simd::add_f32(fvec a, fvec b) {
return _mm512_add_ps(a, b);
}

void simd::store_f32(float *p, fvec v) {
_mm512_storeu_ps(p, v);
}

void simd::store_u16_u8(uint8_t *p, ivec v) {
_mm256_storeu_si256((__m256i *)p, _mm512_cvtusepi16_epi8(v));
}

float simd::reduce_add_ps(fvec v) {
return _mm512_reduce_add_ps(v);
}

int32_t simd::reduce_add_epi16(ivec v) {
#if defined(__AVX512VNNI__)
__m512i wide = v;
#else
const __m512i ones = _mm512_set1_epi16(1);
__m512i wide = _mm512_madd_epi16(v, ones);
#endif

return _mm512_reduce_add_epi32(wide);
}

#endif
Loading
Loading