Overview
In collaboration with the engineering teams at Intel and Kumo, we are thrilled to announce the first in-house distributed training solution for PyTorch* Geometric (PyG) on the latest Intel® Xeon® platforms, available from PyG 2.5 and later. Developers and researchers can take full advantage of distributed training on large-scale datasets, which cannot be fully loaded in memory of one machine at the same time. In addition, this solution greatly reduces the training time when the number of compute nodes is increased. Let's take a closer look on how our distributed solution is implemented and its performance at scale.
In real-life applications, graphs often consist of billions of nodes that can't fit into a single system memory. This is when distributed training of graph neural networks is necessary. By allocating a number of partitions of the graph across a cluster of CPUs, developers can deploy synchronized model training on the whole dataset at once by making use of PyTorch DistributedDataParallel (DDP) capabilities. This architecture seamlessly distributes training of graph neural networks across multiple nodes via Remote Procedure Calls (RPC) for efficient sampling and retrieval of nonlocal features with traditional DDP for model training.
Key Advantages of Distributed Training for PyG
The PyG in-house distributed training scheme, integrated through the torch_geometric.distributed subpackage, provides the following key advantages:
- Balanced graph partitioning via METIS (one partition algorithm) ensures the least communication overheads when sampling subgraphs across compute nodes.
- Using DDP for model training with RPC for remote sampling and feature calls, with TCP and the "gloo" back end specifically tailored for CPU-based sampling, enhances the efficiency and scalability of the training process.
- The implementation via custom GraphStore/FeatureStore APIs provides a flexible and tailored interface for large graph structure information and feature storage.
- Distributed neighbor sampling is capable of sampling in both local and remote partitions through RPC communication channels. All advanced functionality of single-node sampling are also applicable for distributed training, for example, heterogeneous sampling, link-level sampling, and temporal sampling.
- Distributed data loaders offer a high-level abstraction for managing sampler processes, ensuring simplicity and seamless integration with standard PyG data loaders.
- Incorporating the Python* asyncio library for asynchronous processing on top of torch RPC further enhances the system's responsiveness and overall performance.
Architecture Components
There are two distributed groups for distributed training in PyG. Torch RPC is used for distributed sampling over multiple partitions including node sampling and feature lookup. Torch DDP is used for distributed training over multiple nodes.
Overall, torch_geometric.distributed is divided into the following components:
- Graph Partitioner: Partitions the graph into multiple parts so that each node only needs to load its local data in memory.
- GraphStore and FeatureStore: Stores the graph topology and features per partition, respectively. In addition, they maintain a mapping between local and global IDs for efficient assignment of nodes and feature lookup.
- DistNeighborSampler: Implements the distributed sampling algorithm, which includes local and remote sampling, and the final merge between local and remote sampling results based on PyTorch's RPC mechanisms.
- DistNeighborLoader: Manages the distributed neighbor sampling and feature-fetching processes via multiple RPC workers. Finally, it takes care to form sampled nodes, edges, and their features into the PyG data format.
Graph Partitioning
The first step for distributed training is to split the graph into multiple smaller partitions, which can then be loaded locally into nodes of the cluster. Partitioning is built on top of the pyg-lib implementation of the METIS algorithm, suitable to perform graph partitioning efficiently, even on large-scale graphs. By default, METIS tries to balance the number of nodes of each type in each partition while minimizing the number of edges between the partitions. This ensures that the resulting partitions provide maximal local access of neighbors, enabling samplers to perform local computations without the need for communication between nodes. Through this partitioning approach, every edge receives a distinct assignment, while halo nodes (1-hop neighbors that fall into a different partition) are replicated. Halo nodes ensure that neighbor sampling for a single node in a single layer stays purely local (see Figure 1).
Figure 1. Graph partitioning with halo nodes. Nodes 1, 4, 5, 10 are halo nodes to keep required graph information local even after partitioning.
In our distributed training example, we prepared the partition_graph.py script to demonstrate how to apply partitioning for a selected subset of both homogeneous and heterogeneous graphs. The partitioner can also preserve node features, edge features, and any temporal attributes at the node or edge level.
The result of partitioning for a two-part split on the homogeneous ogbn-products and heterogeneous ogbn-mag graphs are shown in Figure 2.
Figure 2. Two partitions for ogbn-Products and ogbn-mag datasets.
In distributed training, each node in the cluster then owns a partition of the graph.
LocalGraphStore and LocalFeatureStore
To maintain distributed data partitions, we use an instantiation of the PyG torch_geometric.data.GraphStore and torch_geometric.data.FeatureStore remote interfaces, as shown in Figure 3. Together with an integrated API for sending and receiving RPC requests, they provide a powerful tool for interconnected distributed data storage. Both stores can be filled with data in a number of ways, from torch_geometric.data.Data and torch_geometric.data.HeteroData objects or initialized directly from generated partition files.
LocalGraphStore
The torch_geometric.distributed.LocalGraphStore class is designed to function as a container for graph topology information. It holds the edge indices that define relationships between nodes in a graph and is implemented on top of the torch_geometric.data.GraphStore interface. It offers methods that provide mapping information for nodes and edges to individual partitions and supports both homogeneous and heterogeneous data formats.
Key features:
- It only stores information about local graph connections and its halo nodes within a partition.
- Remote connectivity: The affiliation information of individual nodes and edges to partitions (both local and global) can be retrieved through node and edge partition books, mappings of partition IDs to global node and edge IDs.
- It maintains global identifiers for nodes and edges, allowing for consistent mapping across partitions.
LocalFeatureStore
torch_geometric.distributed.LocalFeatureStore is a class that serves as both node-level and edge-level feature storage. It is implemented on top of the PyG torch_geometric.data.FeatureStore remote interface and provides efficient put and get routines for attribute retrieval for both local and remote node and edge IDs. The local feature store is responsible for retrieving and updating features across different partitions and machines during the training process.
Key features:
- It provides functionalities for storing, retrieving, and distributing node and edge features. Within the partition managed by each machine or device, node and edge features are stored locally.
- Remote feature lookup: It implements mechanisms for looking up features in both local and remote nodes during distributed training processes through RPC requests. The class is designed to work seamlessly in distributed training scenarios, allowing for efficient feature handling across partitions.
It maintains global identifiers for nodes and edges, allowing for consistent mapping across partitions.
Figure 3. Components and APIs for LocalGraphStore and LocalFeatureStore.
Distributed Neighbor Sampling
The DistNeighborSampler is a module designed for efficient distributed training of graph neural networks (GNNs). It addresses the challenges of sampling neighbors in a distributed environment, whereby graph data is partitioned across multiple machines or devices. The sampler ensures that GNNs can effectively learn from large-scale graphs, thus maintaining scalability and performance.
Asynchronous Neighbor Sampling and Feature Collection
Distributed neighbor sampling is implemented using asynchronous torch.distributed.rpc calls. It allows machines to independently sample neighbors without strict synchronization. Each machine autonomously selects neighbors from its local graph partition without waiting for others to complete their sampling processes. This approach enhances parallelism as machines can progress asynchronously and lead to faster training. In addition to asynchronous sampling, distributed neighbor sampling also provides asynchronous feature collection.
Customizable Sampling Strategies
Users can customize sampling strategies based on their specific requirements. The module provides flexibility in defining sampling techniques, such as:
- Node sampling versus edge sampling
- Homogeneous versus heterogeneous sampling
- Temporal sampling versus static sampling
Distributed Neighbor Sampling Workflow
A batch of seed indices follows three main steps before it is made available for the model's forward pass by the data loader:
- Distributed node sampling: The set of seed indices may originate from either local or remote partitions. For nodes within a local partition, the sampling occurs on the local machine. Conversely, for nodes associated with a remote partition, the neighbor sampling is conducted on the machine responsible for storing the respective partition. Sampling then happens layer-wise, where sampled nodes act as seed nodes in follow-up layers.
- Distributed feature lookup: Each partition stores an array of features of nodes and edges that are within that partition. Consequently, if the output of a sampler on a specific machine includes sampled nodes or edges that do not pertain in its partition, the machine initiates an RPC request to the remote server to which these nodes (or edges) belong.
- Data conversion: Based on the sampler output and the acquired node and edge features, a PyG homogeneous (or heterogeneous) graph object is constructed. This object forms a batch used in subsequent computational operations of the model.
Figure 4. Distributed sampler (local and remote node sampling)
Distributed Data Loading
Distributed data loaders such as DistNeighborLoader and DistLinkNeighborLoader provide a simple API for the sampling engine previously described (see Figure 6) because they entirely wrap initialization and cleanup of sampler processes internally. Notably, the distributed classes are inherited from the standard PyG single-node torch_geometric.loader.NodeLoader and torch_geometric.loader.LinkLoader loaders, respectively, making their application inside training scripts nearly identical.
Batch generation is slightly different from the single-node case in that the step of (local and remote) feature fetching happens within the sampler, rather than encapsulated into two separate steps (sampling and feature fetching). This allows limiting of the amount of RPCs. Due to the asynchronous processing between all sampler subprocesses, the samplers then return their output to a torch.multiprocessing.Queue.
Figure 5. Breakdown of the structure for DistNeighborLoader
Torch RPC and DDP
In this distributed training implementation, two torch.distributed communication technologies are used:
- torch.distributed.ddp for data-parallel model training
- torch.distributed.rpc for remote sampling calls and feature retrieval in a distributed environment
Our solution opts for torch.distributed.rpc over alternatives such as gRPC because PyTorch RPC inherently comprehends tensor-type data. Unlike some other RPC methods like gRPC, which require the serialization or digitization of JSON or other user data into tensor types, using this method helps avoid additional serialization and digitization overhead.
The DDP group initializes in a standard way in the main training script. RPC group initialization is more complicated because it happens in each sampler subprocess, which is achieved through the worker_init_fn of the data loader, which is called by PyTorch directly at the initialization step of worker processes. This function first defines a distributed context for each worker and assigns it a group and rank, subsequently initializes its own distributed neighbor sampler, and finally registers a new member in the RPC group (see Figure 6). This RPC connection remains open as long as the subprocess exists. Additionally, we opted for the atexit module to register additional cleanup behaviors that are triggered when the process is terminated.
Figure 6. RPC and DDP groups for distributed PyG
Results and Performance
We collected the benchmarking results using the system configuration at the bottom of this blog.
Figure 7 shows the scaling performance on the ogbn-products dataset (4.2 Gbytes) under different partition configurations (1/2/4/8/16). We observe that scaling benefits when the number of partitions is increased. Figure 8 shows the training loss and test accuracy over the epochs that proves the correctness of our distributed solution. We also verified the scaling performance for heterogeneous data, like the ogbn-mag dataset.
Figure 7. Scaling performance over partition numbers (1/2/4/8/16)
Figure 8. Training loss and test accuracy for a homogenous ogbn-products dataset
Future Plans
The next steps for ongoing distributed work for PyG are:
- Extending the current distributed framework to PVC GPUs from Intel
- Customer engagement and real customer use cases
Product and Performance Information
- Bare metal: node, 2x Intel Xeon Platinum 8360Y CPU at 2.40 GHz, 36 cores, Intel® Hyper-Threading Technology on, turbo on, NUMA 2, integrated accelerators available (used): DLB 0 [0], DSA 0 [0], IAA 0 [0], QAT 0 [0], total memory 256 GB (16x16 GB DDR4 3200 MT/s [3200 MT/s]), BIOS SE5C620.86B.01.01.0003.2104260124, microcode 0xd000389, 2x Ethernet controller X710 for 10 GbE SFP+, 1x MT28908 family [ConnectX-6], 1x 894.3G INTEL SSDSC2KG96, Rocky Linux* 8.8 (Green Obsidian), 4.18.0-477.21.1.el8_8.x86_64
- Multi-nodes: 1/2/4/8/16 nodes each (same as the previous bare metal bullet)
- Software: Python* 8 3.9.16, PyTorch 2.1.2, PyG 2.4, PyTorch PyG_lib 0.3.1, ogbn-products dataset (4.2 Gbytes), GraphSAGE model
- Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex.