Skip to content
Merged
12 changes: 9 additions & 3 deletions python/gigl/distributed/graph_store/compute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional

import graphlearn_torch as glt
import torch
Expand All @@ -9,7 +10,11 @@
logger = Logger()


def init_compute_process(local_rank: int, cluster_info: GraphStoreInfo) -> None:
def init_compute_process(
local_rank: int,
cluster_info: GraphStoreInfo,
compute_world_backend: Optional[str] = None,
) -> None:
"""
Initializes distributed setup for a compute node in a Graph Store cluster.

Expand All @@ -18,6 +23,7 @@ def init_compute_process(local_rank: int, cluster_info: GraphStoreInfo) -> None:
Args:
local_rank (int): The local (process) rank on the compute node.
cluster_info (GraphStoreInfo): The cluster information.
compute_world_backend (Optional[str]): The backend for the compute Torch Distributed process group.

Raises:
ValueError: If the process group is already initialized.
Expand All @@ -31,11 +37,11 @@ def init_compute_process(local_rank: int, cluster_info: GraphStoreInfo) -> None:
+ local_rank
)
logger.info(
f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}."
f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port} with backend {compute_world_backend}."
f" OS rank: {os.environ['RANK']}, local client rank: {local_rank}"
)
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
backend=compute_world_backend,
world_size=cluster_info.compute_cluster_world_size,
rank=compute_cluster_rank,
init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}",
Expand Down
13 changes: 12 additions & 1 deletion python/gigl/distributed/graph_store/remote_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

[1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L38
"""
from typing import Optional, Union
from typing import Literal, Optional, Union

import torch

Expand Down Expand Up @@ -88,6 +88,17 @@ def get_edge_feature_info() -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], N
return _dataset.edge_feature_info


def get_edge_dir() -> Literal["in", "out"]:
"""Get the edge direction from the registered dataset.

Returns:
The edge direction.
"""
if _dataset is None:
raise _NO_DATASET_ERROR
return _dataset.edge_dir


def get_node_ids_for_rank(
rank: int,
world_size: int,
Expand Down
168 changes: 168 additions & 0 deletions python/gigl/distributed/graph_store/remote_dist_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import Literal, Optional, Union

import torch
from graphlearn_torch.distributed import async_request_server, request_server

from gigl.common.logger import Logger
from gigl.distributed.graph_store.remote_dataset import (
get_edge_dir,
get_edge_feature_info,
get_node_feature_info,
get_node_ids_for_rank,
)
from gigl.distributed.utils.networking import get_free_ports
from gigl.env.distributed import GraphStoreInfo
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.types.graph import FeatureInfo

logger = Logger()


class RemoteDistDataset:
def __init__(self, cluster_info: GraphStoreInfo, local_rank: int):
"""
Represents a dataset that is stored on a difference storage cluster.
*Must* be used in the GiGL graph-store distributed setup.

This class *must* be used on the compute (client) side of the graph-store distributed setup.

Args:
cluster_info (GraphStoreInfo): The cluster information.
local_rank (int): The local rank of the process on the compute node.
"""
self._cluster_info = cluster_info
self._local_rank = local_rank

@property
def cluster_info(self) -> GraphStoreInfo:
return self._cluster_info

def get_node_feature_info(
self,
) -> Union[FeatureInfo, dict[NodeType, FeatureInfo], None]:
"""Get node feature information from the registered dataset.

Returns:
Node feature information, which can be:
- A single FeatureInfo object for homogeneous graphs
- A dict mapping NodeType to FeatureInfo for heterogeneous graphs
- None if no node features are available
"""
return request_server(
0,
get_node_feature_info,
)

def get_edge_feature_info(
self,
) -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], None]:
"""Get edge feature information from the registered dataset.

Returns:
Edge feature information, which can be:
- A single FeatureInfo object for homogeneous graphs
- A dict mapping EdgeType to FeatureInfo for heterogeneous graphs
- None if no edge features are available
"""
return request_server(
0,
get_edge_feature_info,
)

def get_edge_dir(self) -> Union[str, Literal["in", "out"]]:
"""Get the edge direction from the registered dataset.

Returns:
The edge direction.
"""
return request_server(
0,
get_edge_dir,
)

def get_node_ids(
self,
node_type: Optional[NodeType] = None,
) -> list[torch.Tensor]:
"""
Fetches node ids from the storage nodes for the current compute node (machine).

The returned list are the node ids for the current compute node, by storage rank.

For example, if there are two storage ranks, and two compute ranks, and 16 total nodes,
In this scenario, the node ids are sharded as follows:
Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7]
Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15]

NOTE: The GLT sampling enginer expects that all processes on a given compute machine
to have the same sampling input (node ids).
As such, the input tensors will be duplicated across all processes on a given compute machine.
TODO(kmonte): Come up with a solution to avoid this duplication.

Then, for compute rank 0 (node 0, process 0), the returned list will be:
[
[0, 1, 3, 4], # From storage rank 0
[8, 9, 10, 11] # From storage rank 1
]

Args:
node_type (Optional[NodeType]): The type of nodes to get.
Must be provided for heterogeneous datasets.

Returns:
list[torch.Tensor]: A list of node IDs for the given node type.
"""
futures: list[torch.futures.Future[torch.Tensor]] = []
rank = self.cluster_info.compute_node_rank
world_size = self.cluster_info.num_storage_nodes
logger.info(
f"Getting node ids for rank {rank} / {world_size} with node type {node_type}"
)

for server_rank in range(self.cluster_info.num_storage_nodes):
futures.append(
async_request_server(
server_rank,
get_node_ids_for_rank,
rank=rank,
world_size=world_size,
node_type=node_type,
)
)
node_ids = torch.futures.wait_all(futures)
return node_ids

def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]:
"""
Get free ports from the storage master node.

This *must* be used with a torch.distributed process group initialized, for the *entire* training cluster.

All compute ranks will receive the same free ports.

Args:
num_ports (int): Number of free ports to get.
"""
if not torch.distributed.is_initialized():
raise ValueError(
"torch.distributed process group must be initialized for the entire training cluster"
)
compute_cluster_rank = (
self.cluster_info.compute_node_rank
* self.cluster_info.num_processes_per_compute
+ self._local_rank
)
if compute_cluster_rank == 0:
ports = request_server(
0,
get_free_ports,
num_ports=num_ports,
)
logger.info(
f"Compute rank {compute_cluster_rank} found free ports: {ports}"
)
else:
ports = [None] * num_ports
torch.distributed.broadcast_object_list(ports, src=0)
logger.info(f"Compute rank {compute_cluster_rank} received free ports: {ports}")
return ports
19 changes: 17 additions & 2 deletions python/gigl/distributed/graph_store/storage_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import argparse
import os
from typing import Optional

import graphlearn_torch as glt
import torch
Expand All @@ -15,6 +16,7 @@
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.graph_store.remote_dataset import register_dataset
from gigl.distributed.utils import get_graph_store_info
from gigl.distributed.utils.networking import get_free_ports_from_master_node
from gigl.env.distributed import GraphStoreInfo

logger = Logger()
Expand All @@ -24,11 +26,19 @@ def _run_storage_process(
storage_rank: int,
cluster_info: GraphStoreInfo,
dataset: DistDataset,
torch_process_port: int,
storage_world_backend: Optional[str],
) -> None:
register_dataset(dataset)
logger.info(
f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes } on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}. Cluster rank: {os.environ.get('RANK')}"
f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {cluster_info.cluster_master_ip}:{torch_process_port}"
)
torch.distributed.init_process_group(
backend=storage_world_backend,
world_size=cluster_info.num_storage_nodes,
rank=storage_rank,
init_method=f"tcp://{cluster_info.cluster_master_ip}:{torch_process_port}",
)
register_dataset(dataset)
glt.distributed.init_server(
num_servers=cluster_info.num_storage_nodes,
server_rank=storage_rank,
Expand All @@ -51,6 +61,7 @@ def storage_node_process(
task_config_uri: Uri,
is_inference: bool,
tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$",
storage_world_backend: Optional[str] = None,
) -> None:
"""Run a storage node process

Expand All @@ -62,6 +73,7 @@ def storage_node_process(
task_config_uri (Uri): The task config URI.
is_inference (bool): Whether the process is an inference process.
tf_record_uri_pattern (str): The TF Record URI pattern.
storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group.
"""
init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}"
logger.info(
Expand All @@ -82,6 +94,7 @@ def storage_node_process(
is_inference=is_inference,
_tfrecord_uri_pattern=tf_record_uri_pattern,
)
torch_process_port = get_free_ports_from_master_node(num_ports=1)[0]
server_processes = []
mp_context = torch.multiprocessing.get_context("spawn")
# TODO(kmonte): Enable more than one server process per machine
Expand All @@ -92,6 +105,8 @@ def storage_node_process(
storage_rank + i, # storage_rank
cluster_info, # cluster_info
dataset, # dataset
torch_process_port, # torch_process_port
storage_world_backend, # storage_world_backend
),
)
server_processes.append(server_process)
Expand Down
Loading