-
Notifications
You must be signed in to change notification settings - Fork 12
Add RemoteDistDataset to operate on the dataset from compute nodes in graph store mode
#404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
08d19ca
Add server_main and graph store local integration test
kmonte d7f1761
update test name
kmonte 98579a3
add tood
kmonte 54b67a8
Teardown process group
kmonte 0f75d43
Address comments
kmonte f96bda9
address comments
kmonte 33a5c5f
rename
kmonte a433102
comments
kmonte 7a2b29b
Merge branch 'main' into kmonte/add-remote-dist-dataset
kmonte 0b1ea3a
Merge branch 'main' into kmonte/add-remote-dist-dataset
kmonte 9ec95b5
update doc comments
kmonte c46b7d8
address comments
kmonte 84220ef
fixes
kmonte e75a07a
fixes
kmonte 7095c4f
fix type check
kmonte 6ffc7ef
Merge branch 'main' into kmonte/add-remote-dist-dataset
kmontemayor2-sc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
168 changes: 168 additions & 0 deletions
168
python/gigl/distributed/graph_store/remote_dist_dataset.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| from typing import Literal, Optional, Union | ||
|
|
||
| import torch | ||
| from graphlearn_torch.distributed import async_request_server, request_server | ||
|
|
||
| from gigl.common.logger import Logger | ||
| from gigl.distributed.graph_store.remote_dataset import ( | ||
| get_edge_dir, | ||
| get_edge_feature_info, | ||
| get_node_feature_info, | ||
| get_node_ids_for_rank, | ||
| ) | ||
| from gigl.distributed.utils.networking import get_free_ports | ||
| from gigl.env.distributed import GraphStoreInfo | ||
| from gigl.src.common.types.graph_data import EdgeType, NodeType | ||
| from gigl.types.graph import FeatureInfo | ||
|
|
||
| logger = Logger() | ||
|
|
||
|
|
||
| class RemoteDistDataset: | ||
kmontemayor2-sc marked this conversation as resolved.
Show resolved
Hide resolved
kmontemayor2-sc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def __init__(self, cluster_info: GraphStoreInfo, local_rank: int): | ||
| """ | ||
| Represents a dataset that is stored on a difference storage cluster. | ||
| *Must* be used in the GiGL graph-store distributed setup. | ||
|
|
||
| This class *must* be used on the compute (client) side of the graph-store distributed setup. | ||
|
|
||
| Args: | ||
| cluster_info (GraphStoreInfo): The cluster information. | ||
| local_rank (int): The local rank of the process on the compute node. | ||
| """ | ||
| self._cluster_info = cluster_info | ||
| self._local_rank = local_rank | ||
|
|
||
| @property | ||
| def cluster_info(self) -> GraphStoreInfo: | ||
| return self._cluster_info | ||
|
|
||
| def get_node_feature_info( | ||
| self, | ||
| ) -> Union[FeatureInfo, dict[NodeType, FeatureInfo], None]: | ||
| """Get node feature information from the registered dataset. | ||
|
|
||
| Returns: | ||
| Node feature information, which can be: | ||
| - A single FeatureInfo object for homogeneous graphs | ||
| - A dict mapping NodeType to FeatureInfo for heterogeneous graphs | ||
| - None if no node features are available | ||
| """ | ||
| return request_server( | ||
| 0, | ||
| get_node_feature_info, | ||
| ) | ||
|
|
||
| def get_edge_feature_info( | ||
| self, | ||
| ) -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], None]: | ||
| """Get edge feature information from the registered dataset. | ||
|
|
||
| Returns: | ||
| Edge feature information, which can be: | ||
| - A single FeatureInfo object for homogeneous graphs | ||
| - A dict mapping EdgeType to FeatureInfo for heterogeneous graphs | ||
| - None if no edge features are available | ||
| """ | ||
| return request_server( | ||
| 0, | ||
| get_edge_feature_info, | ||
| ) | ||
|
|
||
| def get_edge_dir(self) -> Union[str, Literal["in", "out"]]: | ||
| """Get the edge direction from the registered dataset. | ||
|
|
||
| Returns: | ||
| The edge direction. | ||
| """ | ||
| return request_server( | ||
| 0, | ||
| get_edge_dir, | ||
| ) | ||
|
|
||
| def get_node_ids( | ||
kmontemayor2-sc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| node_type: Optional[NodeType] = None, | ||
| ) -> list[torch.Tensor]: | ||
| """ | ||
| Fetches node ids from the storage nodes for the current compute node (machine). | ||
|
|
||
| The returned list are the node ids for the current compute node, by storage rank. | ||
|
|
||
| For example, if there are two storage ranks, and two compute ranks, and 16 total nodes, | ||
| In this scenario, the node ids are sharded as follows: | ||
| Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] | ||
| Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] | ||
|
|
||
| NOTE: The GLT sampling enginer expects that all processes on a given compute machine | ||
| to have the same sampling input (node ids). | ||
| As such, the input tensors will be duplicated across all processes on a given compute machine. | ||
| TODO(kmonte): Come up with a solution to avoid this duplication. | ||
|
|
||
| Then, for compute rank 0 (node 0, process 0), the returned list will be: | ||
| [ | ||
| [0, 1, 3, 4], # From storage rank 0 | ||
| [8, 9, 10, 11] # From storage rank 1 | ||
| ] | ||
|
|
||
| Args: | ||
| node_type (Optional[NodeType]): The type of nodes to get. | ||
| Must be provided for heterogeneous datasets. | ||
|
|
||
| Returns: | ||
| list[torch.Tensor]: A list of node IDs for the given node type. | ||
| """ | ||
| futures: list[torch.futures.Future[torch.Tensor]] = [] | ||
| rank = self.cluster_info.compute_node_rank | ||
| world_size = self.cluster_info.num_storage_nodes | ||
| logger.info( | ||
| f"Getting node ids for rank {rank} / {world_size} with node type {node_type}" | ||
| ) | ||
|
|
||
| for server_rank in range(self.cluster_info.num_storage_nodes): | ||
| futures.append( | ||
| async_request_server( | ||
| server_rank, | ||
| get_node_ids_for_rank, | ||
| rank=rank, | ||
| world_size=world_size, | ||
| node_type=node_type, | ||
| ) | ||
| ) | ||
| node_ids = torch.futures.wait_all(futures) | ||
| return node_ids | ||
|
|
||
| def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: | ||
| """ | ||
| Get free ports from the storage master node. | ||
|
|
||
| This *must* be used with a torch.distributed process group initialized, for the *entire* training cluster. | ||
|
|
||
| All compute ranks will receive the same free ports. | ||
|
|
||
| Args: | ||
| num_ports (int): Number of free ports to get. | ||
| """ | ||
| if not torch.distributed.is_initialized(): | ||
| raise ValueError( | ||
| "torch.distributed process group must be initialized for the entire training cluster" | ||
| ) | ||
| compute_cluster_rank = ( | ||
| self.cluster_info.compute_node_rank | ||
| * self.cluster_info.num_processes_per_compute | ||
| + self._local_rank | ||
| ) | ||
| if compute_cluster_rank == 0: | ||
| ports = request_server( | ||
| 0, | ||
| get_free_ports, | ||
| num_ports=num_ports, | ||
| ) | ||
| logger.info( | ||
| f"Compute rank {compute_cluster_rank} found free ports: {ports}" | ||
| ) | ||
| else: | ||
| ports = [None] * num_ports | ||
| torch.distributed.broadcast_object_list(ports, src=0) | ||
kmontemayor2-sc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logger.info(f"Compute rank {compute_cluster_rank} received free ports: {ports}") | ||
| return ports | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.