diff --git a/python/gigl/distributed/graph_store/compute.py b/python/gigl/distributed/graph_store/compute.py index 719c580d..36a3b66d 100644 --- a/python/gigl/distributed/graph_store/compute.py +++ b/python/gigl/distributed/graph_store/compute.py @@ -1,4 +1,5 @@ import os +from typing import Optional import graphlearn_torch as glt import torch @@ -9,7 +10,11 @@ logger = Logger() -def init_compute_process(local_rank: int, cluster_info: GraphStoreInfo) -> None: +def init_compute_process( + local_rank: int, + cluster_info: GraphStoreInfo, + compute_world_backend: Optional[str] = None, +) -> None: """ Initializes distributed setup for a compute node in a Graph Store cluster. @@ -18,6 +23,7 @@ def init_compute_process(local_rank: int, cluster_info: GraphStoreInfo) -> None: Args: local_rank (int): The local (process) rank on the compute node. cluster_info (GraphStoreInfo): The cluster information. + compute_world_backend (Optional[str]): The backend for the compute Torch Distributed process group. Raises: ValueError: If the process group is already initialized. @@ -31,11 +37,11 @@ def init_compute_process(local_rank: int, cluster_info: GraphStoreInfo) -> None: + local_rank ) logger.info( - f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}." + f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port} with backend {compute_world_backend}." f" OS rank: {os.environ['RANK']}, local client rank: {local_rank}" ) torch.distributed.init_process_group( - backend="nccl" if torch.cuda.is_available() else "gloo", + backend=compute_world_backend, world_size=cluster_info.compute_cluster_world_size, rank=compute_cluster_rank, init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}", diff --git a/python/gigl/distributed/graph_store/remote_dataset.py b/python/gigl/distributed/graph_store/remote_dataset.py index ac0947bd..61e4d4bc 100644 --- a/python/gigl/distributed/graph_store/remote_dataset.py +++ b/python/gigl/distributed/graph_store/remote_dataset.py @@ -18,7 +18,7 @@ [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L38 """ -from typing import Optional, Union +from typing import Literal, Optional, Union import torch @@ -88,6 +88,17 @@ def get_edge_feature_info() -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], N return _dataset.edge_feature_info +def get_edge_dir() -> Literal["in", "out"]: + """Get the edge direction from the registered dataset. + + Returns: + The edge direction. + """ + if _dataset is None: + raise _NO_DATASET_ERROR + return _dataset.edge_dir + + def get_node_ids_for_rank( rank: int, world_size: int, diff --git a/python/gigl/distributed/graph_store/remote_dist_dataset.py b/python/gigl/distributed/graph_store/remote_dist_dataset.py new file mode 100644 index 00000000..c88424ae --- /dev/null +++ b/python/gigl/distributed/graph_store/remote_dist_dataset.py @@ -0,0 +1,168 @@ +from typing import Literal, Optional, Union + +import torch +from graphlearn_torch.distributed import async_request_server, request_server + +from gigl.common.logger import Logger +from gigl.distributed.graph_store.remote_dataset import ( + get_edge_dir, + get_edge_feature_info, + get_node_feature_info, + get_node_ids_for_rank, +) +from gigl.distributed.utils.networking import get_free_ports +from gigl.env.distributed import GraphStoreInfo +from gigl.src.common.types.graph_data import EdgeType, NodeType +from gigl.types.graph import FeatureInfo + +logger = Logger() + + +class RemoteDistDataset: + def __init__(self, cluster_info: GraphStoreInfo, local_rank: int): + """ + Represents a dataset that is stored on a difference storage cluster. + *Must* be used in the GiGL graph-store distributed setup. + + This class *must* be used on the compute (client) side of the graph-store distributed setup. + + Args: + cluster_info (GraphStoreInfo): The cluster information. + local_rank (int): The local rank of the process on the compute node. + """ + self._cluster_info = cluster_info + self._local_rank = local_rank + + @property + def cluster_info(self) -> GraphStoreInfo: + return self._cluster_info + + def get_node_feature_info( + self, + ) -> Union[FeatureInfo, dict[NodeType, FeatureInfo], None]: + """Get node feature information from the registered dataset. + + Returns: + Node feature information, which can be: + - A single FeatureInfo object for homogeneous graphs + - A dict mapping NodeType to FeatureInfo for heterogeneous graphs + - None if no node features are available + """ + return request_server( + 0, + get_node_feature_info, + ) + + def get_edge_feature_info( + self, + ) -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], None]: + """Get edge feature information from the registered dataset. + + Returns: + Edge feature information, which can be: + - A single FeatureInfo object for homogeneous graphs + - A dict mapping EdgeType to FeatureInfo for heterogeneous graphs + - None if no edge features are available + """ + return request_server( + 0, + get_edge_feature_info, + ) + + def get_edge_dir(self) -> Union[str, Literal["in", "out"]]: + """Get the edge direction from the registered dataset. + + Returns: + The edge direction. + """ + return request_server( + 0, + get_edge_dir, + ) + + def get_node_ids( + self, + node_type: Optional[NodeType] = None, + ) -> list[torch.Tensor]: + """ + Fetches node ids from the storage nodes for the current compute node (machine). + + The returned list are the node ids for the current compute node, by storage rank. + + For example, if there are two storage ranks, and two compute ranks, and 16 total nodes, + In this scenario, the node ids are sharded as follows: + Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] + Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] + + NOTE: The GLT sampling enginer expects that all processes on a given compute machine + to have the same sampling input (node ids). + As such, the input tensors will be duplicated across all processes on a given compute machine. + TODO(kmonte): Come up with a solution to avoid this duplication. + + Then, for compute rank 0 (node 0, process 0), the returned list will be: + [ + [0, 1, 3, 4], # From storage rank 0 + [8, 9, 10, 11] # From storage rank 1 + ] + + Args: + node_type (Optional[NodeType]): The type of nodes to get. + Must be provided for heterogeneous datasets. + + Returns: + list[torch.Tensor]: A list of node IDs for the given node type. + """ + futures: list[torch.futures.Future[torch.Tensor]] = [] + rank = self.cluster_info.compute_node_rank + world_size = self.cluster_info.num_storage_nodes + logger.info( + f"Getting node ids for rank {rank} / {world_size} with node type {node_type}" + ) + + for server_rank in range(self.cluster_info.num_storage_nodes): + futures.append( + async_request_server( + server_rank, + get_node_ids_for_rank, + rank=rank, + world_size=world_size, + node_type=node_type, + ) + ) + node_ids = torch.futures.wait_all(futures) + return node_ids + + def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: + """ + Get free ports from the storage master node. + + This *must* be used with a torch.distributed process group initialized, for the *entire* training cluster. + + All compute ranks will receive the same free ports. + + Args: + num_ports (int): Number of free ports to get. + """ + if not torch.distributed.is_initialized(): + raise ValueError( + "torch.distributed process group must be initialized for the entire training cluster" + ) + compute_cluster_rank = ( + self.cluster_info.compute_node_rank + * self.cluster_info.num_processes_per_compute + + self._local_rank + ) + if compute_cluster_rank == 0: + ports = request_server( + 0, + get_free_ports, + num_ports=num_ports, + ) + logger.info( + f"Compute rank {compute_cluster_rank} found free ports: {ports}" + ) + else: + ports = [None] * num_ports + torch.distributed.broadcast_object_list(ports, src=0) + logger.info(f"Compute rank {compute_cluster_rank} received free ports: {ports}") + return ports diff --git a/python/gigl/distributed/graph_store/storage_main.py b/python/gigl/distributed/graph_store/storage_main.py index 18699bd3..fd7c1cb5 100644 --- a/python/gigl/distributed/graph_store/storage_main.py +++ b/python/gigl/distributed/graph_store/storage_main.py @@ -5,6 +5,7 @@ """ import argparse import os +from typing import Optional import graphlearn_torch as glt import torch @@ -15,6 +16,7 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.graph_store.remote_dataset import register_dataset from gigl.distributed.utils import get_graph_store_info +from gigl.distributed.utils.networking import get_free_ports_from_master_node from gigl.env.distributed import GraphStoreInfo logger = Logger() @@ -24,11 +26,19 @@ def _run_storage_process( storage_rank: int, cluster_info: GraphStoreInfo, dataset: DistDataset, + torch_process_port: int, + storage_world_backend: Optional[str], ) -> None: + register_dataset(dataset) logger.info( - f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes } on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}. Cluster rank: {os.environ.get('RANK')}" + f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {cluster_info.cluster_master_ip}:{torch_process_port}" + ) + torch.distributed.init_process_group( + backend=storage_world_backend, + world_size=cluster_info.num_storage_nodes, + rank=storage_rank, + init_method=f"tcp://{cluster_info.cluster_master_ip}:{torch_process_port}", ) - register_dataset(dataset) glt.distributed.init_server( num_servers=cluster_info.num_storage_nodes, server_rank=storage_rank, @@ -51,6 +61,7 @@ def storage_node_process( task_config_uri: Uri, is_inference: bool, tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$", + storage_world_backend: Optional[str] = None, ) -> None: """Run a storage node process @@ -62,6 +73,7 @@ def storage_node_process( task_config_uri (Uri): The task config URI. is_inference (bool): Whether the process is an inference process. tf_record_uri_pattern (str): The TF Record URI pattern. + storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. """ init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}" logger.info( @@ -82,6 +94,7 @@ def storage_node_process( is_inference=is_inference, _tfrecord_uri_pattern=tf_record_uri_pattern, ) + torch_process_port = get_free_ports_from_master_node(num_ports=1)[0] server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") # TODO(kmonte): Enable more than one server process per machine @@ -92,6 +105,8 @@ def storage_node_process( storage_rank + i, # storage_rank cluster_info, # cluster_info dataset, # dataset + torch_process_port, # torch_process_port + storage_world_backend, # storage_world_backend ), ) server_processes.append(server_process) diff --git a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py index 06d9f1a3..c0d44abe 100644 --- a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -1,3 +1,4 @@ +import collections import os import unittest from unittest import mock @@ -11,8 +12,10 @@ init_compute_process, shutdown_compute_proccess, ) +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.graph_store.storage_main import storage_node_process -from gigl.distributed.utils import get_free_port +from gigl.distributed.utils.neighborloader import shard_nodes_by_process +from gigl.distributed.utils.networking import get_free_ports from gigl.env.distributed import ( COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, GraphStoreInfo, @@ -21,6 +24,7 @@ from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO, ) +from tests.test_assets.distributed.utils import assert_tensor_equality logger = Logger() @@ -28,22 +32,64 @@ def _run_client_process( client_rank: int, cluster_info: GraphStoreInfo, + expected_sampler_input: dict[int, list[torch.Tensor]], ) -> None: - client_global_rank = ( - cluster_info.compute_node_rank * cluster_info.num_processes_per_compute - + client_rank + init_compute_process(client_rank, cluster_info, compute_world_backend="gloo") + + remote_dist_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=client_rank ) - init_compute_process(client_rank, cluster_info) + assert ( + remote_dist_dataset.get_edge_dir() == "in" + ), f"Edge direction must be 'in' for the test dataset. Got {remote_dist_dataset.get_edge_dir()}" + assert ( + remote_dist_dataset.get_edge_feature_info() is not None + ), "Edge feature info must not be None for the test dataset" + assert ( + remote_dist_dataset.get_node_feature_info() is not None + ), "Node feature info must not be None for the test dataset" + ports = remote_dist_dataset.get_free_ports_on_storage_cluster(num_ports=2) + assert len(ports) == 2, "Expected 2 free ports" + if torch.distributed.get_rank() == 0: + all_ports = [None] * torch.distributed.get_world_size() + else: + all_ports = None + torch.distributed.gather_object(ports, all_ports) + logger.info(f"All ports: {all_ports}") + + if torch.distributed.get_rank() == 0: + assert isinstance(all_ports, list) + for i, received_ports in enumerate(all_ports): + assert ( + received_ports == ports + ), f"Expected {ports} free ports, got {received_ports}" + + torch.distributed.barrier() + logger.info("Verified that all ranks received the same free ports") + + sampler_input = remote_dist_dataset.get_node_ids() + + rank_expected_sampler_input = expected_sampler_input[cluster_info.compute_node_rank] + for i in range(cluster_info.num_compute_nodes): + if i == cluster_info.compute_node_rank: + logger.info(f"Verifying sampler input for rank {i}") + logger.info(f"--------------------------------") + assert len(sampler_input) == len(rank_expected_sampler_input) + for j, expected in enumerate(rank_expected_sampler_input): + assert_tensor_equality(sampler_input[j], expected) + logger.info( + f"{i} / {cluster_info.num_compute_nodes} compute node rank input nodes verified" + ) + torch.distributed.barrier() + torch.distributed.barrier() - logger.info( - f"{client_global_rank} / {cluster_info.compute_cluster_world_size} Shutting down client" - ) shutdown_compute_proccess() def _client_process( client_rank: int, cluster_info: GraphStoreInfo, + expected_sampler_input: dict[int, list[torch.Tensor]], ) -> None: logger.info( f"Initializing client node {client_rank} / {cluster_info.num_compute_nodes}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']}, local client rank: {client_rank}" @@ -51,12 +97,14 @@ def _client_process( mp_context = torch.multiprocessing.get_context("spawn") client_processes = [] + logger.info("Starting client processes") for i in range(cluster_info.num_processes_per_compute): client_process = mp_context.Process( target=_run_client_process, args=[ i, # client_rank cluster_info, # cluster_info + expected_sampler_input, # expected_sampler_input ], ) client_processes.append(client_process) @@ -80,9 +128,55 @@ def _run_server_processes( task_config_uri=task_config_uri, is_inference=is_inference, tf_record_uri_pattern=".*tfrecord", + storage_world_backend="gloo", ) +def _get_expected_input_nodes_by_rank( + num_nodes: int, cluster_info: GraphStoreInfo +) -> dict[int, list[torch.Tensor]]: + """Get the expected sampler input for each compute rank. + + We generate the expected sampler input for each compute rank by sharding the nodes across the compute ranks. + We then append the generated nodes to the expected sampler input for each compute rank. + Example for num_nodes = 16, num_processes_per_compute = 1, num_compute_nodes = 2, num_storage_nodes = 2: + { + 0: # compute rank 0 + [ + [0, 1, 3, 4], # From storage rank 0 + [8, 9, 11, 12] # From storage rank 1 + ] + 1: # compute rank 1 + [ + [5, 6, 7, 8], # From storage rank 0 + [13, 14, 15, 16] # From storage rank 1 + ], + } + + + Args: + num_nodes (int): The number of nodes in the graph. + cluster_info (GraphStoreInfo): The cluster information. + + Returns: + dict[int, list[torch.Tensor]]: The expected sampler input for each compute rank. + """ + expected_sampler_input = collections.defaultdict(list) + all_nodes = torch.arange(num_nodes, dtype=torch.int64) + for server_rank in range(cluster_info.num_storage_nodes): + server_node_start = server_rank * num_nodes // cluster_info.num_storage_nodes + server_node_end = ( + (server_rank + 1) * num_nodes // cluster_info.num_storage_nodes + ) + server_nodes = all_nodes[server_node_start:server_node_end] + for compute_rank in range(cluster_info.num_compute_nodes): + generated_nodes = shard_nodes_by_process( + server_nodes, compute_rank, cluster_info.num_processes_per_compute + ) + expected_sampler_input[compute_rank].append(generated_nodes) + return dict(expected_sampler_input) + + class TestUtils(unittest.TestCase): def test_graph_store_locally(self): # Simulating two server machine, two compute machines. @@ -91,6 +185,12 @@ def test_graph_store_locally(self): CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name ] task_config_uri = cora_supervised_info.frozen_gbml_config_uri + ( + cluster_master_port, + storage_cluster_master_port, + compute_cluster_master_port, + master_port, + ) = get_free_ports(num_ports=4) cluster_info = GraphStoreInfo( num_storage_nodes=2, num_compute_nodes=2, @@ -98,12 +198,16 @@ def test_graph_store_locally(self): cluster_master_ip="localhost", storage_cluster_master_ip="localhost", compute_cluster_master_ip="localhost", - cluster_master_port=get_free_port(), - storage_cluster_master_port=get_free_port(), - compute_cluster_master_port=get_free_port(), + cluster_master_port=cluster_master_port, + storage_cluster_master_port=storage_cluster_master_port, + compute_cluster_master_port=compute_cluster_master_port, + ) + + num_cora_nodes = 2708 + expected_sampler_input = _get_expected_input_nodes_by_rank( + num_cora_nodes, cluster_info ) - master_port = get_free_port() ctx = mp.get_context("spawn") client_processes: list = [] for i in range(cluster_info.num_compute_nodes): @@ -125,6 +229,7 @@ def test_graph_store_locally(self): args=[ i, # client_rank cluster_info, # cluster_info + expected_sampler_input, # expected_sampler_input ], ) client_process.start() diff --git a/python/tests/unit/distributed/graph_store/remote_dataset_test.py b/python/tests/unit/distributed/graph_store/remote_dataset_test.py index c83e8095..fdfc0fe5 100644 --- a/python/tests/unit/distributed/graph_store/remote_dataset_test.py +++ b/python/tests/unit/distributed/graph_store/remote_dataset_test.py @@ -272,6 +272,27 @@ def test_get_node_ids_for_rank_with_heterogeneous_dataset_and_no_node_type( str(context.exception), ) + def test_get_edge_dir(self) -> None: + """Test get_edge_dir with a registered dataset.""" + dataset = self._create_homogeneous_dataset() + remote_dataset.register_dataset(dataset) + edge_dir = remote_dataset.get_edge_dir() + self.assertEqual(edge_dir, dataset.edge_dir) + + def test_get_node_feature_info(self) -> None: + """Test get_node_feature_info with a registered dataset.""" + dataset = self._create_homogeneous_dataset() + remote_dataset.register_dataset(dataset) + node_feature_info = remote_dataset.get_node_feature_info() + self.assertEqual(node_feature_info, dataset.node_feature_info) + + def test_get_edge_feature_info(self) -> None: + """Test get_edge_feature_info with a registered dataset.""" + dataset = self._create_homogeneous_dataset() + remote_dataset.register_dataset(dataset) + edge_feature_info = remote_dataset.get_edge_feature_info() + self.assertEqual(edge_feature_info, dataset.edge_feature_info) + if __name__ == "__main__": unittest.main()