Skip to content

Commit b7d1cd9

Browse files
[FEA] Update pylibcugraph to support different temporal comparisons (#5345)
Updates pylibcugraph to support different temporal comparisons (i.e. >=, <, <=, >, last). Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Joseph Nke (https://github.com/jnke2016) - Chuck Hastings (https://github.com/ChuckHastings) - Rick Ratzel (https://github.com/rlratzel) URL: #5345
1 parent f90b24c commit b7d1cd9

File tree

6 files changed

+161
-9
lines changed

6 files changed

+161
-9
lines changed

python/pylibcugraph/pylibcugraph/_cugraph_c/algorithms.pxd

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,25 @@ cdef extern from "cugraph_c/algorithms.h":
286286
DCSR
287287
DCSC
288288

289+
ctypedef enum cugraph_temporal_sampling_comparison_t:
290+
STRICTLY_INCREASING=0
291+
MONOTONICALLY_INCREASING
292+
STRICTLY_DECREASING
293+
MONOTONICALLY_DECREASING
294+
LAST
295+
289296
cdef cugraph_error_code_t \
290297
cugraph_sampling_options_create(
291298
cugraph_sampling_options_t** options,
292299
cugraph_error_t** error,
293300
)
294301

302+
cdef void \
303+
cugraph_sampling_set_temporal_sampling_comparison(
304+
cugraph_sampling_options_t* options,
305+
cugraph_temporal_sampling_comparison_t comparison,
306+
)
307+
295308
cdef void \
296309
cugraph_sampling_set_renumber_results(
297310
cugraph_sampling_options_t* options,

python/pylibcugraph/pylibcugraph/heterogeneous_biased_temporal_neighbor_sample.pyx

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ from pylibcugraph._cugraph_c.algorithms cimport (
4545
cugraph_sampling_set_compress_per_hop,
4646
cugraph_sampling_set_compression_type,
4747
cugraph_sampling_set_retain_seeds,
48+
cugraph_sampling_set_temporal_sampling_comparison,
49+
cugraph_temporal_sampling_comparison_t,
4850
cugraph_sampling_set_disjoint_sampling,
4951
)
5052
from pylibcugraph._cugraph_c.sampling_algorithms cimport (
@@ -94,7 +96,8 @@ def heterogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle
9496
retain_seeds=False,
9597
compression='COO',
9698
compress_per_hop=False,
97-
random_state=None):
99+
random_state=None,
100+
temporal_sampling_comparison='strictly_increasing'):
98101
"""
99102
Performs biased temporal neighborhood sampling, which samples nodes from
100103
a graph based on the current node's neighbors, with a corresponding fan_out
@@ -199,6 +202,10 @@ def heterogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle
199202
defaults to a hash of process id, time, and hostname.
200203
(See pylibcugraph.random.CuGraphRandomState)
201204
205+
temporal_sampling_comparison: str (Optional)
206+
Options: 'strictly_increasing' (default), 'strictly_decreasing', 'monotonically_increasing', 'monotonically_decreasing', 'last'
207+
Sets the comparison operator for temporal sampling.
208+
202209
disjoint_sampling: bool (Optional)
203210
If True, enables disjoint sampling between seeds per hop when supported.
204211
Defaults to False.
@@ -404,6 +411,21 @@ def heterogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle
404411
cugraph_sampling_set_retain_seeds(sampling_options, retain_seeds)
405412
cugraph_sampling_set_disjoint_sampling(sampling_options, disjoint_sampling)
406413

414+
cdef cugraph_temporal_sampling_comparison_t temporal_sampling_comparison_e
415+
if temporal_sampling_comparison is None or temporal_sampling_comparison == 'strictly_increasing':
416+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.STRICTLY_INCREASING
417+
elif temporal_sampling_comparison == 'strictly_decreasing':
418+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.STRICTLY_DECREASING
419+
elif temporal_sampling_comparison == 'monotonically_increasing':
420+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.MONOTONICALLY_INCREASING
421+
elif temporal_sampling_comparison == 'monotonically_decreasing':
422+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.MONOTONICALLY_DECREASING
423+
elif temporal_sampling_comparison == "last":
424+
raise NotImplementedError('The "last" comparison type is currently unsupported.')
425+
else:
426+
raise ValueError(f'Invalid option {temporal_sampling_comparison} for temporal sampling comparison')
427+
cugraph_sampling_set_temporal_sampling_comparison(sampling_options, temporal_sampling_comparison_e)
428+
407429
error_code = cugraph_heterogeneous_biased_temporal_neighbor_sample(
408430
c_resource_handle_ptr,
409431
rng_state_ptr,

python/pylibcugraph/pylibcugraph/heterogeneous_uniform_temporal_neighbor_sample.pyx

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ from pylibcugraph._cugraph_c.algorithms cimport (
4242
cugraph_sampling_set_compress_per_hop,
4343
cugraph_sampling_set_compression_type,
4444
cugraph_sampling_set_retain_seeds,
45+
cugraph_sampling_set_temporal_sampling_comparison,
46+
cugraph_temporal_sampling_comparison_t,
4547
cugraph_sampling_set_disjoint_sampling,
4648
)
4749
from pylibcugraph._cugraph_c.sampling_algorithms cimport (
@@ -91,7 +93,8 @@ def heterogeneous_uniform_temporal_neighbor_sample(ResourceHandle resource_handl
9193
retain_seeds=False,
9294
compression='COO',
9395
compress_per_hop=False,
94-
random_state=None):
96+
random_state=None,
97+
temporal_sampling_comparison='strictly_increasing'):
9598
"""
9699
Performs uniform temporal neighborhood sampling, which samples nodes from
97100
a graph based on the current node's neighbors, with a corresponding fan_out
@@ -194,6 +197,10 @@ def heterogeneous_uniform_temporal_neighbor_sample(ResourceHandle resource_handl
194197
defaults to a hash of process id, time, and hostname.
195198
(See pylibcugraph.random.CuGraphRandomState)
196199
200+
temporal_sampling_comparison: str (Optional)
201+
Options: 'strictly_increasing' (default), 'strictly_decreasing', 'monotonically_increasing', 'monotonically_decreasing', 'last'
202+
Sets the comparison operator for temporal sampling.
203+
197204
disjoint_sampling: bool (Optional)
198205
If True, enables disjoint sampling between seeds per hop when supported.
199206
Defaults to False.
@@ -400,6 +407,21 @@ def heterogeneous_uniform_temporal_neighbor_sample(ResourceHandle resource_handl
400407
cugraph_sampling_set_retain_seeds(sampling_options, retain_seeds)
401408
cugraph_sampling_set_disjoint_sampling(sampling_options, disjoint_sampling)
402409

410+
cdef cugraph_temporal_sampling_comparison_t temporal_sampling_comparison_e
411+
if temporal_sampling_comparison is None or temporal_sampling_comparison == 'strictly_increasing':
412+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.STRICTLY_INCREASING
413+
elif temporal_sampling_comparison == 'strictly_decreasing':
414+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.STRICTLY_DECREASING
415+
elif temporal_sampling_comparison == 'monotonically_increasing':
416+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.MONOTONICALLY_INCREASING
417+
elif temporal_sampling_comparison == 'monotonically_decreasing':
418+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.MONOTONICALLY_DECREASING
419+
elif temporal_sampling_comparison == "last":
420+
raise NotImplementedError('The "last" comparison type is currently unsupported.')
421+
else:
422+
raise ValueError(f'Invalid option {temporal_sampling_comparison} for temporal sampling comparison')
423+
cugraph_sampling_set_temporal_sampling_comparison(sampling_options, temporal_sampling_comparison_e)
424+
403425
error_code = cugraph_heterogeneous_uniform_temporal_neighbor_sample(
404426
c_resource_handle_ptr,
405427
rng_state_ptr,

python/pylibcugraph/pylibcugraph/homogeneous_biased_temporal_neighbor_sample.pyx

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ from pylibcugraph._cugraph_c.algorithms cimport (
4545
cugraph_sampling_set_compress_per_hop,
4646
cugraph_sampling_set_compression_type,
4747
cugraph_sampling_set_retain_seeds,
48+
cugraph_sampling_set_temporal_sampling_comparison,
49+
cugraph_temporal_sampling_comparison_t,
4850
cugraph_sampling_set_disjoint_sampling,
4951
)
5052
from pylibcugraph._cugraph_c.sampling_algorithms cimport (
@@ -92,7 +94,8 @@ def homogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle,
9294
retain_seeds=False,
9395
compression='COO',
9496
compress_per_hop=False,
95-
random_state=None):
97+
random_state=None,
98+
temporal_sampling_comparison='strictly_increasing'):
9699
"""
97100
Performs biased temporal neighborhood sampling, which samples nodes from
98101
a graph based on the current node's neighbors, with a corresponding fan_out
@@ -191,6 +194,10 @@ def homogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle,
191194
defaults to a hash of process id, time, and hostname.
192195
(See pylibcugraph.random.CuGraphRandomState)
193196
197+
temporal_sampling_comparison: str (Optional)
198+
Options: 'strictly_increasing' (default), 'strictly_decreasing', 'monotonically_increasing', 'monotonically_decreasing', 'last'
199+
Sets the comparison operator for temporal sampling.
200+
194201
Returns
195202
-------
196203
A tuple of device arrays, where the first and second items in the tuple
@@ -380,6 +387,21 @@ def homogeneous_biased_temporal_neighbor_sample(ResourceHandle resource_handle,
380387
cugraph_sampling_set_retain_seeds(sampling_options, retain_seeds)
381388
cugraph_sampling_set_disjoint_sampling(sampling_options, disjoint_sampling)
382389

390+
cdef cugraph_temporal_sampling_comparison_t temporal_sampling_comparison_e
391+
if temporal_sampling_comparison is None or temporal_sampling_comparison == 'strictly_increasing':
392+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.STRICTLY_INCREASING
393+
elif temporal_sampling_comparison == 'strictly_decreasing':
394+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.STRICTLY_DECREASING
395+
elif temporal_sampling_comparison == 'monotonically_increasing':
396+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.MONOTONICALLY_INCREASING
397+
elif temporal_sampling_comparison == 'monotonically_decreasing':
398+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.MONOTONICALLY_DECREASING
399+
elif temporal_sampling_comparison == "last":
400+
raise NotImplementedError('The "last" comparison type is currently unsupported.')
401+
else:
402+
raise ValueError(f'Invalid option {temporal_sampling_comparison} for temporal sampling comparison')
403+
cugraph_sampling_set_temporal_sampling_comparison(sampling_options, temporal_sampling_comparison_e)
404+
383405
error_code = cugraph_homogeneous_biased_temporal_neighbor_sample(
384406
c_resource_handle_ptr,
385407
rng_state_ptr,

python/pylibcugraph/pylibcugraph/homogeneous_uniform_temporal_neighbor_sample.pyx

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ from pylibcugraph._cugraph_c.algorithms cimport (
4242
cugraph_sampling_set_compress_per_hop,
4343
cugraph_sampling_set_compression_type,
4444
cugraph_sampling_set_retain_seeds,
45+
cugraph_sampling_set_temporal_sampling_comparison,
46+
cugraph_temporal_sampling_comparison_t,
4547
cugraph_sampling_set_disjoint_sampling,
4648
)
4749
from pylibcugraph._cugraph_c.sampling_algorithms cimport (
@@ -89,7 +91,8 @@ def homogeneous_uniform_temporal_neighbor_sample(ResourceHandle resource_handle,
8991
retain_seeds=False,
9092
compression='COO',
9193
compress_per_hop=False,
92-
random_state=None):
94+
random_state=None,
95+
temporal_sampling_comparison='strictly_increasing'):
9396
"""
9497
Performs uniform temporal neighborhood sampling, which samples nodes from
9598
a graph based on the current node's neighbors, with a corresponding fan_out
@@ -186,6 +189,9 @@ def homogeneous_uniform_temporal_neighbor_sample(ResourceHandle resource_handle,
186189
defaults to a hash of process id, time, and hostname.
187190
(See pylibcugraph.random.CuGraphRandomState)
188191
192+
temporal_sampling_comparison: str (Optional)
193+
Options: 'strictly_increasing' (default), 'strictly_decreasing', 'monotonically_increasing', 'monotonically_decreasing', 'last'
194+
Sets the comparison operator for temporal sampling.
189195
Returns
190196
-------
191197
A tuple of device arrays, where the first and second items in the tuple
@@ -379,6 +385,21 @@ def homogeneous_uniform_temporal_neighbor_sample(ResourceHandle resource_handle,
379385
cugraph_sampling_set_retain_seeds(sampling_options, retain_seeds)
380386
cugraph_sampling_set_disjoint_sampling(sampling_options, disjoint_sampling)
381387

388+
cdef cugraph_temporal_sampling_comparison_t temporal_sampling_comparison_e
389+
if temporal_sampling_comparison is None or temporal_sampling_comparison == 'strictly_increasing':
390+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.STRICTLY_INCREASING
391+
elif temporal_sampling_comparison == 'strictly_decreasing':
392+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.STRICTLY_DECREASING
393+
elif temporal_sampling_comparison == 'monotonically_increasing':
394+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.MONOTONICALLY_INCREASING
395+
elif temporal_sampling_comparison == 'monotonically_decreasing':
396+
temporal_sampling_comparison_e = cugraph_temporal_sampling_comparison_t.MONOTONICALLY_DECREASING
397+
elif temporal_sampling_comparison == "last":
398+
raise NotImplementedError('The "last" comparison type is currently unsupported.')
399+
else:
400+
raise ValueError(f'Invalid option {temporal_sampling_comparison} for temporal sampling comparison')
401+
cugraph_sampling_set_temporal_sampling_comparison(sampling_options, temporal_sampling_comparison_e)
402+
382403
error_code = cugraph_homogeneous_uniform_temporal_neighbor_sample(
383404
c_resource_handle_ptr,
384405
rng_state_ptr,

python/pylibcugraph/pylibcugraph/tests/test_temporal_neighbor_sample.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,16 @@ def _build_temporal_sg_with_edge_types(resource_handle: ResourceHandle) -> SGGra
6464
return G
6565

6666

67-
def test_homogeneous_uniform_temporal_none_times():
67+
@pytest.mark.parametrize(
68+
"temporal_sampling_comparison",
69+
[
70+
"strictly_increasing",
71+
"strictly_decreasing",
72+
"monotonically_increasing",
73+
"monotonically_decreasing",
74+
],
75+
)
76+
def test_homogeneous_uniform_temporal_none_times(temporal_sampling_comparison):
6877
rh = ResourceHandle()
6978
G = _build_temporal_sg(rh)
7079
starts = cp.asarray([1, 2], dtype=np.int32)
@@ -80,6 +89,7 @@ def test_homogeneous_uniform_temporal_none_times():
8089
fanout,
8190
with_replacement=False,
8291
do_expensive_check=True,
92+
temporal_sampling_comparison=temporal_sampling_comparison,
8393
)
8494
result = {k: v for k, v in result.items() if v is not None}
8595

@@ -89,7 +99,18 @@ def test_homogeneous_uniform_temporal_none_times():
8999
assert isinstance(result["edge_end_time"], cp.ndarray)
90100

91101

92-
def test_homogeneous_uniform_temporal_with_times_and_labels():
102+
@pytest.mark.parametrize(
103+
"temporal_sampling_comparison",
104+
[
105+
"strictly_increasing",
106+
"strictly_decreasing",
107+
"monotonically_increasing",
108+
"monotonically_decreasing",
109+
],
110+
)
111+
def test_homogeneous_uniform_temporal_with_times_and_labels(
112+
temporal_sampling_comparison,
113+
):
93114
rh = ResourceHandle()
94115
G = _build_temporal_sg(rh)
95116
starts = cp.asarray([1, 2, 1], dtype=np.int32)
@@ -107,14 +128,24 @@ def test_homogeneous_uniform_temporal_with_times_and_labels():
107128
fanout,
108129
with_replacement=False,
109130
do_expensive_check=True,
131+
temporal_sampling_comparison=temporal_sampling_comparison,
110132
)
111133
result = {k: v for k, v in result.items() if v is not None}
112134

113135
assert result["majors"].size == result["minors"].size
114136
assert result["edge_start_time"].size == result["edge_end_time"].size
115137

116138

117-
def test_homogeneous_biased_temporal_with_times():
139+
@pytest.mark.parametrize(
140+
"temporal_sampling_comparison",
141+
[
142+
"strictly_increasing",
143+
"strictly_decreasing",
144+
"monotonically_increasing",
145+
"monotonically_decreasing",
146+
],
147+
)
148+
def test_homogeneous_biased_temporal_with_times(temporal_sampling_comparison):
118149
rh = ResourceHandle()
119150
G = _build_temporal_sg(rh)
120151
starts = cp.asarray([0, 1], dtype=np.int32)
@@ -131,12 +162,22 @@ def test_homogeneous_biased_temporal_with_times():
131162
fanout,
132163
with_replacement=False,
133164
do_expensive_check=True,
165+
temporal_sampling_comparison=temporal_sampling_comparison,
134166
)
135167
result = {k: v for k, v in result.items() if v is not None}
136168
assert "edge_start_time" in result and "edge_end_time" in result
137169

138170

139-
def test_heterogeneous_uniform_temporal_none_times():
171+
@pytest.mark.parametrize(
172+
"temporal_sampling_comparison",
173+
[
174+
"strictly_increasing",
175+
"strictly_decreasing",
176+
"monotonically_increasing",
177+
"monotonically_decreasing",
178+
],
179+
)
180+
def test_heterogeneous_uniform_temporal_none_times(temporal_sampling_comparison):
140181
rh = ResourceHandle()
141182
G = _build_temporal_sg_with_edge_types(rh)
142183
starts = cp.asarray([1, 2], dtype=np.int32)
@@ -154,12 +195,22 @@ def test_heterogeneous_uniform_temporal_none_times():
154195
num_edge_types=2,
155196
with_replacement=False,
156197
do_expensive_check=True,
198+
temporal_sampling_comparison=temporal_sampling_comparison,
157199
)
158200
result = {k: v for k, v in result.items() if v is not None}
159201
assert "edge_type" in result and "edge_start_time" in result
160202

161203

162-
def test_heterogeneous_biased_temporal_with_times():
204+
@pytest.mark.parametrize(
205+
"temporal_sampling_comparison",
206+
[
207+
"strictly_increasing",
208+
"strictly_decreasing",
209+
"monotonically_increasing",
210+
"monotonically_decreasing",
211+
],
212+
)
213+
def test_heterogeneous_biased_temporal_with_times(temporal_sampling_comparison):
163214
rh = ResourceHandle()
164215
G = _build_temporal_sg_with_edge_types(rh)
165216
starts = cp.asarray([0, 1], dtype=np.int32)
@@ -178,6 +229,7 @@ def test_heterogeneous_biased_temporal_with_times():
178229
num_edge_types=2,
179230
with_replacement=False,
180231
do_expensive_check=True,
232+
temporal_sampling_comparison=temporal_sampling_comparison,
181233
)
182234
result = {k: v for k, v in result.items() if v is not None}
183235
assert (

0 commit comments

Comments
 (0)