diff --git a/src/common/algorithm/ReservoirSampling.h b/src/common/algorithm/ReservoirSampling.h index 7ad90edcbad..5e9d1e3e5b5 100644 --- a/src/common/algorithm/ReservoirSampling.h +++ b/src/common/algorithm/ReservoirSampling.h @@ -33,7 +33,7 @@ class ReservoirSampling final { ++cnt_; return true; } else { - auto index = folly::Random::rand64(cnt_); + auto index = folly::Random::rand64(cnt_ + 1); if (index < num_) { samples_[index] = (std::move(sample)); ++cnt_; diff --git a/src/common/algorithm/test/ReservoirSamplingTest.cpp b/src/common/algorithm/test/ReservoirSamplingTest.cpp index 0486b473ea7..57b153f1e9d 100644 --- a/src/common/algorithm/test/ReservoirSamplingTest.cpp +++ b/src/common/algorithm/test/ReservoirSamplingTest.cpp @@ -39,6 +39,21 @@ TEST(ReservoirSamplingTest, Sample) { EXPECT_EQ(2, result[2]); } } + { + std::unordered_set hit; + for (size_t time = 0; time < 1024; time++) { + ReservoirSampling sampler(1); + sampler.sampling(1); + sampler.sampling(2); + auto result = sampler.samples(); + EXPECT_EQ(1, result.size()); + EXPECT_TRUE(result[0] == 1 || result[0] == 2); + hit.insert(result[0]); + } + EXPECT_EQ(2, hit.size()); + EXPECT_TRUE(hit.find(1) != hit.end()); + EXPECT_TRUE(hit.find(2) != hit.end()); + } } } // namespace algorithm } // namespace nebula