Skip to content

Commit fb4f8a1

Browse files
New sampling flag for disjoint sampling (#5342)
Admin merging due while we refactor `rapids-bot` to assist with `NBS` This PR modifies PLC, C and C++ APIs to support a flag for disjoint sampling. The use of that parameter is going to part of a follow-up PR. Supports #5042
1 parent 68a5a44 commit fb4f8a1

15 files changed

+134
-32
lines changed

cpp/include/cugraph/sampling_functions.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ struct sampling_flags_t {
7373
*/
7474
temporal_sampling_comparison_t temporal_sampling_comparison{
7575
temporal_sampling_comparison_t::STRICTLY_INCREASING};
76+
77+
/**
78+
* Specifies if disjoint sampling should be enforced. Default is false.
79+
*/
80+
bool disjoint_sampling{false};
7681
};
7782

7883
/**

cpp/include/cugraph_c/sampling_algorithms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,17 @@ void cugraph_sampling_set_dedupe_sources(cugraph_sampling_options_t* options, bo
309309
void cugraph_sampling_set_temporal_sampling_comparison(
310310
cugraph_sampling_options_t* options, cugraph_temporal_sampling_comparison_t comparison);
311311

312+
/**
313+
* @ingroup samplingC
314+
* @brief Set flag to perform disjoint sampling
315+
*
316+
* Note: This flag is not supported in the current implementation.
317+
*
318+
* @param options - opaque pointer to the sampling options
319+
* @param value - Boolean value to assign to the option
320+
*/
321+
void cugraph_sampling_set_disjoint_sampling(cugraph_sampling_options_t* options, bool_t value);
322+
312323
/**
313324
* @ingroup samplingC
314325
* @brief Free sampling options object

cpp/src/c_api/neighbor_sampling.cpp

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,13 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
279279
: std::nullopt,
280280
raft::host_span<const int>(fan_out_->as_type<const int>(), fan_out_->size_),
281281
num_edge_types_,
282-
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
283-
options_.return_hops_ == TRUE,
284-
options_.dedupe_sources_ == TRUE,
285-
options_.with_replacement_ == TRUE},
282+
cugraph::sampling_flags_t{
283+
options_.prior_sources_behavior_,
284+
options_.return_hops_ == TRUE,
285+
options_.dedupe_sources_ == TRUE,
286+
options_.with_replacement_ == TRUE,
287+
cugraph::temporal_sampling_comparison_t::STRICTLY_INCREASING,
288+
options_.disjoint_sampling_ == TRUE},
286289
do_expensive_check_);
287290
} else {
288291
std::tie(sampled_srcs,
@@ -309,10 +312,13 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
309312
: std::nullopt,
310313
raft::host_span<const int>(fan_out_->as_type<const int>(), fan_out_->size_),
311314
num_edge_types_,
312-
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
313-
options_.return_hops_ == TRUE,
314-
options_.dedupe_sources_ == TRUE,
315-
options_.with_replacement_ == TRUE},
315+
cugraph::sampling_flags_t{
316+
options_.prior_sources_behavior_,
317+
options_.return_hops_ == TRUE,
318+
options_.dedupe_sources_ == TRUE,
319+
options_.with_replacement_ == TRUE,
320+
cugraph::temporal_sampling_comparison_t::STRICTLY_INCREASING,
321+
options_.disjoint_sampling_ == TRUE},
316322
do_expensive_check_);
317323
}
318324
} else {
@@ -342,10 +348,13 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
342348
(*label_to_comm_rank).data(), (*label_to_comm_rank).size()})
343349
: std::nullopt,
344350
raft::host_span<const int>(fan_out_->as_type<const int>(), fan_out_->size_),
345-
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
346-
options_.return_hops_ == TRUE,
347-
options_.dedupe_sources_ == TRUE,
348-
options_.with_replacement_ == TRUE},
351+
cugraph::sampling_flags_t{
352+
options_.prior_sources_behavior_,
353+
options_.return_hops_ == TRUE,
354+
options_.dedupe_sources_ == TRUE,
355+
options_.with_replacement_ == TRUE,
356+
cugraph::temporal_sampling_comparison_t::STRICTLY_INCREASING,
357+
options_.disjoint_sampling_ == TRUE},
349358
do_expensive_check_);
350359
} else {
351360
std::tie(sampled_srcs,
@@ -371,10 +380,13 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
371380
(*label_to_comm_rank).data(), (*label_to_comm_rank).size()})
372381
: std::nullopt,
373382
raft::host_span<const int>(fan_out_->as_type<const int>(), fan_out_->size_),
374-
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
375-
options_.return_hops_ == TRUE,
376-
options_.dedupe_sources_ == TRUE,
377-
options_.with_replacement_ == TRUE},
383+
cugraph::sampling_flags_t{
384+
options_.prior_sources_behavior_,
385+
options_.return_hops_ == TRUE,
386+
options_.dedupe_sources_ == TRUE,
387+
options_.with_replacement_ == TRUE,
388+
cugraph::temporal_sampling_comparison_t::STRICTLY_INCREASING,
389+
options_.disjoint_sampling_ == TRUE},
378390
do_expensive_check_);
379391
}
380392
}

cpp/src/c_api/sampling_common.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct cugraph_sampling_options_t {
2323
bool_t retain_seeds_{FALSE};
2424
cugraph_temporal_sampling_comparison_t temporal_sampling_comparison_{
2525
cugraph_temporal_sampling_comparison_t::STRICTLY_INCREASING};
26+
bool_t disjoint_sampling_{FALSE};
2627
};
2728

2829
struct sampling_flags_t {

cpp/src/c_api/sampling_result.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ extern "C" void cugraph_sampling_set_temporal_sampling_comparison(
9999
internal_pointer->temporal_sampling_comparison_ = value;
100100
}
101101

102+
extern "C" void cugraph_sampling_set_disjoint_sampling(cugraph_sampling_options_t* options,
103+
bool_t value)
104+
{
105+
auto internal_pointer = reinterpret_cast<cugraph::c_api::cugraph_sampling_options_t*>(options);
106+
internal_pointer->disjoint_sampling_ = value;
107+
}
108+
102109
extern "C" void cugraph_sampling_options_free(cugraph_sampling_options_t* options)
103110
{
104111
auto internal_pointer = reinterpret_cast<cugraph::c_api::cugraph_sampling_options_t*>(options);

cpp/src/c_api/temporal_neighbor_sampling.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,11 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
364364
raft::host_span<const int>(fan_out_->as_type<const int>(), fan_out_->size_),
365365
num_edge_types_,
366366
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
367-
options_.return_hops_,
368-
options_.dedupe_sources_,
369-
options_.with_replacement_,
370-
temporal_sampling_comparison},
367+
options_.return_hops_ == TRUE,
368+
options_.dedupe_sources_ == TRUE,
369+
options_.with_replacement_ == TRUE,
370+
temporal_sampling_comparison,
371+
options_.disjoint_sampling_ == TRUE},
371372
do_expensive_check_);
372373
} else {
373374
std::tie(sampled_edge_srcs,
@@ -404,10 +405,11 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
404405
raft::host_span<const int>(fan_out_->as_type<const int>(), fan_out_->size_),
405406
num_edge_types_,
406407
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
407-
options_.return_hops_,
408-
options_.dedupe_sources_,
409-
options_.with_replacement_,
410-
temporal_sampling_comparison},
408+
options_.return_hops_ == TRUE,
409+
options_.dedupe_sources_ == TRUE,
410+
options_.with_replacement_ == TRUE,
411+
temporal_sampling_comparison,
412+
options_.disjoint_sampling_ == TRUE},
411413
do_expensive_check_);
412414
}
413415
} else {
@@ -447,10 +449,11 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
447449
: std::nullopt,
448450
raft::host_span<const int>(fan_out_->as_type<const int>(), fan_out_->size_),
449451
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
450-
options_.return_hops_,
451-
options_.dedupe_sources_,
452-
options_.with_replacement_,
453-
temporal_sampling_comparison},
452+
options_.return_hops_ == TRUE,
453+
options_.dedupe_sources_ == TRUE,
454+
options_.with_replacement_ == TRUE,
455+
temporal_sampling_comparison,
456+
options_.disjoint_sampling_ == TRUE},
454457
do_expensive_check_);
455458
} else {
456459
std::tie(sampled_edge_srcs,
@@ -486,10 +489,11 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
486489
: std::nullopt,
487490
raft::host_span<const int>(fan_out_->as_type<const int>(), fan_out_->size_),
488491
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
489-
options_.return_hops_,
490-
options_.dedupe_sources_,
491-
options_.with_replacement_,
492-
temporal_sampling_comparison},
492+
options_.return_hops_ == TRUE,
493+
options_.dedupe_sources_ == TRUE,
494+
options_.with_replacement_ == TRUE,
495+
temporal_sampling_comparison,
496+
options_.disjoint_sampling_ == TRUE},
493497
do_expensive_check_);
494498
}
495499
}

python/pylibcugraph/pylibcugraph/_cugraph_c/algorithms.pxd

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,12 @@ cdef extern from "cugraph_c/algorithms.h":
340340
cugraph_compression_type_t value,
341341
)
342342

343+
cdef void \
344+
cugraph_sampling_set_disjoint_sampling(
345+
cugraph_sampling_options_t* options,
346+
bool_t value,
347+
)
348+
343349
cdef void \
344350
cugraph_sampling_options_free(
345351
cugraph_sampling_options_t* options,

python/pylibcugraph/pylibcugraph/heterogeneous_biased_neighbor_sample.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ from pylibcugraph._cugraph_c.algorithms cimport (
4545
cugraph_sampling_set_compress_per_hop,
4646
cugraph_sampling_set_compression_type,
4747
cugraph_sampling_set_retain_seeds,
48+
cugraph_sampling_set_disjoint_sampling,
4849
)
4950
from pylibcugraph._cugraph_c.sampling_algorithms cimport (
5051
cugraph_heterogeneous_biased_neighbor_sample,
@@ -85,6 +86,7 @@ def heterogeneous_biased_neighbor_sample(ResourceHandle resource_handle,
8586
bool_t do_expensive_check,
8687
prior_sources_behavior=None,
8788
deduplicate_sources=False,
89+
disjoint_sampling=False,
8890
return_hops=False,
8991
renumber=False,
9092
retain_seeds=False,
@@ -174,6 +176,10 @@ def heterogeneous_biased_neighbor_sample(ResourceHandle resource_handle,
174176
If True, will create a separate compressed edgelist per hop within
175177
a batch.
176178
179+
disjoint_sampling: bool (Optional)
180+
If True, enables disjoint sampling between seeds per hop when supported.
181+
Defaults to False.
182+
177183
random_state: int (Optional)
178184
Random state to use when generating samples. Optional argument,
179185
defaults to a hash of process id, time, and hostname.
@@ -357,6 +363,7 @@ def heterogeneous_biased_neighbor_sample(ResourceHandle resource_handle,
357363
cugraph_sampling_set_compression_type(sampling_options, compression_behavior_e)
358364
cugraph_sampling_set_compress_per_hop(sampling_options, c_compress_per_hop)
359365
cugraph_sampling_set_retain_seeds(sampling_options, retain_seeds)
366+
cugraph_sampling_set_disjoint_sampling(sampling_options, disjoint_sampling)
360367

361368
error_code = cugraph_heterogeneous_biased_neighbor_sample(
362369
c_resource_handle_ptr,

python/pylibcugraph/pylibcugraph/heterogeneous_biased_temporal_neighbor_sample.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ from pylibcugraph._cugraph_c.algorithms cimport (
4545
cugraph_sampling_set_compress_per_hop,
4646
cugraph_sampling_set_compression_type,
4747
cugraph_sampling_set_retain_seeds,
48+
cugraph_sampling_set_disjoint_sampling,
4849
)
4950
from pylibcugraph._cugraph_c.sampling_algorithms cimport (
5051
cugraph_heterogeneous_biased_temporal_neighbor_sample,
@@ -87,6 +88,7 @@ def heterogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle
8788
bool_t do_expensive_check,
8889
prior_sources_behavior=None,
8990
deduplicate_sources=False,
91+
disjoint_sampling=False,
9092
return_hops=False,
9193
renumber=False,
9294
retain_seeds=False,
@@ -197,6 +199,10 @@ def heterogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle
197199
defaults to a hash of process id, time, and hostname.
198200
(See pylibcugraph.random.CuGraphRandomState)
199201
202+
disjoint_sampling: bool (Optional)
203+
If True, enables disjoint sampling between seeds per hop when supported.
204+
Defaults to False.
205+
200206
Returns
201207
-------
202208
A tuple of device arrays, where the first and second items in the tuple
@@ -396,6 +402,7 @@ def heterogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle
396402
cugraph_sampling_set_compression_type(sampling_options, compression_behavior_e)
397403
cugraph_sampling_set_compress_per_hop(sampling_options, c_compress_per_hop)
398404
cugraph_sampling_set_retain_seeds(sampling_options, retain_seeds)
405+
cugraph_sampling_set_disjoint_sampling(sampling_options, disjoint_sampling)
399406

400407
error_code = cugraph_heterogeneous_biased_temporal_neighbor_sample(
401408
c_resource_handle_ptr,

python/pylibcugraph/pylibcugraph/heterogeneous_uniform_neighbor_sample.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ from pylibcugraph._cugraph_c.algorithms cimport (
4242
cugraph_sampling_set_compress_per_hop,
4343
cugraph_sampling_set_compression_type,
4444
cugraph_sampling_set_retain_seeds,
45+
cugraph_sampling_set_disjoint_sampling,
4546
)
4647
from pylibcugraph._cugraph_c.sampling_algorithms cimport (
4748
cugraph_heterogeneous_uniform_neighbor_sample,
@@ -82,6 +83,7 @@ def heterogeneous_uniform_neighbor_sample(ResourceHandle resource_handle,
8283
bool_t do_expensive_check,
8384
prior_sources_behavior=None,
8485
deduplicate_sources=False,
86+
disjoint_sampling=False,
8587
return_hops=False,
8688
renumber=False,
8789
retain_seeds=False,
@@ -169,6 +171,10 @@ def heterogeneous_uniform_neighbor_sample(ResourceHandle resource_handle,
169171
If True, will create a separate compressed edgelist per hop within
170172
a batch.
171173
174+
disjoint_sampling: bool (Optional)
175+
If True, enables disjoint sampling between seeds per hop when supported.
176+
Defaults to False.
177+
172178
random_state: int (Optional)
173179
Random state to use when generating samples. Optional argument,
174180
defaults to a hash of process id, time, and hostname.
@@ -350,6 +356,7 @@ def heterogeneous_uniform_neighbor_sample(ResourceHandle resource_handle,
350356
cugraph_sampling_set_compression_type(sampling_options, compression_behavior_e)
351357
cugraph_sampling_set_compress_per_hop(sampling_options, c_compress_per_hop)
352358
cugraph_sampling_set_retain_seeds(sampling_options, retain_seeds)
359+
cugraph_sampling_set_disjoint_sampling(sampling_options, disjoint_sampling)
353360

354361
error_code = cugraph_heterogeneous_uniform_neighbor_sample(
355362
c_resource_handle_ptr,

0 commit comments

Comments
 (0)