diff --git a/benchmarks/pytest.ini b/benchmarks/pytest.ini index fe7fc31b6d6..d692b78de37 100644 --- a/benchmarks/pytest.ini +++ b/benchmarks/pytest.ini @@ -8,6 +8,7 @@ testpaths = addopts = --benchmark-columns="min, max, mean, stddev, outliers" + --tb=native markers = managedmem_on: RMM managed memory enabled diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index 1976d8ff46f..f3979ab3049 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -17,7 +17,7 @@ cd "${package_dir}" python -m pip wheel \ -w dist \ - -vvv \ + -v \ --no-deps \ --disable-pip-version-check \ --extra-index-url https://pypi.nvidia.com \ diff --git a/ci/test_python.sh b/ci/test_python.sh index 810284b8c97..f21a06cf061 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -159,7 +159,7 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then cugraph \ cugraph-dgl \ 'dgl>=1.1.0.cu*,<=2.0.0.cu*' \ - 'pytorch>=2.0' \ + 'pytorch>=2.3,<2.4' \ 'cuda-version=11.8' rapids-print-env @@ -198,10 +198,10 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then # TODO re-enable logic once CUDA 12 is testable #if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then CONDA_CUDA_VERSION="11.8" - PYG_URL="https://data.pyg.org/whl/torch-2.1.0+cu118.html" + PYG_URL="https://data.pyg.org/whl/torch-2.3.0+cu118.html" #else # CONDA_CUDA_VERSION="12.1" - # PYG_URL="https://data.pyg.org/whl/torch-2.1.0+cu121.html" + # PYG_URL="https://data.pyg.org/whl/torch-2.3.0+cu121.html" #fi # Will automatically install built dependencies of cuGraph-PyG diff --git a/ci/test_wheel_cugraph-dgl.sh b/ci/test_wheel_cugraph-dgl.sh index 564b46cb07e..9b79cb17fe4 100755 --- a/ci/test_wheel_cugraph-dgl.sh +++ b/ci/test_wheel_cugraph-dgl.sh @@ -32,18 +32,8 @@ fi PYTORCH_URL="https://download.pytorch.org/whl/cu${PYTORCH_CUDA_VER}" DGL_URL="https://data.dgl.ai/wheels/cu${PYTORCH_CUDA_VER}/repo.html" -# Starting from 2.2, PyTorch wheels depend on nvidia-nccl-cuxx>=2.19 wheel and -# dynamically link to NCCL. RAPIDS CUDA 11 CI images have an older NCCL version that -# might shadow the newer NCCL required by PyTorch during import (when importing -# `cupy` before `torch`). -if [[ "${NCCL_VERSION}" < "2.19" ]]; then - PYTORCH_VER="2.1.0" -else - PYTORCH_VER="2.3.0" -fi - rapids-logger "Installing PyTorch and DGL" -rapids-retry python -m pip install "torch==${PYTORCH_VER}" --index-url ${PYTORCH_URL} +rapids-retry python -m pip install torch==2.3.0 --index-url ${PYTORCH_URL} rapids-retry python -m pip install dgl==2.0.0 --find-links ${DGL_URL} python -m pytest python/cugraph-dgl/tests diff --git a/ci/test_wheel_cugraph-pyg.sh b/ci/test_wheel_cugraph-pyg.sh index c55ae033344..8f4b16a2dec 100755 --- a/ci/test_wheel_cugraph-pyg.sh +++ b/ci/test_wheel_cugraph-pyg.sh @@ -29,13 +29,13 @@ export CI_RUN=1 if [[ "${CUDA_VERSION}" == "11.8.0" ]]; then PYTORCH_URL="https://download.pytorch.org/whl/cu118" - PYG_URL="https://data.pyg.org/whl/torch-2.1.0+cu118.html" + PYG_URL="https://data.pyg.org/whl/torch-2.3.0+cu118.html" else PYTORCH_URL="https://download.pytorch.org/whl/cu121" - PYG_URL="https://data.pyg.org/whl/torch-2.1.0+cu121.html" + PYG_URL="https://data.pyg.org/whl/torch-2.3.0+cu121.html" fi rapids-logger "Installing PyTorch and PyG dependencies" -rapids-retry python -m pip install torch==2.1.0 --index-url ${PYTORCH_URL} +rapids-retry python -m pip install torch==2.3.0 --index-url ${PYTORCH_URL} rapids-retry python -m pip install "torch-geometric>=2.5,<2.6" rapids-retry python -m pip install \ ogb \ diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index ef075767e70..a23c2395646 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -33,13 +33,13 @@ dependencies: - libraft==24.12.*,>=0.0.0a0 - librmm==24.12.*,>=0.0.0a0 - nbsphinx -- nccl>=2.18.1.1 +- nccl>=2.19 - networkx>=2.5.1 - networkx>=3.0 - ninja - notebook>=0.5.0 - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - nvcc_linux-64=11.8 - ogb @@ -57,7 +57,7 @@ dependencies: - pytest-mpl - pytest-xdist - python-louvain -- pytorch>=2.0,<2.2.0a0 +- pytorch>=2.3,<2.4.0a0 - raft-dask==24.12.*,>=0.0.0a0 - rapids-build-backend>=0.3.1,<0.4.0.dev0 - rapids-dask-dependency==24.12.*,>=0.0.0a0 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 50eab69299d..eca10584304 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -39,13 +39,13 @@ dependencies: - libraft==24.12.*,>=0.0.0a0 - librmm==24.12.*,>=0.0.0a0 - nbsphinx -- nccl>=2.18.1.1 +- nccl>=2.19 - networkx>=2.5.1 - networkx>=3.0 - ninja - notebook>=0.5.0 - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - ogb - openmpi @@ -62,7 +62,7 @@ dependencies: - pytest-mpl - pytest-xdist - python-louvain -- pytorch>=2.0,<2.2.0a0 +- pytorch>=2.3,<2.4.0a0 - raft-dask==24.12.*,>=0.0.0a0 - rapids-build-backend>=0.3.1,<0.4.0.dev0 - rapids-dask-dependency==24.12.*,>=0.0.0a0 diff --git a/conda/recipes/cugraph-dgl/meta.yaml b/conda/recipes/cugraph-dgl/meta.yaml index d1cf6fcd9e9..c80ca6890a8 100644 --- a/conda/recipes/cugraph-dgl/meta.yaml +++ b/conda/recipes/cugraph-dgl/meta.yaml @@ -27,11 +27,11 @@ requirements: - cugraph ={{ version }} - dgl >=1.1.0.cu* - numba >=0.57 - - numpy >=1.23,<2.0a0 + - numpy >=1.23,<3.0a0 - pylibcugraphops ={{ minor_version }} - tensordict >=0.1.2 - python - - pytorch >=2.0 + - pytorch >=2.3,<2.4.0a0 - cupy >=12.0.0 tests: diff --git a/conda/recipes/cugraph-pyg/meta.yaml b/conda/recipes/cugraph-pyg/meta.yaml index 2e1788ac0c6..38d4a3d7d15 100644 --- a/conda/recipes/cugraph-pyg/meta.yaml +++ b/conda/recipes/cugraph-pyg/meta.yaml @@ -29,9 +29,9 @@ requirements: run: - rapids-dask-dependency ={{ minor_version }} - numba >=0.57 - - numpy >=1.23,<2.0a0 + - numpy >=1.23,<3.0a0 - python - - pytorch >=2.0 + - pytorch >=2.3,<2.4.0a0 - cupy >=12.0.0 - cugraph ={{ version }} - pylibcugraphops ={{ minor_version }} diff --git a/conda/recipes/cugraph-service/meta.yaml b/conda/recipes/cugraph-service/meta.yaml index c1027582c78..7df7573e2d0 100644 --- a/conda/recipes/cugraph-service/meta.yaml +++ b/conda/recipes/cugraph-service/meta.yaml @@ -63,7 +63,7 @@ outputs: - dask-cuda ={{ minor_version }} - dask-cudf ={{ minor_version }} - numba >=0.57 - - numpy >=1.23,<2.0a0 + - numpy >=1.23,<3.0a0 - python - rapids-dask-dependency ={{ minor_version }} - thriftpy2 >=0.4.15,!=0.5.0,!=0.5.1 diff --git a/conda/recipes/libcugraph/conda_build_config.yaml b/conda/recipes/libcugraph/conda_build_config.yaml index 6b50d0aad23..55bd635c330 100644 --- a/conda/recipes/libcugraph/conda_build_config.yaml +++ b/conda/recipes/libcugraph/conda_build_config.yaml @@ -17,7 +17,7 @@ doxygen_version: - ">=1.8.11" nccl_version: - - ">=2.18.1.1" + - ">=2.19" c_stdlib: - sysroot diff --git a/cpp/src/c_api/graph_generators.cpp b/cpp/src/c_api/graph_generators.cpp index 7601f1508f9..a58a4d5db35 100644 --- a/cpp/src/c_api/graph_generators.cpp +++ b/cpp/src/c_api/graph_generators.cpp @@ -124,32 +124,41 @@ cugraph_error_code_t cugraph_generate_rmat_edgelists( extern "C" cugraph_type_erased_device_array_view_t* cugraph_coo_get_sources(cugraph_coo_t* coo) { auto internal_pointer = reinterpret_cast(coo); - return reinterpret_cast(internal_pointer->src_->view()); + return (internal_pointer->src_) ? reinterpret_cast( + internal_pointer->src_->view()) + : nullptr; } extern "C" cugraph_type_erased_device_array_view_t* cugraph_coo_get_destinations(cugraph_coo_t* coo) { auto internal_pointer = reinterpret_cast(coo); - return reinterpret_cast(internal_pointer->dst_->view()); + return (internal_pointer->dst_) ? reinterpret_cast( + internal_pointer->dst_->view()) + : nullptr; } extern "C" cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_weights(cugraph_coo_t* coo) { auto internal_pointer = reinterpret_cast(coo); - return reinterpret_cast(internal_pointer->wgt_->view()); + return (internal_pointer->wgt_) ? reinterpret_cast( + internal_pointer->wgt_->view()) + : nullptr; } extern "C" cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_id(cugraph_coo_t* coo) { auto internal_pointer = reinterpret_cast(coo); - return reinterpret_cast(internal_pointer->id_->view()); + return (internal_pointer->id_) ? reinterpret_cast( + internal_pointer->id_->view()) + : nullptr; } extern "C" cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_type(cugraph_coo_t* coo) { auto internal_pointer = reinterpret_cast(coo); - return reinterpret_cast( - internal_pointer->type_->view()); + return (internal_pointer->type_) ? reinterpret_cast( + internal_pointer->type_->view()) + : nullptr; } extern "C" size_t cugraph_coo_list_size(const cugraph_coo_list_t* coo_list) diff --git a/dependencies.yaml b/dependencies.yaml index b8699f708f9..640adf8099f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -524,13 +524,13 @@ dependencies: - &dask rapids-dask-dependency==24.12.*,>=0.0.0a0 - &dask_cuda dask-cuda==24.12.*,>=0.0.0a0 - &numba numba>=0.57 - - &numpy numpy>=1.23,<2.0a0 + - &numpy numpy>=1.23,<3.0a0 - output_types: conda packages: - aiohttp - fsspec>=0.6.0 - requests - - nccl>=2.18.1.1 + - nccl>=2.19 - ucx-proc=*=gpu - &ucx_py_unsuffixed ucx-py==0.41.*,>=0.0.0a0 - output_types: pyproject @@ -695,7 +695,9 @@ dependencies: - output_types: [conda] packages: - *cugraph_unsuffixed - - pytorch>=2.0 + # ceiling could be removed when this is fixed: + # https://github.com/conda-forge/pytorch-cpu-feedstock/issues/254 + - &pytorch_conda pytorch>=2.3,<2.4.0a0 - pytorch-cuda==11.8 - &tensordict tensordict>=0.1.2 - dgl>=1.1.0.cu* @@ -704,7 +706,7 @@ dependencies: - output_types: [conda] packages: - *cugraph_unsuffixed - - pytorch>=2.0 + - *pytorch_conda - pytorch-cuda==11.8 - *tensordict - pyg>=2.5,<2.6 @@ -713,7 +715,7 @@ dependencies: common: - output_types: [conda] packages: - - &pytorch_unsuffixed pytorch>=2.0,<2.2.0a0 + - *pytorch_conda - torchdata - pydantic - ogb @@ -733,7 +735,7 @@ dependencies: matrices: - matrix: {cuda: "12.*"} packages: - - &pytorch_pip torch>=2.0,<2.2.0a0 + - &pytorch_pip torch>=2.3,<2.4.0a0 - *tensordict - matrix: {cuda: "11.*"} packages: diff --git a/python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml b/python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml index 4e6bdc20232..42cbcab5008 100644 --- a/python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml +++ b/python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml @@ -19,7 +19,7 @@ dependencies: - pytest-cov - pytest-xdist - pytorch-cuda==11.8 -- pytorch>=2.0 +- pytorch>=2.3,<2.4.0a0 - scipy - tensordict>=0.1.2 name: cugraph_dgl_dev_cuda-118 diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py index 4ec513cbf9b..ecc51006995 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py @@ -197,10 +197,8 @@ def sample( if g.is_homogeneous: indices = torch.concat(list(indices)) - ds.sample_from_nodes(indices.long(), batch_size=batch_size) - return HomogeneousSampleReader( - ds.get_reader(), self.output_format, self.edge_dir - ) + reader = ds.sample_from_nodes(indices.long(), batch_size=batch_size) + return HomogeneousSampleReader(reader, self.output_format, self.edge_dir) raise ValueError( "Sampling heterogeneous graphs is currently" diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/sampler.py b/python/cugraph-dgl/cugraph_dgl/dataloading/sampler.py index 731ec1b8d6f..7ea608e7e53 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/sampler.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/sampler.py @@ -20,7 +20,6 @@ create_homogeneous_sampled_graphs_from_tensors_csc, ) -from cugraph.gnn import DistSampleReader from cugraph.utilities.utils import import_optional @@ -33,14 +32,18 @@ class SampleReader: Iterator that processes results from the cuGraph distributed sampler. """ - def __init__(self, base_reader: DistSampleReader, output_format: str = "dgl.Block"): + def __init__( + self, + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]], + output_format: str = "dgl.Block", + ): """ Constructs a new SampleReader. Parameters ---------- - base_reader: DistSampleReader - The reader responsible for loading saved samples produced by + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + The iterator responsible for loading saved samples produced by the cuGraph distributed sampler. """ self.__output_format = output_format @@ -83,7 +86,7 @@ class HomogeneousSampleReader(SampleReader): def __init__( self, - base_reader: DistSampleReader, + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]], output_format: str = "dgl.Block", edge_dir="in", ): @@ -92,7 +95,7 @@ def __init__( Parameters ---------- - base_reader: DistSampleReader + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] The reader responsible for loading saved samples produced by the cuGraph distributed sampler. output_format: str diff --git a/python/cugraph-dgl/cugraph_dgl/graph.py b/python/cugraph-dgl/cugraph_dgl/graph.py index 138e645838a..88b93656fa8 100644 --- a/python/cugraph-dgl/cugraph_dgl/graph.py +++ b/python/cugraph-dgl/cugraph_dgl/graph.py @@ -620,9 +620,6 @@ def _get_n_emb( ) try: - print( - u, - ) return self.__ndata_storage[ntype, emb_name].fetch( _cast_to_torch_tensor(u), "cuda" ) diff --git a/python/cugraph-dgl/pyproject.toml b/python/cugraph-dgl/pyproject.toml index c1044efd7e7..e3e12216ac7 100644 --- a/python/cugraph-dgl/pyproject.toml +++ b/python/cugraph-dgl/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ dependencies = [ "cugraph==24.12.*,>=0.0.0a0", "numba>=0.57", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "pylibcugraphops==24.12.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. @@ -40,7 +40,7 @@ test = [ "pytest-xdist", "scipy", "tensordict>=0.1.2", - "torch>=2.0,<2.2.0a0", + "torch>=2.3,<2.4.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project.urls] diff --git a/python/cugraph-equivariant/cugraph_equivariant/tests/pytest.ini b/python/cugraph-equivariant/cugraph_equivariant/tests/pytest.ini new file mode 100644 index 00000000000..7b0a9f29fb1 --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/tests/pytest.ini @@ -0,0 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +[pytest] +addopts = --tb=native diff --git a/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml b/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml index 6a0523f1c2d..39b1ab21edb 100644 --- a/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml +++ b/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml @@ -19,7 +19,7 @@ dependencies: - pytest-cov - pytest-xdist - pytorch-cuda==11.8 -- pytorch>=2.0 +- pytorch>=2.3,<2.4.0a0 - scipy - tensordict>=0.1.2 name: cugraph_pyg_dev_cuda-118 diff --git a/python/cugraph-pyg/cugraph_pyg/__init__.py b/python/cugraph-pyg/cugraph_pyg/__init__.py index 719751c966a..e566e6e9fdd 100644 --- a/python/cugraph-pyg/cugraph_pyg/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,3 +12,8 @@ # limitations under the License. from cugraph_pyg._version import __git_commit__, __version__ + +import cugraph_pyg.data +import cugraph_pyg.loader +import cugraph_pyg.sampler +import cugraph_pyg.nn diff --git a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py index 93ea5700c50..c47dda5eaa5 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py @@ -205,13 +205,18 @@ def _num_vertices(self) -> Dict[str, int]: else edge_attr.size[1] ) else: - if edge_attr.edge_type[0] not in num_vertices: + if edge_attr.edge_type[0] != edge_attr.edge_type[2]: + if edge_attr.edge_type[0] not in num_vertices: + num_vertices[edge_attr.edge_type[0]] = int( + self.__edge_indices[edge_attr.edge_type][0].max() + 1 + ) + if edge_attr.edge_type[2] not in num_vertices: + num_vertices[edge_attr.edge_type[1]] = int( + self.__edge_indices[edge_attr.edge_type][1].max() + 1 + ) + elif edge_attr.edge_type[0] not in num_vertices: num_vertices[edge_attr.edge_type[0]] = int( - self.__edge_indices[edge_attr.edge_type][0].max() + 1 - ) - if edge_attr.edge_type[2] not in num_vertices: - num_vertices[edge_attr.edge_type[1]] = int( - self.__edge_indices[edge_attr.edge_type][1].max() + 1 + self.__edge_indices[edge_attr.edge_type].max() + 1 ) if self.is_multi_gpu: diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py index 7002d7ebded..127ca809d91 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py @@ -185,6 +185,8 @@ def run_train( wall_clock_start, tempdir=None, num_layers=3, + in_memory=False, + seeds_per_call=-1, ): optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) @@ -196,20 +198,23 @@ def run_train( from cugraph_pyg.loader import NeighborLoader ix_train = split_idx["train"].cuda() - train_path = os.path.join(tempdir, f"train_{global_rank}") - os.mkdir(train_path) + train_path = None if in_memory else os.path.join(tempdir, f"train_{global_rank}") + if train_path: + os.mkdir(train_path) train_loader = NeighborLoader( data, input_nodes=ix_train, directory=train_path, shuffle=True, drop_last=True, + local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, **kwargs, ) ix_test = split_idx["test"].cuda() - test_path = os.path.join(tempdir, f"test_{global_rank}") - os.mkdir(test_path) + test_path = None if in_memory else os.path.join(tempdir, f"test_{global_rank}") + if test_path: + os.mkdir(test_path) test_loader = NeighborLoader( data, input_nodes=ix_test, @@ -221,14 +226,16 @@ def run_train( ) ix_valid = split_idx["valid"].cuda() - valid_path = os.path.join(tempdir, f"valid_{global_rank}") - os.mkdir(valid_path) + valid_path = None if in_memory else os.path.join(tempdir, f"valid_{global_rank}") + if valid_path: + os.mkdir(valid_path) valid_loader = NeighborLoader( data, input_nodes=ix_valid, directory=valid_path, shuffle=True, drop_last=True, + local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, **kwargs, ) @@ -347,6 +354,9 @@ def parse_args(): parser.add_argument("--skip_partition", action="store_true") parser.add_argument("--wg_mem_type", type=str, default="distributed") + parser.add_argument("--in_memory", action="store_true", default=False) + parser.add_argument("--seeds_per_call", type=int, default=-1) + return parser.parse_args() @@ -429,6 +439,8 @@ def parse_args(): wall_clock_start, tempdir, args.num_layers, + args.in_memory, + args.seeds_per_call, ) else: warnings.warn("This script should be run with 'torchrun`. Exiting.") diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py index 09d874bd87d..0f9c39bf04d 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py @@ -91,10 +91,20 @@ def test(loader: NeighborLoader, val_steps: Optional[int] = None): def create_loader( - data, num_neighbors, input_nodes, replace, batch_size, samples_dir, stage_name + data, + num_neighbors, + input_nodes, + replace, + batch_size, + samples_dir, + stage_name, + local_seeds_per_call, ): - directory = os.path.join(samples_dir, stage_name) - os.mkdir(directory) + if samples_dir is not None: + directory = os.path.join(samples_dir, stage_name) + os.mkdir(directory) + else: + directory = None return NeighborLoader( data, num_neighbors=num_neighbors, @@ -102,6 +112,7 @@ def create_loader( replace=replace, batch_size=batch_size, directory=directory, + local_seeds_per_call=local_seeds_per_call, ) @@ -147,6 +158,8 @@ def parse_args(): parser.add_argument("--tempdir_root", type=str, default=None) parser.add_argument("--dataset_root", type=str, default="dataset") parser.add_argument("--dataset", type=str, default="ogbn-products") + parser.add_argument("--in_memory", action="store_true", default=False) + parser.add_argument("--seeds_per_call", type=int, default=-1) return parser.parse_args() @@ -170,7 +183,10 @@ def parse_args(): "num_neighbors": [args.fan_out] * args.num_layers, "replace": False, "batch_size": args.batch_size, - "samples_dir": samples_dir, + "samples_dir": None if args.in_memory else samples_dir, + "local_seeds_per_call": None + if args.seeds_per_call <= 0 + else args.seeds_per_call, } train_loader = create_loader( diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py index b1bb0240e71..73efbc92a24 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py @@ -86,6 +86,8 @@ def run_train( wall_clock_start, tempdir=None, num_layers=3, + in_memory=False, + seeds_per_call=-1, ): init_pytorch_worker( @@ -119,20 +121,23 @@ def run_train( dist.barrier() ix_train = torch.tensor_split(split_idx["train"], world_size)[rank].cuda() - train_path = os.path.join(tempdir, f"train_{rank}") - os.mkdir(train_path) + train_path = None if in_memory else os.path.join(tempdir, f"train_{rank}") + if train_path: + os.mkdir(train_path) train_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_train, directory=train_path, shuffle=True, drop_last=True, + local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, **kwargs, ) ix_test = torch.tensor_split(split_idx["test"], world_size)[rank].cuda() - test_path = os.path.join(tempdir, f"test_{rank}") - os.mkdir(test_path) + test_path = None if in_memory else os.path.join(tempdir, f"test_{rank}") + if test_path: + os.mkdir(test_path) test_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_test, @@ -144,14 +149,16 @@ def run_train( ) ix_valid = torch.tensor_split(split_idx["valid"], world_size)[rank].cuda() - valid_path = os.path.join(tempdir, f"valid_{rank}") - os.mkdir(valid_path) + valid_path = None if in_memory else os.path.join(tempdir, f"valid_{rank}") + if valid_path: + os.mkdir(valid_path) valid_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_valid, directory=valid_path, shuffle=True, drop_last=True, + local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, **kwargs, ) @@ -269,6 +276,8 @@ def run_train( parser.add_argument("--tempdir_root", type=str, default=None) parser.add_argument("--dataset_root", type=str, default="dataset") parser.add_argument("--dataset", type=str, default="ogbn-products") + parser.add_argument("--in_memory", action="store_true", default=False) + parser.add_argument("--seeds_per_call", type=int, default=-1) parser.add_argument( "--n_devices", @@ -322,6 +331,8 @@ def run_train( wall_clock_start, tempdir, args.num_layers, + args.in_memory, + args.seeds_per_call, ), nprocs=world_size, join=True, diff --git a/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_mnmg.py b/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_mnmg.py new file mode 100644 index 00000000000..5c75e01e6f5 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_mnmg.py @@ -0,0 +1,418 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example illustrates link classification using the ogbl-wikikg2 dataset. + +import os +import json +import argparse +import warnings + +import torch + +import torch.nn.functional as F +from torch.nn import Parameter +from torch_geometric.nn import FastRGCNConv, GAE +from torch.nn.parallel import DistributedDataParallel + +from ogb.linkproppred import PygLinkPropPredDataset + +import cugraph_pyg + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_create_unique_id, + cugraph_comms_shutdown, +) + +from pylibwholegraph.torch.initialize import ( + init as wm_init, + finalize as wm_finalize, +) + + +# Enable cudf spilling to save gpu memory +from cugraph.testing.mg_utils import enable_spilling + +# Ensures that a CUDA context is not created on import of rapids. +# Allows pytorch to create the context instead +os.environ["RAPIDS_NO_INITIALIZE"] = "1" + + +def init_pytorch_worker(global_rank, local_rank, world_size, uid): + import rmm + + rmm.reinitialize(devices=[local_rank], pool_allocator=True, managed_memory=True) + + import cupy + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + cugraph_comms_init( + global_rank, + world_size, + uid, + local_rank, + ) + + wm_init(global_rank, world_size, local_rank, torch.cuda.device_count()) + + enable_spilling() + + +class RGCNEncoder(torch.nn.Module): + def __init__(self, num_nodes, hidden_channels, num_relations, num_bases=30): + super().__init__() + self.node_emb = Parameter(torch.empty(num_nodes, hidden_channels)) + self.conv1 = FastRGCNConv( + hidden_channels, hidden_channels, num_relations, num_bases=num_bases + ) + self.conv2 = FastRGCNConv( + hidden_channels, hidden_channels, num_relations, num_bases=num_bases + ) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.node_emb) + self.conv1.reset_parameters() + self.conv2.reset_parameters() + + def forward(self, edge_index, edge_type): + x = self.node_emb + x = self.conv1(x, edge_index, edge_type).relu_() + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv2(x, edge_index, edge_type) + return x + + +def train(epoch, model, optimizer, train_loader, edge_feature_store, num_steps=None): + model.train() + optimizer.zero_grad() + + for i, batch in enumerate(train_loader): + r = edge_feature_store[("n", "e", "n"), "rel"][batch.e_id].flatten().cuda() + z = model.encode(batch.edge_index, r) + + loss = model.recon_loss(z, batch.edge_index) + loss.backward() + optimizer.step() + + if i % 10 == 0: + print( + f"Epoch: {epoch:02d}, Iteration: {i:02d}, Loss: {loss:.4f}", flush=True + ) + if num_steps and i == num_steps: + break + + +def test(stage, epoch, model, loader, num_steps=None): + # TODO support ROC-AUC metric + # Predict probabilities of future edges + model.eval() + + rr = 0.0 + for i, (h, h_neg, t, t_neg, r) in enumerate(loader): + if num_steps and i >= num_steps: + break + + ei = torch.concatenate( + [ + torch.stack([h, t]).cuda(), + torch.stack([h_neg.flatten(), t_neg.flatten()]).cuda(), + ], + dim=-1, + ) + + r = torch.concatenate([r, torch.repeat_interleave(r, h_neg.shape[-1])]).cuda() + + z = model.encode(ei, r) + q = model.decode(z, ei) + + _, ix = torch.sort(q, descending=True) + rr += 1.0 / (1.0 + ix[0]) + + print(f"epoch {epoch:02d} {stage} mrr:", rr / i, flush=True) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden_channels", type=int, default=128) + parser.add_argument("--num_layers", type=int, default=1) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=16384) + parser.add_argument("--num_neg", type=int, default=500) + parser.add_argument("--num_pos", type=int, default=-1) + parser.add_argument("--fan_out", type=int, default=10) + parser.add_argument("--dataset", type=str, default="ogbl-wikikg2") + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--seeds_per_call", type=int, default=-1) + parser.add_argument("--n_devices", type=int, default=-1) + parser.add_argument("--skip_partition", action="store_true") + + return parser.parse_args() + + +def run_train(rank, world_size, model, data, edge_feature_store, meta, splits, args): + model = model.to(rank) + model = GAE(DistributedDataParallel(model, device_ids=[rank])) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + eli = torch.stack([splits["train"]["head"], splits["train"]["tail"]]) + + train_loader = cugraph_pyg.loader.LinkNeighborLoader( + data, + [args.fan_out] * args.num_layers, + edge_label_index=eli, + local_seeds_per_call=args.seeds_per_call if args.seeds_per_call > 0 else None, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + ) + + def get_eval_loader(stage: str): + head = splits[stage]["head"] + tail = splits[stage]["tail"] + + head_neg = splits[stage]["head_neg"][:, : args.num_neg] + tail_neg = splits[stage]["tail_neg"][:, : args.num_neg] + + rel = splits[stage]["relation"] + + return torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + head.pin_memory(), + head_neg.pin_memory(), + tail.pin_memory(), + tail_neg.pin_memory(), + rel.pin_memory(), + ), + batch_size=1, + shuffle=False, + drop_last=True, + ) + + test_loader = get_eval_loader("test") + valid_loader = get_eval_loader("valid") + + num_train_steps = (args.num_pos // args.batch_size) if args.num_pos > 0 else 100 + + for epoch in range(1, 1 + args.epochs): + train( + epoch, + model, + optimizer, + train_loader, + edge_feature_store, + num_steps=num_train_steps, + ) + test("validation", epoch, model, valid_loader, num_steps=1024) + + test("test", epoch, model, test_loader, num_steps=1024) + + wm_finalize() + cugraph_comms_shutdown() + + +def partition_data( + data, splits, meta, edge_path, rel_path, pos_path, neg_path, meta_path +): + # Split and save edge index + os.makedirs( + edge_path, + exist_ok=True, + ) + for (r, e) in enumerate(torch.tensor_split(data.edge_index, world_size, dim=1)): + rank_path = os.path.join(edge_path, f"rank={r}.pt") + torch.save( + e.clone(), + rank_path, + ) + + # Split and save edge reltypes + os.makedirs( + rel_path, + exist_ok=True, + ) + for (r, f) in enumerate(torch.tensor_split(data.edge_reltype, world_size)): + rank_path = os.path.join(rel_path, f"rank={r}.pt") + torch.save( + f.clone(), + rank_path, + ) + + # Split and save positive edges + os.makedirs( + pos_path, + exist_ok=True, + ) + for stage in ["train", "test", "valid"]: + for (r, n) in enumerate( + torch.tensor_split( + torch.stack([splits[stage]["head"], splits[stage]["tail"]]), + world_size, + dim=-1, + ) + ): + rank_path = os.path.join(pos_path, f"rank={r}_{stage}.pt") + torch.save( + n.clone(), + rank_path, + ) + + # Split and save negative edges + os.makedirs( + neg_path, + exist_ok=True, + ) + for stage in ["test", "valid"]: + for (r, n) in enumerate( + torch.tensor_split( + torch.stack([splits[stage]["head_neg"], splits[stage]["tail_neg"]]), + world_size, + dim=1, + ) + ): + rank_path = os.path.join(neg_path, f"rank={r}_{stage}.pt") + torch.save(n.clone(), rank_path) + for (r, n) in enumerate( + torch.tensor_split(splits[stage]["relation"], world_size, dim=-1) + ): + print(n) + rank_path = os.path.join(neg_path, f"rank={r}_{stage}_relation.pt") + torch.save(n.clone(), rank_path) + + with open(meta_path, "w") as f: + json.dump(meta, f) + + +def load_partitioned_data(rank, edge_path, rel_path, pos_path, neg_path, meta_path): + from cugraph_pyg.data import GraphStore, WholeFeatureStore, TensorDictFeatureStore + + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() + edge_feature_store = WholeFeatureStore() + + # Load edge index + graph_store[("n", "e", "n"), "coo"] = torch.load( + os.path.join(edge_path, f"rank={rank}.pt") + ) + + # Load edge rel type + edge_feature_store[("n", "e", "n"), "rel"] = torch.load( + os.path.join(rel_path, f"rank={rank}.pt") + ) + + splits = {} + + # Load positive edges + for stage in ["train", "test", "valid"]: + head, tail = torch.load(os.path.join(pos_path, f"rank={rank}_{stage}.pt")) + splits[stage] = { + "head": head, + "tail": tail, + } + + # Load negative edges + for stage in ["test", "valid"]: + head_neg, tail_neg = torch.load( + os.path.join(neg_path, f"rank={rank}_{stage}.pt") + ) + relation = torch.load( + os.path.join(neg_path, f"rank={rank}_{stage}_relation.pt") + ) + splits[stage]["head_neg"] = head_neg + splits[stage]["tail_neg"] = tail_neg + splits[stage]["relation"] = relation + + with open(meta_path, "r") as f: + meta = json.load(f) + + return (feature_store, graph_store), edge_feature_store, splits, meta + + +if __name__ == "__main__": + args = parse_args() + + if "LOCAL_RANK" in os.environ: + torch.distributed.init_process_group("nccl") + world_size = torch.distributed.get_world_size() + global_rank = torch.distributed.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(local_rank) + + # Create the uid needed for cuGraph comms + if global_rank == 0: + cugraph_id = [cugraph_comms_create_unique_id()] + else: + cugraph_id = [None] + torch.distributed.broadcast_object_list(cugraph_id, src=0, device=device) + cugraph_id = cugraph_id[0] + + init_pytorch_worker(global_rank, local_rank, world_size, cugraph_id) + + # Split the data + edge_path = os.path.join(args.dataset_root, args.dataset + "_eix_part") + rel_path = os.path.join(args.dataset_root, args.dataset + "_rel_part") + pos_path = os.path.join(args.dataset_root, args.dataset + "_e_pos_part") + neg_path = os.path.join(args.dataset_root, args.dataset + "_e_neg_part") + meta_path = os.path.join(args.dataset_root, args.dataset + "_meta.json") + + if not args.skip_partition and global_rank == 0: + data = PygLinkPropPredDataset(args.dataset, root=args.dataset_root) + dataset = data[0] + + splits = data.get_edge_split() + + meta = {} + meta["num_nodes"] = int(dataset.num_nodes) + meta["num_rels"] = int(dataset.edge_reltype.max()) + 1 + + partition_data( + dataset, + splits, + meta, + edge_path=edge_path, + rel_path=rel_path, + pos_path=pos_path, + neg_path=neg_path, + meta_path=meta_path, + ) + del data + del dataset + del splits + torch.distributed.barrier() + + # Load partitions + data, edge_feature_store, splits, meta = load_partitioned_data( + rank=global_rank, + edge_path=edge_path, + rel_path=rel_path, + pos_path=pos_path, + neg_path=neg_path, + meta_path=meta_path, + ) + torch.distributed.barrier() + + model = RGCNEncoder( + meta["num_nodes"], + hidden_channels=args.hidden_channels, + num_relations=meta["num_rels"], + ) + + run_train( + global_rank, world_size, model, data, edge_feature_store, meta, splits, args + ) + else: + warnings.warn("This script should be run with 'torchrun`. Exiting.") diff --git a/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_sg.py new file mode 100644 index 00000000000..67d7eecc7c2 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_sg.py @@ -0,0 +1,219 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example illustrates link classification using the ogbl-wikikg2 dataset. + +import argparse + +from typing import Tuple, Dict, Any + +import torch +import cupy + +import rmm +from rmm.allocators.cupy import rmm_cupy_allocator +from rmm.allocators.torch import rmm_torch_allocator + +# Must change allocators immediately upon import +# or else other imports will cause memory to be +# allocated and prevent changing the allocator +rmm.reinitialize(devices=[0], pool_allocator=True, managed_memory=True) +cupy.cuda.set_allocator(rmm_cupy_allocator) +torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + +import torch.nn.functional as F # noqa: E402 +from torch.nn import Parameter # noqa: E402 +from torch_geometric.nn import FastRGCNConv, GAE # noqa: E402 +import torch_geometric # noqa: E402 +import cugraph_pyg # noqa: E402 + +# Enable cudf spilling to save gpu memory +from cugraph.testing.mg_utils import enable_spilling # noqa: E402 + +enable_spilling() + + +class RGCNEncoder(torch.nn.Module): + def __init__(self, num_nodes, hidden_channels, num_relations, num_bases=30): + super().__init__() + self.node_emb = Parameter(torch.empty(num_nodes, hidden_channels)) + self.conv1 = FastRGCNConv( + hidden_channels, hidden_channels, num_relations, num_bases=num_bases + ) + self.conv2 = FastRGCNConv( + hidden_channels, hidden_channels, num_relations, num_bases=num_bases + ) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.node_emb) + self.conv1.reset_parameters() + self.conv2.reset_parameters() + + def forward(self, edge_index, edge_type): + x = self.node_emb + x = self.conv1(x, edge_index, edge_type).relu_() + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv2(x, edge_index, edge_type) + return x + + +def load_data( + dataset_str, dataset_root: str +) -> Tuple[ + Tuple["torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore"], + "torch_geometric.data.FeatureStore", + Dict[str, Dict[str, "torch.Tensor"]], + Dict[str, Any], +]: + from ogb.linkproppred import PygLinkPropPredDataset + + data = PygLinkPropPredDataset(dataset_str, root=dataset_root) + dataset = data[0] + + splits = data.get_edge_split() + + from cugraph_pyg.data import GraphStore, TensorDictFeatureStore + + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() + edge_feature_store = TensorDictFeatureStore() + meta = {} + + graph_store[("n", "e", "n"), "coo"] = dataset.edge_index + edge_feature_store[("n", "e", "n"), "rel"] = dataset.edge_reltype.pin_memory() + meta["num_nodes"] = dataset.num_nodes + meta["num_rels"] = dataset.edge_reltype.max() + 1 + + return (feature_store, graph_store), edge_feature_store, splits, meta + + +def train(epoch, model, optimizer, train_loader, edge_feature_store): + model.train() + optimizer.zero_grad() + + for i, batch in enumerate(train_loader): + r = edge_feature_store[("n", "e", "n"), "rel"][batch.e_id].flatten().cuda() + z = model.encode(batch.edge_index, r) + + loss = model.recon_loss(z, batch.edge_index) + loss.backward() + optimizer.step() + + if i % 10 == 0: + print(f"Epoch: {epoch:02d}, Iteration: {i:02d}, Loss: {loss:.4f}") + if i == 100: + break + + +def test(stage, epoch, model, loader, num_steps=None): + # TODO support ROC-AUC metric + # Predict probabilities of future edges + model.eval() + + rr = 0.0 + for i, (h, h_neg, t, t_neg, r) in enumerate(loader): + if num_steps and i >= num_steps: + break + + ei = torch.concatenate( + [ + torch.stack([h, t]).cuda(), + torch.stack([h_neg.flatten(), t_neg.flatten()]).cuda(), + ], + dim=-1, + ) + + r = torch.concatenate([r, torch.repeat_interleave(r, h_neg.shape[-1])]).cuda() + + z = model.encode(ei, r) + q = model.decode(z, ei) + + _, ix = torch.sort(q, descending=True) + rr += 1.0 / (1.0 + ix[0]) + + print(f"epoch {epoch:02d} {stage} mrr:", rr / i) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden_channels", type=int, default=128) + parser.add_argument("--num_layers", type=int, default=1) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=16384) + parser.add_argument("--num_neg", type=int, default=500) + parser.add_argument("--fan_out", type=int, default=10) + parser.add_argument("--dataset", type=str, default="ogbl-wikikg2") + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--seeds_per_call", type=int, default=-1) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + data, edge_feature_store, splits, meta = load_data(args.dataset, args.dataset_root) + + model = GAE( + RGCNEncoder( + meta["num_nodes"], + hidden_channels=args.hidden_channels, + num_relations=meta["num_rels"], + ) + ).cuda() + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + train_loader = cugraph_pyg.loader.LinkNeighborLoader( + data, + [args.fan_out] * args.num_layers, + edge_label_index=torch.stack( + [splits["train"]["head"], splits["train"]["tail"]] + ), + local_seeds_per_call=args.seeds_per_call if args.seeds_per_call > 0 else None, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + ) + + def get_eval_loader(stage: str): + head = splits[stage]["head"] + tail = splits[stage]["tail"] + + head_neg = splits[stage]["head_neg"][:, : args.num_neg] + tail_neg = splits[stage]["tail_neg"][:, : args.num_neg] + + rel = splits[stage]["relation"] + + return torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + head.pin_memory(), + head_neg.pin_memory(), + tail.pin_memory(), + tail_neg.pin_memory(), + rel.pin_memory(), + ), + batch_size=1, + shuffle=False, + drop_last=True, + ) + + test_loader = get_eval_loader("test") + valid_loader = get_eval_loader("valid") + + for epoch in range(1, 1 + args.epochs): + train(epoch, model, optimizer, train_loader, edge_feature_store) + test("validation", epoch, model, valid_loader, num_steps=1024) + + test("test", epoch, model, test_loader, num_steps=1024) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py b/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py new file mode 100644 index 00000000000..2c0ae53a08e --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py @@ -0,0 +1,320 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example illustrates link classification using the ogbl-wikikg2 dataset. + +import os +import argparse +import warnings + +from typing import Tuple, Any + +import torch + +import torch.nn.functional as F +from torch.nn import Parameter +from torch_geometric.nn import FastRGCNConv, GAE +from torch.nn.parallel import DistributedDataParallel + +import torch_geometric +import cugraph_pyg + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_create_unique_id, + cugraph_comms_shutdown, +) + +from pylibwholegraph.torch.initialize import ( + init as wm_init, + finalize as wm_finalize, +) + + +# Enable cudf spilling to save gpu memory +from cugraph.testing.mg_utils import enable_spilling + +# Ensures that a CUDA context is not created on import of rapids. +# Allows pytorch to create the context instead +os.environ["RAPIDS_NO_INITIALIZE"] = "1" + + +def init_pytorch_worker(rank, world_size, uid): + import rmm + + rmm.reinitialize(devices=[rank], pool_allocator=True, managed_memory=True) + + import cupy + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + cugraph_comms_init( + rank, + world_size, + uid, + rank, + ) + + wm_init(rank, world_size, rank, world_size) + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group( + "nccl", + rank=rank, + world_size=world_size, + ) + + enable_spilling() + + +class RGCNEncoder(torch.nn.Module): + def __init__(self, num_nodes, hidden_channels, num_relations, num_bases=30): + super().__init__() + self.node_emb = Parameter(torch.empty(num_nodes, hidden_channels)) + self.conv1 = FastRGCNConv( + hidden_channels, hidden_channels, num_relations, num_bases=num_bases + ) + self.conv2 = FastRGCNConv( + hidden_channels, hidden_channels, num_relations, num_bases=num_bases + ) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.node_emb) + self.conv1.reset_parameters() + self.conv2.reset_parameters() + + def forward(self, edge_index, edge_type): + x = self.node_emb + x = self.conv1(x, edge_index, edge_type).relu_() + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv2(x, edge_index, edge_type) + return x + + +def load_data( + rank: int, + world_size: int, + data: Any, +) -> Tuple[ + Tuple["torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore"], + "torch_geometric.data.FeatureStore", +]: + from cugraph_pyg.data import GraphStore, WholeFeatureStore, TensorDictFeatureStore + + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() # empty fs required by PyG + edge_feature_store = WholeFeatureStore() + + graph_store[("n", "e", "n"), "coo"] = torch.tensor_split( + data.edge_index.cuda(), world_size, dim=1 + )[rank] + + edge_feature_store[("n", "e", "n"), "rel"] = torch.tensor_split( + data.edge_reltype.cuda(), + world_size, + )[rank] + + return (feature_store, graph_store), edge_feature_store + + +def train(epoch, model, optimizer, train_loader, edge_feature_store, num_steps=None): + model.train() + optimizer.zero_grad() + + for i, batch in enumerate(train_loader): + r = edge_feature_store[("n", "e", "n"), "rel"][batch.e_id].flatten().cuda() + z = model.encode(batch.edge_index, r) + + loss = model.recon_loss(z, batch.edge_index) + loss.backward() + optimizer.step() + + if i % 10 == 0: + print( + f"Epoch: {epoch:02d}, Iteration: {i:02d}, Loss: {loss:.4f}", flush=True + ) + if num_steps and i == num_steps: + break + + +def test(stage, epoch, model, loader, num_steps=None): + # TODO support ROC-AUC metric + # Predict probabilities of future edges + model.eval() + + rr = 0.0 + for i, (h, h_neg, t, t_neg, r) in enumerate(loader): + if num_steps and i >= num_steps: + break + + ei = torch.concatenate( + [ + torch.stack([h, t]).cuda(), + torch.stack([h_neg.flatten(), t_neg.flatten()]).cuda(), + ], + dim=-1, + ) + + r = torch.concatenate([r, torch.repeat_interleave(r, h_neg.shape[-1])]).cuda() + + z = model.encode(ei, r) + q = model.decode(z, ei) + + _, ix = torch.sort(q, descending=True) + rr += 1.0 / (1.0 + ix[0]) + + print(f"epoch {epoch:02d} {stage} mrr:", rr / i, flush=True) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden_channels", type=int, default=128) + parser.add_argument("--num_layers", type=int, default=1) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=16384) + parser.add_argument("--num_neg", type=int, default=500) + parser.add_argument("--num_pos", type=int, default=-1) + parser.add_argument("--fan_out", type=int, default=10) + parser.add_argument("--dataset", type=str, default="ogbl-wikikg2") + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--seeds_per_call", type=int, default=-1) + parser.add_argument("--n_devices", type=int, default=-1) + + return parser.parse_args() + + +def run_train(rank, world_size, uid, model, data, meta, splits, args): + init_pytorch_worker( + rank, + world_size, + uid, + ) + + model = model.to(rank) + model = GAE(DistributedDataParallel(model, device_ids=[rank])) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + data, edge_feature_store = load_data(rank, world_size, data) + + eli = torch.stack( + [ + torch.tensor_split(splits["train"]["head"], world_size)[rank], + torch.tensor_split(splits["train"]["tail"], world_size)[rank], + ] + ) + + train_loader = cugraph_pyg.loader.LinkNeighborLoader( + data, + [args.fan_out] * args.num_layers, + edge_label_index=eli, + local_seeds_per_call=args.seeds_per_call if args.seeds_per_call > 0 else None, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + ) + + def get_eval_loader(stage: str): + head = torch.tensor_split(splits[stage]["head"], world_size)[rank] + tail = torch.tensor_split(splits[stage]["tail"], world_size)[rank] + + head_neg = torch.tensor_split( + splits[stage]["head_neg"][:, : args.num_neg], world_size + )[rank] + tail_neg = torch.tensor_split( + splits[stage]["tail_neg"][:, : args.num_neg], world_size + )[rank] + + rel = torch.tensor_split(splits[stage]["relation"], world_size)[rank] + + return torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + head.pin_memory(), + head_neg.pin_memory(), + tail.pin_memory(), + tail_neg.pin_memory(), + rel.pin_memory(), + ), + batch_size=1, + shuffle=False, + drop_last=True, + ) + + test_loader = get_eval_loader("test") + valid_loader = get_eval_loader("valid") + + num_train_steps = (args.num_pos // args.batch_size) if args.num_pos > 0 else 100 + + for epoch in range(1, 1 + args.epochs): + train( + epoch, + model, + optimizer, + train_loader, + edge_feature_store, + num_steps=num_train_steps, + ) + test("validation", epoch, model, valid_loader, num_steps=1024) + + test("test", epoch, model, test_loader, num_steps=1024) + + wm_finalize() + cugraph_comms_shutdown() + + +if __name__ == "__main__": + if "CI_RUN" in os.environ and os.environ["CI_RUN"] == "1": + warnings.warn("Skipping SMNG example in CI due to memory limit") + else: + args = parse_args() + + # change the allocator before any allocations are made + from rmm.allocators.torch import rmm_torch_allocator + + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + # import ogb here to stop it from creating a context and breaking pytorch/rmm + from ogb.linkproppred import PygLinkPropPredDataset + + data = PygLinkPropPredDataset(args.dataset, root=args.dataset_root) + dataset = data[0] + + splits = data.get_edge_split() + + meta = {} + meta["num_nodes"] = dataset.num_nodes + meta["num_rels"] = dataset.edge_reltype.max() + 1 + + model = RGCNEncoder( + meta["num_nodes"], + hidden_channels=args.hidden_channels, + num_relations=meta["num_rels"], + ) + + print("Data =", data) + if args.n_devices == -1: + world_size = torch.cuda.device_count() + else: + world_size = args.n_devices + print("Using", world_size, "GPUs...") + + uid = cugraph_comms_create_unique_id() + torch.multiprocessing.spawn( + run_train, + (world_size, uid, model, data, meta, splits, args), + nprocs=world_size, + join=True, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/__init__.py b/python/cugraph-pyg/cugraph_pyg/loader/__init__.py index cad66aaa183..c804b3d1f97 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/__init__.py @@ -16,6 +16,9 @@ from cugraph_pyg.loader.node_loader import NodeLoader from cugraph_pyg.loader.neighbor_loader import NeighborLoader +from cugraph_pyg.loader.link_loader import LinkLoader +from cugraph_pyg.loader.link_neighbor_loader import LinkNeighborLoader + from cugraph_pyg.loader.dask_node_loader import DaskNeighborLoader from cugraph_pyg.loader.dask_node_loader import BulkSampleLoader diff --git a/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py new file mode 100644 index 00000000000..77e2ac4f99d --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import cugraph_pyg +from typing import Union, Tuple, Callable, Optional + +from cugraph.utilities.utils import import_optional + +torch_geometric = import_optional("torch_geometric") +torch = import_optional("torch") + + +class LinkLoader: + """ + Duck-typed version of torch_geometric.loader.LinkLoader. + Loads samples from batches of input nodes using a + `~cugraph_pyg.sampler.BaseSampler.sample_from_edges` + function. + """ + + def __init__( + self, + data: Union[ + "torch_geometric.data.Data", + "torch_geometric.data.HeteroData", + Tuple[ + "torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore" + ], + ], + link_sampler: "cugraph_pyg.sampler.BaseSampler", + edge_label_index: "torch_geometric.typing.InputEdges" = None, + edge_label: "torch_geometric.typing.OptTensor" = None, + edge_label_time: "torch_geometric.typing.OptTensor" = None, + neg_sampling: Optional["torch_geometric.sampler.NegativeSampling"] = None, + neg_sampling_ratio: Optional[Union[int, float]] = None, + transform: Optional[Callable] = None, + transform_sampler_output: Optional[Callable] = None, + filter_per_worker: Optional[bool] = None, + custom_cls: Optional["torch_geometric.data.HeteroData"] = None, + input_id: "torch_geometric.typing.OptTensor" = None, + batch_size: int = 1, # refers to number of edges in batch + shuffle: bool = False, + drop_last: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + data: Data, HeteroData, or Tuple[FeatureStore, GraphStore] + See torch_geometric.loader.NodeLoader. + link_sampler: BaseSampler + See torch_geometric.loader.LinkLoader. + edge_label_index: InputEdges + See torch_geometric.loader.LinkLoader. + edge_label: OptTensor + See torch_geometric.loader.LinkLoader. + edge_label_time: OptTensor + See torch_geometric.loader.LinkLoader. + neg_sampling: Optional[NegativeSampling] + Type of negative sampling to perform, if desired. + See torch_geometric.loader.LinkLoader. + neg_sampling_ratio: Optional[Union[int, float]] + Negative sampling ratio. Affects how many negative + samples are generated. + See torch_geometric.loader.LinkLoader. + transform: Callable (optional, default=None) + This argument currently has no effect. + transform_sampler_output: Callable (optional, default=None) + This argument currently has no effect. + filter_per_worker: bool (optional, default=False) + This argument currently has no effect. + custom_cls: HeteroData + This argument currently has no effect. This loader will + always return a Data or HeteroData object. + input_id: OptTensor + See torch_geometric.loader.LinkLoader. + + """ + if not isinstance(data, (list, tuple)) or not isinstance( + data[1], cugraph_pyg.data.GraphStore + ): + # Will eventually automatically convert these objects to cuGraph objects. + raise NotImplementedError("Currently can't accept non-cugraph graphs") + + if not isinstance(link_sampler, cugraph_pyg.sampler.BaseSampler): + raise NotImplementedError("Must provide a cuGraph sampler") + + if edge_label_time is not None: + raise ValueError("Temporal sampling is currently unsupported") + + if filter_per_worker: + warnings.warn("filter_per_worker is currently ignored") + + if custom_cls is not None: + warnings.warn("custom_cls is currently ignored") + + if transform is not None: + warnings.warn("transform is currently ignored.") + + if transform_sampler_output is not None: + warnings.warn("transform_sampler_output is currently ignored.") + + if neg_sampling_ratio is not None: + warnings.warn( + "The 'neg_sampling_ratio' argument is deprecated in PyG" + " and is not supported in cuGraph-PyG." + ) + + neg_sampling = torch_geometric.sampler.NegativeSampling.cast(neg_sampling) + + ( + input_type, + edge_label_index, + ) = torch_geometric.loader.utils.get_edge_label_index( + data, + (None, edge_label_index), + ) + + self.__input_data = torch_geometric.sampler.EdgeSamplerInput( + input_id=torch.arange( + edge_label_index[0].numel(), dtype=torch.int64, device="cuda" + ) + if input_id is None + else input_id, + row=edge_label_index[0], + col=edge_label_index[1], + label=edge_label, + time=edge_label_time, + input_type=input_type, + ) + + # Edge label check from torch_geometric.loader.LinkLoader + if ( + neg_sampling is not None + and neg_sampling.is_binary() + and edge_label is not None + and edge_label.min() == 0 + ): + edge_label = edge_label + 1 + + if ( + neg_sampling is not None + and neg_sampling.is_triplet() + and edge_label is not None + ): + raise ValueError( + "'edge_label' needs to be undefined for " + "'triplet'-based negative sampling. Please use " + "`src_index`, `dst_pos_index` and " + "`neg_pos_index` of the returned mini-batch " + "instead to differentiate between positive and " + "negative samples." + ) + + self.__data = data + + self.__link_sampler = link_sampler + self.__neg_sampling = neg_sampling + + self.__batch_size = batch_size + self.__shuffle = shuffle + self.__drop_last = drop_last + + def __iter__(self): + if self.__shuffle: + perm = torch.randperm(self.__input_data.row.numel()) + else: + perm = torch.arange(self.__input_data.row.numel()) + + if self.__drop_last: + d = perm.numel() % self.__batch_size + perm = perm[:-d] + + input_data = torch_geometric.sampler.EdgeSamplerInput( + input_id=self.__input_data.input_id[perm], + row=self.__input_data.row[perm], + col=self.__input_data.col[perm], + label=None + if self.__input_data.label is None + else self.__input_data.label[perm], + time=None + if self.__input_data.time is None + else self.__input_data.time[perm], + input_type=self.__input_data.input_type, + ) + + return cugraph_pyg.sampler.SampleIterator( + self.__data, + self.__link_sampler.sample_from_edges( + input_data, + neg_sampling=self.__neg_sampling, + ), + ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py new file mode 100644 index 00000000000..080565368c4 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py @@ -0,0 +1,243 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from typing import Union, Tuple, Optional, Callable, List, Dict + +import cugraph_pyg +from cugraph_pyg.loader import LinkLoader +from cugraph_pyg.sampler import BaseSampler + +from cugraph.gnn import NeighborSampler, DistSampleWriter +from cugraph.utilities.utils import import_optional + +torch_geometric = import_optional("torch_geometric") + + +class LinkNeighborLoader(LinkLoader): + """ + Duck-typed version of torch_geometric.loader.LinkNeighborLoader + + Link loader that implements the neighbor sampling + algorithm used in GraphSAGE. + """ + + def __init__( + self, + data: Union[ + "torch_geometric.data.Data", + "torch_geometric.data.HeteroData", + Tuple[ + "torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore" + ], + ], + num_neighbors: Union[ + List[int], Dict["torch_geometric.typing.EdgeType", List[int]] + ], + edge_label_index: "torch_geometric.typing.InputEdges" = None, + edge_label: "torch_geometric.typing.OptTensor" = None, + edge_label_time: "torch_geometric.typing.OptTensor" = None, + replace: bool = False, + subgraph_type: Union[ + "torch_geometric.typing.SubgraphType", str + ] = "directional", + disjoint: bool = False, + temporal_strategy: str = "uniform", + neg_sampling: Optional["torch_geometric.sampler.NegativeSampling"] = None, + neg_sampling_ratio: Optional[Union[int, float]] = None, + time_attr: Optional[str] = None, + weight_attr: Optional[str] = None, + transform: Optional[Callable] = None, + transform_sampler_output: Optional[Callable] = None, + is_sorted: bool = False, + filter_per_worker: Optional[bool] = None, + neighbor_sampler: Optional["torch_geometric.sampler.NeighborSampler"] = None, + directed: bool = True, # Deprecated. + batch_size: int = 16, # Refers to number of edges per batch. + directory: Optional[str] = None, + batches_per_partition=256, + format: str = "parquet", + compression: Optional[str] = None, + local_seeds_per_call: Optional[int] = None, + **kwargs, + ): + """ + data: Data, HeteroData, or Tuple[FeatureStore, GraphStore] + See torch_geometric.loader.LinkNeighborLoader. + num_neighbors: List[int] or Dict[EdgeType, List[int]] + Fanout values. + See torch_geometric.loader.LinkNeighborLoader. + edge_label_index: InputEdges + Input edges for sampling. + See torch_geometric.loader.LinkNeighborLoader. + edge_label: OptTensor + Labels for input edges. + See torch_geometric.loader.LinkNeighborLoader. + edge_label_time: OptTensor + Time attribute for input edges. + See torch_geometric.loader.LinkNeighborLoader. + replace: bool (optional, default=False) + Whether to sample with replacement. + See torch_geometric.loader.LinkNeighborLoader. + subgraph_type: Union[SubgraphType, str] (optional, default='directional') + The type of subgraph to return. + Currently only 'directional' is supported. + See torch_geometric.loader.LinkNeighborLoader. + disjoint: bool (optional, default=False) + Whether to perform disjoint sampling. + Currently unsupported. + See torch_geometric.loader.LinkNeighborLoader. + temporal_strategy: str (optional, default='uniform') + Currently only 'uniform' is suppported. + See torch_geometric.loader.LinkNeighborLoader. + time_attr: str (optional, default=None) + Used for temporal sampling. + See torch_geometric.loader.LinkNeighborLoader. + weight_attr: str (optional, default=None) + Used for biased sampling. + See torch_geometric.loader.LinkNeighborLoader. + transform: Callable (optional, default=None) + See torch_geometric.loader.LinkNeighborLoader. + transform_sampler_output: Callable (optional, default=None) + See torch_geometric.loader.LinkNeighborLoader. + is_sorted: bool (optional, default=False) + Ignored by cuGraph. + See torch_geometric.loader.LinkNeighborLoader. + filter_per_worker: bool (optional, default=False) + Currently ignored by cuGraph, but this may + change once in-memory sampling is implemented. + See torch_geometric.loader.LinkNeighborLoader. + neighbor_sampler: torch_geometric.sampler.NeighborSampler + (optional, default=None) + Not supported by cuGraph. + See torch_geometric.loader.LinkNeighborLoader. + directed: bool (optional, default=True) + Deprecated. + See torch_geometric.loader.LinkNeighborLoader. + batch_size: int (optional, default=16) + The number of input nodes per output minibatch. + See torch.utils.dataloader. + directory: str (optional, default=None) + The directory where samples will be temporarily stored, + if spilling samples to disk. If None, this loader + will perform buffered in-memory sampling. + If writing to disk, setting this argument + to a tempfile.TemporaryDirectory with a context + manager is a good option but depending on the filesystem, + you may want to choose an alternative location with fast I/O + intead. + See cugraph.gnn.DistSampleWriter. + batches_per_partition: int (optional, default=256) + The number of batches per partition if writing samples to + disk. Manually tuning this parameter is not recommended + but reducing it may help conserve GPU memory. + See cugraph.gnn.DistSampleWriter. + format: str (optional, default='parquet') + If writing samples to disk, they will be written in this + file format. + See cugraph.gnn.DistSampleWriter. + compression: str (optional, default=None) + The compression type to use if writing samples to disk. + If not provided, it is automatically chosen. + local_seeds_per_call: int (optional, default=None) + The number of seeds to process within a single sampling call. + Manually tuning this parameter is not recommended but reducing + it may conserve GPU memory. The total number of seeds processed + per sampling call is equal to the sum of this parameter across + all workers. If not provided, it will be automatically + calculated. + See cugraph.gnn.DistSampler. + **kwargs + Other keyword arguments passed to the superclass. + """ + + subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type) + + if not directed: + subgraph_type = torch_geometric.sampler.base.SubgraphType.induced + warnings.warn( + "The 'directed' argument is deprecated. " + "Use subgraph_type='induced' instead." + ) + if subgraph_type != torch_geometric.sampler.base.SubgraphType.directional: + raise ValueError("Only directional subgraphs are currently supported") + if disjoint: + raise ValueError("Disjoint sampling is currently unsupported") + if temporal_strategy != "uniform": + warnings.warn("Only the uniform temporal strategy is currently supported") + if neighbor_sampler is not None: + raise ValueError("Passing a neighbor sampler is currently unsupported") + if time_attr is not None: + raise ValueError("Temporal sampling is currently unsupported") + if is_sorted: + warnings.warn("The 'is_sorted' argument is ignored by cuGraph.") + if not isinstance(data, (list, tuple)) or not isinstance( + data[1], cugraph_pyg.data.GraphStore + ): + # Will eventually automatically convert these objects to cuGraph objects. + raise NotImplementedError("Currently can't accept non-cugraph graphs") + + if compression is None: + compression = "CSR" + elif compression not in ["CSR", "COO"]: + raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')") + + writer = ( + None + if directory is None + else DistSampleWriter( + directory=directory, + batches_per_partition=batches_per_partition, + format=format, + ) + ) + + feature_store, graph_store = data + + if weight_attr is not None: + graph_store._set_weight_attr((feature_store, weight_attr)) + + sampler = BaseSampler( + NeighborSampler( + graph_store._graph, + writer, + retain_original_seeds=True, + fanout=num_neighbors, + prior_sources_behavior="exclude", + deduplicate_sources=True, + compression=compression, + compress_per_hop=False, + with_replacement=replace, + local_seeds_per_call=local_seeds_per_call, + biased=(weight_attr is not None), + ), + (feature_store, graph_store), + batch_size=batch_size, + ) + # TODO add heterogeneous support and pass graph_store._vertex_offsets + + super().__init__( + (feature_store, graph_store), + sampler, + edge_label_index=edge_label_index, + edge_label=edge_label, + edge_label_time=edge_label_time, + neg_sampling=neg_sampling, + neg_sampling_ratio=neg_sampling_ratio, + transform=transform, + transform_sampler_output=transform_sampler_output, + filter_per_worker=filter_per_worker, + batch_size=batch_size, + **kwargs, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py index 7f12bbb3fe6..1da2c6dc381 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py @@ -12,7 +12,6 @@ # limitations under the License. import warnings -import tempfile from typing import Union, Tuple, Optional, Callable, List, Dict @@ -123,14 +122,14 @@ def __init__( The number of input nodes per output minibatch. See torch.utils.dataloader. directory: str (optional, default=None) - The directory where samples will be temporarily stored. - It is recommend that this be set by the user, usually - setting it to a tempfile.TemporaryDirectory with a context + The directory where samples will be temporarily stored, + if spilling samples to disk. If None, this loader + will perform buffered in-memory sampling. + If writing to disk, setting this argument + to a tempfile.TemporaryDirectory with a context manager is a good option but depending on the filesystem, you may want to choose an alternative location with fast I/O intead. - If not set, this will create a TemporaryDirectory that will - persist until this object is garbage collected. See cugraph.gnn.DistSampleWriter. batches_per_partition: int (optional, default=256) The number of batches per partition if writing samples to @@ -182,20 +181,19 @@ def __init__( # Will eventually automatically convert these objects to cuGraph objects. raise NotImplementedError("Currently can't accept non-cugraph graphs") - if directory is None: - warnings.warn("Setting a directory to store samples is recommended.") - self._tempdir = tempfile.TemporaryDirectory() - directory = self._tempdir.name - if compression is None: compression = "CSR" elif compression not in ["CSR", "COO"]: raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')") - writer = DistSampleWriter( - directory=directory, - batches_per_partition=batches_per_partition, - format=format, + writer = ( + None + if directory is None + else DistSampleWriter( + directory=directory, + batches_per_partition=batches_per_partition, + format=format, + ) ) feature_store, graph_store = data diff --git a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py index 49923783d6b..4b236f75885 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py @@ -110,8 +110,10 @@ def __init__( input_id, ) - self.__input_data = torch_geometric.loader.node_loader.NodeSamplerInput( - input_id=input_id, + self.__input_data = torch_geometric.sampler.NodeSamplerInput( + input_id=torch.arange(len(input_nodes), dtype=torch.int64, device="cuda") + if input_id is None + else input_id, node=input_nodes, time=None, input_type=input_type, @@ -135,10 +137,8 @@ def __iter__(self): d = perm.numel() % self.__batch_size perm = perm[:-d] - input_data = torch_geometric.loader.node_loader.NodeSamplerInput( - input_id=None - if self.__input_data.input_id is None - else self.__input_data.input_id[perm], + input_data = torch_geometric.sampler.NodeSamplerInput( + input_id=self.__input_data.input_id[perm], node=self.__input_data.node[perm], time=None if self.__input_data.time is None diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index 268e9ffebbd..bc3d4fd8d3c 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -14,9 +14,9 @@ from typing import Optional, Iterator, Union, Dict, Tuple from cugraph.utilities.utils import import_optional -from cugraph.gnn import DistSampler, DistSampleReader +from cugraph.gnn import DistSampler -from .sampler_utils import filter_cugraph_pyg_store +from .sampler_utils import filter_cugraph_pyg_store, neg_sample, neg_cat torch = import_optional("torch") torch_geometric = import_optional("torch_geometric") @@ -60,7 +60,12 @@ def __next__(self): next_sample = next(self.__output_iter) if isinstance(next_sample, torch_geometric.sampler.SamplerOutput): sz = next_sample.edge.numel() - if sz == next_sample.col.numel(): + if sz == next_sample.col.numel() and ( + next_sample.node.numel() > next_sample.col[-1] + ): + # This will only trigger on very small batches and will have minimal + # performance impact. If COO output is removed, then this condition + # can be avoided. col = next_sample.col else: col = torch_geometric.edge_index.ptr2index( @@ -101,10 +106,20 @@ def __next__(self): data.num_sampled_nodes = next_sample.num_sampled_nodes data.num_sampled_edges = next_sample.num_sampled_edges - data.input_id = data.batch - data.seed_time = None + data.input_id = next_sample.metadata[0] data.batch_size = data.input_id.size(0) + if len(next_sample.metadata) == 2: + data.seed_time = next_sample.metadata[1] + elif len(next_sample.metadata) == 4: + ( + data.edge_label_index, + data.edge_label, + data.seed_time, + ) = next_sample.metadata[1:] + else: + raise ValueError("Invalid metadata") + elif isinstance(next_sample, torch_geometric.sampler.HeteroSamplerOutput): col = {} for edge_type, col_idx in next_sample.col: @@ -152,13 +167,15 @@ class SampleReader: Iterator that processes results from the cuGraph distributed sampler. """ - def __init__(self, base_reader: DistSampleReader): + def __init__( + self, base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + ): """ Constructs a new SampleReader. Parameters ---------- - base_reader: DistSampleReader + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] The reader responsible for loading saved samples produced by the cuGraph distributed sampler. """ @@ -173,6 +190,9 @@ def __next__(self): self.__base_reader ) + self.__raw_sample_data["input_offsets"] -= self.__raw_sample_data[ + "input_offsets" + ][0].clone() self.__raw_sample_data["label_hop_offsets"] -= self.__raw_sample_data[ "label_hop_offsets" ][0].clone() @@ -202,14 +222,16 @@ class HomogeneousSampleReader(SampleReader): produced by the cuGraph distributed sampler. """ - def __init__(self, base_reader: DistSampleReader): + def __init__( + self, base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + ): """ Constructs a new HomogeneousSampleReader Parameters ---------- - base_reader: DistSampleReader - The reader responsible for loading saved samples produced by + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + The iterator responsible for loading saved samples produced by the cuGraph distributed sampler. """ super().__init__(base_reader) @@ -262,6 +284,52 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): [num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)] ) + input_index = raw_sample_data["input_index"][ + raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][ + index + 1 + ] + ] + + num_seeds = input_index.numel() + input_index = input_index[input_index >= 0] + + num_pos = input_index.numel() + num_neg = num_seeds - num_pos + if num_neg > 0: + edge_label = torch.concat( + [ + torch.full((num_pos,), 1.0), + torch.full((num_neg,), 0.0), + ] + ) + else: + edge_label = None + + edge_inverse = ( + ( + raw_sample_data["edge_inverse"][ + (raw_sample_data["input_offsets"][index] * 2) : ( + raw_sample_data["input_offsets"][index + 1] * 2 + ) + ] + ) + if "edge_inverse" in raw_sample_data + else None + ) + + if edge_inverse is None: + metadata = ( + input_index, + None, # TODO this will eventually include time + ) + else: + metadata = ( + input_index, + edge_inverse.view(2, -1), + edge_label, + None, # TODO this will eventually include time + ) + return torch_geometric.sampler.SamplerOutput( node=renumber_map.cpu(), row=minors, @@ -270,6 +338,7 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): batch=renumber_map[:num_seeds], num_sampled_nodes=num_sampled_nodes.cpu(), num_sampled_edges=num_sampled_edges.cpu(), + metadata=metadata, ) def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): @@ -315,6 +384,37 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): [num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)] ) + input_index = raw_sample_data["input_index"][ + raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][ + index + 1 + ] + ] + + edge_inverse = ( + ( + raw_sample_data["edge_inverse"][ + (raw_sample_data["input_offsets"][index] * 2) : ( + raw_sample_data["input_offsets"][index + 1] * 2 + ) + ] + ) + if "edge_inverse" in raw_sample_data + else None + ) + + if edge_inverse is None: + metadata = ( + input_index, + None, # TODO this will eventually include time + ) + else: + metadata = ( + input_index, + edge_inverse.view(2, -1), + None, + None, # TODO this will eventually include time + ) + return torch_geometric.sampler.SamplerOutput( node=renumber_map.cpu(), row=minors, @@ -323,6 +423,7 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): batch=renumber_map[:num_seeds], num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, + metadata=metadata, ) def _decode(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): @@ -353,8 +454,8 @@ def sample_from_nodes( "torch_geometric.sampler.SamplerOutput", ] ]: - self.__sampler.sample_from_nodes( - index.node, batch_size=self.__batch_size, **kwargs + reader = self.__sampler.sample_from_nodes( + index.node, batch_size=self.__batch_size, input_id=index.input_id, **kwargs ) edge_attrs = self.__graph_store.get_all_edge_attrs() @@ -362,7 +463,7 @@ def sample_from_nodes( len(edge_attrs) == 1 and edge_attrs[0].edge_type[0] == edge_attrs[0].edge_type[2] ): - return HomogeneousSampleReader(self.__sampler.get_reader()) + return HomogeneousSampleReader(reader) else: # TODO implement heterogeneous sampling raise NotImplementedError( @@ -381,4 +482,59 @@ def sample_from_edges( "torch_geometric.sampler.SamplerOutput", ] ]: - raise NotImplementedError("Edge sampling is currently unimplemented.") + src = index.row + dst = index.col + input_id = index.input_id + neg_batch_size = 0 + if neg_sampling: + # Sample every negative subset at once. + # TODO handle temporal sampling (node_time) + src_neg, dst_neg = neg_sample( + self.__graph_store, + index.row, + index.col, + self.__batch_size, + neg_sampling, + None, # src_time, + None, # src_node_time, + ) + if neg_sampling.is_binary(): + src, _ = neg_cat(src.cuda(), src_neg, self.__batch_size) + else: + # triplet, cat dst to src so length is the same; will + # result in the same set of unique vertices + src, _ = neg_cat(src.cuda(), dst_neg, self.__batch_size) + dst, neg_batch_size = neg_cat(dst.cuda(), dst_neg, self.__batch_size) + + # Concatenate -1s so the input id tensor lines up and can + # be processed by the dist sampler. + # When loading the output batch, '-1' will be dropped. + input_id, _ = neg_cat( + input_id, + torch.full( + (dst_neg.numel(),), -1, dtype=torch.int64, device=input_id.device + ), + self.__batch_size, + ) + + # TODO for temporal sampling, node times have to be + # adjusted here. + reader = self.__sampler.sample_from_edges( + torch.stack([src, dst]), # reverse of usual convention + input_id=input_id, + batch_size=self.__batch_size + neg_batch_size, + **kwargs, + ) + + edge_attrs = self.__graph_store.get_all_edge_attrs() + if ( + len(edge_attrs) == 1 + and edge_attrs[0].edge_type[0] == edge_attrs[0].edge_type[2] + ): + return HomogeneousSampleReader(reader) + else: + # TODO implement heterogeneous sampling + raise NotImplementedError( + "Sampling heterogeneous graphs is currently" + " unsupported in the non-dask API" + ) diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py index dba7c146b01..b3d56ef9992 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py @@ -14,10 +14,14 @@ from typing import Sequence, Dict, Tuple -from cugraph_pyg.data import DaskGraphStore +from math import ceil + +from cugraph_pyg.data import GraphStore, DaskGraphStore from cugraph.utilities.utils import import_optional import cudf +import cupy +import pylibcugraph dask_cudf = import_optional("dask_cudf") torch_geometric = import_optional("torch_geometric") @@ -429,3 +433,99 @@ def filter_cugraph_pyg_store( data[attr.attr_name] = tensors[i] return data + + +def neg_sample( + graph_store: GraphStore, + seed_src: "torch.Tensor", + seed_dst: "torch.Tensor", + batch_size: int, + neg_sampling: "torch_geometric.sampler.NegativeSampling", + time: "torch.Tensor", + node_time: "torch.Tensor", +) -> Tuple["torch.Tensor", "torch.Tensor"]: + try: + # Compatibility for PyG 2.5 + src_weight = neg_sampling.src_weight + dst_weight = neg_sampling.dst_weight + except AttributeError: + src_weight = neg_sampling.weight + dst_weight = neg_sampling.weight + unweighted = src_weight is None and dst_weight is None + + # Require at least one negative edge per batch + num_neg = max( + int(ceil(neg_sampling.amount * seed_src.numel())), + int(ceil(seed_src.numel() / batch_size)), + ) + + if graph_store.is_multi_gpu: + num_neg_global = torch.tensor([num_neg], device="cuda") + torch.distributed.all_reduce(num_neg_global, op=torch.distributed.ReduceOp.SUM) + num_neg = int(num_neg_global) + else: + num_neg_global = num_neg + + if node_time is None: + result_dict = pylibcugraph.negative_sampling( + graph_store._resource_handle, + graph_store._graph, + num_neg_global, + vertices=None + if unweighted + else cupy.arange(src_weight.numel(), dtype="int64"), + src_bias=None if src_weight is None else cupy.asarray(src_weight), + dst_bias=None if dst_weight is None else cupy.asarray(dst_weight), + remove_duplicates=False, + remove_false_negatives=False, + exact_number_of_samples=True, + do_expensive_check=False, + ) + + src_neg = torch.as_tensor(result_dict["sources"], device="cuda")[:num_neg] + dst_neg = torch.as_tensor(result_dict["destinations"], device="cuda")[:num_neg] + + # TODO modifiy the C API so this condition is impossible + if src_neg.numel() < num_neg: + num_gen = num_neg - src_neg.numel() + src_neg = torch.concat( + [ + src_neg, + torch.randint( + 0, src_neg.max(), (num_gen,), device="cuda", dtype=torch.int64 + ), + ] + ) + dst_neg = torch.concat( + [ + dst_neg, + torch.randint( + 0, dst_neg.max(), (num_gen,), device="cuda", dtype=torch.int64 + ), + ] + ) + return src_neg, dst_neg + raise NotImplementedError( + "Temporal negative sampling is currently unimplemented in cuGraph-PyG" + ) + + +def neg_cat( + seed_pos: "torch.Tensor", seed_neg: "torch.Tensor", pos_batch_size: int +) -> Tuple["torch.Tensor", int]: + num_seeds = seed_pos.numel() + num_batches = int(ceil(num_seeds / pos_batch_size)) + neg_batch_size = int(ceil(seed_neg.numel() / num_batches)) + + batch_pos_offsets = torch.full((num_batches,), pos_batch_size).cumsum(-1)[:-1] + seed_pos_splits = torch.tensor_split(seed_pos, batch_pos_offsets) + + batch_neg_offsets = torch.full((num_batches,), neg_batch_size).cumsum(-1)[:-1] + seed_neg_splits = torch.tensor_split(seed_neg, batch_neg_offsets) + + return ( + torch.concatenate( + [torch.concatenate(s) for s in zip(seed_pos_splits, seed_neg_splits)] + ), + neg_batch_size, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py index c4ad941de7a..8ee18a826f7 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py @@ -16,6 +16,7 @@ from cugraph.datasets import karate from cugraph.utilities.utils import import_optional, MissingModule +import cugraph_pyg from cugraph_pyg.data import TensorDictFeatureStore, GraphStore from cugraph_pyg.loader import NeighborLoader @@ -86,3 +87,110 @@ def test_neighbor_loader_biased(): assert out.edge_index.shape[1] == 2 assert (out.edge_index.cpu() == torch.tensor([[3, 4], [1, 2]])).all() + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +@pytest.mark.parametrize("num_nodes", [10, 25]) +@pytest.mark.parametrize("num_edges", [64, 128]) +@pytest.mark.parametrize("batch_size", [2, 4]) +@pytest.mark.parametrize("select_edges", [16, 32]) +@pytest.mark.parametrize("depth", [1, 3]) +@pytest.mark.parametrize("num_neighbors", [1, 4]) +def test_link_neighbor_loader_basic( + num_nodes, num_edges, batch_size, select_edges, num_neighbors, depth +): + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() + + eix = torch.randperm(num_edges)[:select_edges] + graph_store[("n", "e", "n"), "coo"] = torch.stack( + [ + torch.randint(0, num_nodes, (num_edges,)), + torch.randint(0, num_nodes, (num_edges,)), + ] + ) + + elx = graph_store[("n", "e", "n"), "coo"][:, eix] + loader = cugraph_pyg.loader.LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[num_neighbors] * depth, + edge_label_index=elx, + batch_size=batch_size, + shuffle=False, + ) + + elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1) + for i, batch in enumerate(loader): + assert ( + batch.input_id.cpu() == torch.arange(i * batch_size, (i + 1) * batch_size) + ).all() + assert (elx[i] == batch.n_id[batch.edge_label_index.cpu()]).all() + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_link_neighbor_loader_negative_sampling_basic(batch_size): + num_edges = 62 + num_nodes = 19 + select_edges = 17 + + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() + + eix = torch.randperm(num_edges)[:select_edges] + graph_store[("n", "e", "n"), "coo"] = torch.stack( + [ + torch.randint(0, num_nodes, (num_edges,)), + torch.randint(0, num_nodes, (num_edges,)), + ] + ) + + elx = graph_store[("n", "e", "n"), "coo"][:, eix] + loader = cugraph_pyg.loader.LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[3, 3, 3], + edge_label_index=elx, + batch_size=batch_size, + neg_sampling="binary", + shuffle=False, + ) + + elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1) + for i, batch in enumerate(loader): + assert batch.edge_label[0] == 1.0 + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_link_neighbor_loader_negative_sampling_uneven(batch_size): + num_edges = 62 + num_nodes = 19 + select_edges = 17 + + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() + + eix = torch.randperm(num_edges)[:select_edges] + graph_store[("n", "e", "n"), "coo"] = torch.stack( + [ + torch.randint(0, num_nodes, (num_edges,)), + torch.randint(0, num_nodes, (num_edges,)), + ] + ) + + elx = graph_store[("n", "e", "n"), "coo"][:, eix] + loader = cugraph_pyg.loader.LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[3, 3, 3], + edge_label_index=elx, + batch_size=batch_size, + neg_sampling=torch_geometric.sampler.NegativeSampling("binary", amount=0.1), + shuffle=False, + ) + + elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1) + for i, batch in enumerate(loader): + assert batch.edge_label[0] == 1.0 diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py index b8089bb901d..d1dee01a508 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py @@ -19,7 +19,7 @@ from cugraph.utilities.utils import import_optional, MissingModule from cugraph_pyg.data import TensorDictFeatureStore, GraphStore -from cugraph_pyg.loader import NeighborLoader +from cugraph_pyg.loader import NeighborLoader, LinkNeighborLoader from cugraph.gnn import ( cugraph_comms_init, @@ -96,6 +96,7 @@ def run_test_neighbor_loader_mg(rank, uid, world_size, specify_size): cugraph_comms_shutdown() +@pytest.mark.skip(reason="deleteme") @pytest.mark.parametrize("specify_size", [True, False]) @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.mg @@ -165,6 +166,7 @@ def run_test_neighbor_loader_biased_mg(rank, uid, world_size): cugraph_comms_shutdown() +@pytest.mark.skip(reason="deleteme") @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.mg def test_neighbor_loader_biased_mg(): @@ -179,3 +181,184 @@ def test_neighbor_loader_biased_mg(): ), nprocs=world_size, ) + + +def run_test_link_neighbor_loader_basic_mg( + rank, + uid, + world_size, + num_nodes: int, + num_edges: int, + select_edges: int, + batch_size: int, + num_neighbors: int, + depth: int, +): + init_pytorch_worker(rank, world_size, uid) + + graph_store = GraphStore(is_multi_gpu=True) + feature_store = TensorDictFeatureStore() + + eix = torch.randperm(num_edges)[:select_edges] + graph_store[("n", "e", "n"), "coo"] = torch.stack( + [ + torch.randint(0, num_nodes, (num_edges,)), + torch.randint(0, num_nodes, (num_edges,)), + ] + ) + + elx = graph_store[("n", "e", "n"), "coo"][:, eix] + loader = LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[num_neighbors] * depth, + edge_label_index=elx, + batch_size=batch_size, + shuffle=False, + ) + + elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1) + for i, batch in enumerate(loader): + assert ( + batch.input_id.cpu() == torch.arange(i * batch_size, (i + 1) * batch_size) + ).all() + assert (elx[i] == batch.n_id[batch.edge_label_index.cpu()]).all() + + cugraph_comms_shutdown() + + +@pytest.mark.skip(reason="deleteme") +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg +@pytest.mark.parametrize("select_edges", [64, 128]) +@pytest.mark.parametrize("batch_size", [2, 4]) +@pytest.mark.parametrize("depth", [1, 3]) +def test_link_neighbor_loader_basic_mg(select_edges, batch_size, depth): + num_nodes = 25 + num_edges = 128 + num_neighbors = 2 + + uid = cugraph_comms_create_unique_id() + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_link_neighbor_loader_basic_mg, + args=( + uid, + world_size, + num_nodes, + num_edges, + select_edges, + batch_size, + num_neighbors, + depth, + ), + nprocs=world_size, + ) + + +def run_test_link_neighbor_loader_uneven_mg(rank, uid, world_size, edge_index): + init_pytorch_worker(rank, world_size, uid) + + graph_store = GraphStore(is_multi_gpu=True) + feature_store = TensorDictFeatureStore() + + batch_size = 1 + graph_store[("n", "e", "n"), "coo"] = torch.tensor_split( + edge_index, world_size, dim=-1 + )[rank] + + elx = graph_store[("n", "e", "n"), "coo"] # select all edges on each worker + loader = LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[2, 2, 2], + edge_label_index=elx, + batch_size=batch_size, + shuffle=False, + ) + + for i, batch in enumerate(loader): + assert ( + batch.input_id.cpu() == torch.arange(i * batch_size, (i + 1) * batch_size) + ).all() + + assert (elx[:, [i]] == batch.n_id[batch.edge_label_index.cpu()]).all() + + cugraph_comms_shutdown() + + +@pytest.mark.skip(reason="deleteme") +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg +def test_link_neighbor_loader_uneven_mg(): + edge_index = torch.tensor( + [ + [0, 1, 3, 4, 7], + [1, 0, 8, 9, 12], + ] + ) + + uid = cugraph_comms_create_unique_id() + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_link_neighbor_loader_uneven_mg, + args=( + uid, + world_size, + edge_index, + ), + nprocs=world_size, + ) + + +def run_test_link_neighbor_loader_negative_sampling_basic_mg( + rank, world_size, uid, batch_size +): + num_edges = 62 * world_size + num_nodes = 19 * world_size + select_edges = 17 + + init_pytorch_worker(rank, world_size, uid) + + graph_store = GraphStore(is_multi_gpu=True) + feature_store = TensorDictFeatureStore() + + eix = torch.randperm(num_edges)[:select_edges] + graph_store[("n", "e", "n"), "coo"] = torch.stack( + [ + torch.randint(0, num_nodes, (num_edges,)), + torch.randint(0, num_nodes, (num_edges,)), + ] + ) + + elx = graph_store[("n", "e", "n"), "coo"][:, eix] + loader = LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[3, 3, 3], + edge_label_index=elx, + batch_size=batch_size, + neg_sampling="binary", + shuffle=False, + ) + + elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1) + for i, batch in enumerate(loader): + assert batch.edge_label[0] == 1.0 + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_link_neighbor_loader_negative_sampling_basic_mg(batch_size): + uid = cugraph_comms_create_unique_id() + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_link_neighbor_loader_negative_sampling_basic_mg, + args=( + world_size, + uid, + batch_size, + ), + nprocs=world_size, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/pytest.ini b/python/cugraph-pyg/cugraph_pyg/tests/pytest.ini new file mode 100644 index 00000000000..7b0a9f29fb1 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/pytest.ini @@ -0,0 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +[pytest] +addopts = --tb=native diff --git a/python/cugraph-pyg/pyproject.toml b/python/cugraph-pyg/pyproject.toml index e3fb0eee98f..e157f36f8f6 100644 --- a/python/cugraph-pyg/pyproject.toml +++ b/python/cugraph-pyg/pyproject.toml @@ -31,7 +31,7 @@ classifiers = [ dependencies = [ "cugraph==24.12.*,>=0.0.0a0", "numba>=0.57", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "pylibcugraphops==24.12.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. @@ -49,7 +49,7 @@ test = [ "pytest-xdist", "scipy", "tensordict>=0.1.2", - "torch>=2.0,<2.2.0a0", + "torch>=2.3,<2.4.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [tool.setuptools] diff --git a/python/cugraph-pyg/pytest.ini b/python/cugraph-pyg/pytest.ini index db99a54ae49..07c4ffa0958 100644 --- a/python/cugraph-pyg/pytest.ini +++ b/python/cugraph-pyg/pytest.ini @@ -17,6 +17,7 @@ addopts = --benchmark-max-time=0 --benchmark-min-rounds=1 --benchmark-columns="mean, rounds" + --tb=native ## do not run slow tests/benchmarks by default -m "not slow" diff --git a/python/cugraph-service/pytest.ini b/python/cugraph-service/pytest.ini index 6a0dd36ecec..f2ba9175f82 100644 --- a/python/cugraph-service/pytest.ini +++ b/python/cugraph-service/pytest.ini @@ -16,6 +16,7 @@ addopts = --benchmark-warmup=off --benchmark-max-time=0 --benchmark-min-rounds=1 --benchmark-columns="min, max, mean, rounds" + --tb=native ## for use with rapids-pytest-benchmark plugin #--benchmark-gpu-disable ## for use with pytest-cov plugin diff --git a/python/cugraph-service/server/pyproject.toml b/python/cugraph-service/server/pyproject.toml index c850397b6fc..f388fd4c126 100644 --- a/python/cugraph-service/server/pyproject.toml +++ b/python/cugraph-service/server/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "dask-cuda==24.12.*,>=0.0.0a0", "dask-cudf==24.12.*,>=0.0.0a0", "numba>=0.57", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "rapids-dask-dependency==24.12.*,>=0.0.0a0", "rmm==24.12.*,>=0.0.0a0", "thriftpy2!=0.5.0,!=0.5.1", @@ -47,7 +47,7 @@ cugraph-service-server = "cugraph_service_server.__main__:main" [project.optional-dependencies] test = [ "networkx>=2.5.1", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "pandas", "pytest", "pytest-benchmark", diff --git a/python/cugraph-service/tests/pytest.ini b/python/cugraph-service/tests/pytest.ini new file mode 100644 index 00000000000..7b0a9f29fb1 --- /dev/null +++ b/python/cugraph-service/tests/pytest.ini @@ -0,0 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +[pytest] +addopts = --tb=native diff --git a/python/cugraph/cugraph/gnn/data_loading/__init__.py b/python/cugraph/cugraph/gnn/data_loading/__init__.py index 9e2c81ec749..25f58be88aa 100644 --- a/python/cugraph/cugraph/gnn/data_loading/__init__.py +++ b/python/cugraph/cugraph/gnn/data_loading/__init__.py @@ -14,9 +14,12 @@ from cugraph.gnn.data_loading.bulk_sampler import BulkSampler from cugraph.gnn.data_loading.dist_sampler import ( DistSampler, + NeighborSampler, +) +from cugraph.gnn.data_loading.dist_io import ( DistSampleWriter, DistSampleReader, - NeighborSampler, + BufferedSampleReader, ) diff --git a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py index 6abbd82647b..222fb49a836 100644 --- a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py +++ b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py @@ -33,10 +33,12 @@ def create_df_from_disjoint_series(series_list: List[cudf.Series]): def create_df_from_disjoint_arrays(array_dict: Dict[str, cupy.array]): + series_dict = {} for k in list(array_dict.keys()): - array_dict[k] = cudf.Series(array_dict[k], name=k) + if array_dict[k] is not None: + series_dict[k] = cudf.Series(array_dict[k], name=k) - return create_df_from_disjoint_series(list(array_dict.values())) + return create_df_from_disjoint_series(list(series_dict.values())) def _write_samples_to_parquet_csr( diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_io/__init__.py b/python/cugraph/cugraph/gnn/data_loading/dist_io/__init__.py new file mode 100644 index 00000000000..29bb5489be2 --- /dev/null +++ b/python/cugraph/cugraph/gnn/data_loading/dist_io/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .reader import BufferedSampleReader, DistSampleReader +from .writer import DistSampleWriter diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_io/reader.py b/python/cugraph/cugraph/gnn/data_loading/dist_io/reader.py new file mode 100644 index 00000000000..69f909e7a8d --- /dev/null +++ b/python/cugraph/cugraph/gnn/data_loading/dist_io/reader.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import re + +import cudf + +from typing import Callable, Iterator, Tuple, Dict, Optional + +from cugraph.utilities.utils import import_optional, MissingModule + +# Prevent PyTorch from being imported and causing an OOM error +torch = MissingModule("torch") + + +class DistSampleReader: + def __init__( + self, + directory: str, + *, + format: str = "parquet", + rank: Optional[int] = None, + filelist=None, + ): + torch = import_optional("torch") + + self.__format = format + self.__directory = directory + + if format != "parquet": + raise ValueError("Invalid format (currently supported: 'parquet')") + + if filelist is None: + files = os.listdir(directory) + ex = re.compile(r"batch\=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet") + filematch = [ex.match(f) for f in files] + filematch = [f for f in filematch if f] + + if rank is not None: + filematch = [f for f in filematch if int(f[1]) == rank] + + batch_count = sum([int(f[4]) - int(f[2]) + 1 for f in filematch]) + filematch = sorted(filematch, key=lambda f: int(f[2]), reverse=True) + + self.__files = filematch + else: + self.__files = list(filelist) + + if rank is None: + self.__batch_count = batch_count + else: + # TODO maybe remove this in favor of warning users that they are + # probably going to cause a hang, instead of attempting to resolve + # the hang for them by dropping batches. + batch_count = torch.tensor([batch_count], device="cuda") + torch.distributed.all_reduce(batch_count, torch.distributed.ReduceOp.MIN) + self.__batch_count = int(batch_count) + + def __iter__(self): + return self + + def __next__(self) -> Tuple[Dict[str, "torch.Tensor"], int, int]: + torch = import_optional("torch") + + if len(self.__files) > 0: + f = self.__files.pop() + fname = f[0] + start_inclusive = int(f[2]) + end_inclusive = int(f[4]) + + if (end_inclusive - start_inclusive + 1) > self.__batch_count: + end_inclusive = start_inclusive + self.__batch_count - 1 + self.__batch_count = 0 + else: + self.__batch_count -= end_inclusive - start_inclusive + 1 + + df = cudf.read_parquet(os.path.join(self.__directory, fname)) + tensors = {} + for col in list(df.columns): + s = df[col].dropna() + if len(s) > 0: + tensors[col] = torch.as_tensor(s, device="cuda") + df.drop(col, axis=1, inplace=True) + + return tensors, start_inclusive, end_inclusive + + raise StopIteration + + +class BufferedSampleReader: + def __init__( + self, + nodes_call_groups: list["torch.Tensor"], + sample_fn: Callable[..., Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]], + *args, + **kwargs, + ): + self.__sample_args = args + self.__sample_kwargs = kwargs + + self.__nodes_call_groups = iter(nodes_call_groups) + self.__sample_fn = sample_fn + self.__current_call_id = 0 + self.__current_reader = None + + def __next__(self) -> Tuple[Dict[str, "torch.Tensor"], int, int]: + new_reader = False + + if self.__current_reader is None: + new_reader = True + else: + try: + out = next(self.__current_reader) + except StopIteration: + new_reader = True + + if new_reader: + # Will trigger StopIteration if there are no more call groups + self.__current_reader = self.__sample_fn( + self.__current_call_id, + next(self.__nodes_call_groups), + *self.__sample_args, + **self.__sample_kwargs, + ) + + self.__current_call_id += 1 + out = next(self.__current_reader) + + return out + + def __iter__(self) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: + return self diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py b/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py new file mode 100644 index 00000000000..f8ad4719a76 --- /dev/null +++ b/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py @@ -0,0 +1,321 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from math import ceil + + +import cupy + +from cugraph.utilities.utils import MissingModule +from cugraph.gnn.data_loading.dist_io import DistSampleReader + +from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays + +from typing import Iterator, Tuple, Dict + +torch = MissingModule("torch") + + +class DistSampleWriter: + def __init__( + self, + directory: str, + *, + batches_per_partition: int = 256, + format: str = "parquet", + ): + """ + Parameters + ---------- + directory: str (required) + The directory where samples will be written. This + writer can only write to disk. + batches_per_partition: int (optional, default=256) + The number of batches to write in a single file. + format: str (optional, default='parquet') + The file format of the output files containing the + sampled minibatches. Currently, only parquet format + is supported. + """ + if format != "parquet": + raise ValueError("Invalid format (currently supported: 'parquet')") + + self.__format = format + self.__directory = directory + self.__batches_per_partition = batches_per_partition + + @property + def _format(self): + return self.__format + + @property + def _directory(self): + return self.__directory + + @property + def _batches_per_partition(self): + return self.__batches_per_partition + + def get_reader( + self, rank: int + ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: + """ + Returns an iterator over sampled data. + """ + + # currently only disk reading is supported + return DistSampleReader(self._directory, format=self._format, rank=rank) + + def __write_minibatches_coo(self, minibatch_dict): + has_edge_ids = minibatch_dict["edge_id"] is not None + has_edge_types = minibatch_dict["edge_type"] is not None + has_weights = minibatch_dict["weight"] is not None + + if minibatch_dict["renumber_map"] is None: + raise ValueError( + "Distributed sampling without renumbering is not supported" + ) + + # Quit if there are no batches to write. + if len(minibatch_dict["batch_id"]) == 0: + return + + fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( + minibatch_dict["batch_id"] + ) + + for p in range( + 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) + ): + partition_start = p * (self.__batches_per_partition) + partition_end = (p + 1) * (self.__batches_per_partition) + + label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ + partition_start * fanout_length : partition_end * fanout_length + 1 + ] + + batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] + start_batch_id = batch_id_array_p[0] + + input_offsets_p = minibatch_dict["input_offsets"][ + partition_start : (partition_end + 1) + ] + input_index_p = minibatch_dict["input_index"][ + input_offsets_p[0] : input_offsets_p[-1] + ] + edge_inverse_p = ( + minibatch_dict["edge_inverse"][ + (input_offsets_p[0] * 2) : (input_offsets_p[-1] * 2) + ] + if "edge_inverse" in minibatch_dict + else None + ) + + start_ix, end_ix = label_hop_offsets_array_p[[0, -1]] + majors_array_p = minibatch_dict["majors"][start_ix:end_ix] + minors_array_p = minibatch_dict["minors"][start_ix:end_ix] + edge_id_array_p = ( + minibatch_dict["edge_id"][start_ix:end_ix] + if has_edge_ids + else cupy.array([], dtype="int64") + ) + edge_type_array_p = ( + minibatch_dict["edge_type"][start_ix:end_ix] + if has_edge_types + else cupy.array([], dtype="int32") + ) + weight_array_p = ( + minibatch_dict["weight"][start_ix:end_ix] + if has_weights + else cupy.array([], dtype="float32") + ) + + # create the renumber map offsets + renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ + partition_start : partition_end + 1 + ] + + renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ + [0, -1] + ] + + renumber_map_array_p = minibatch_dict["renumber_map"][ + renumber_map_start_ix:renumber_map_end_ix + ] + + results_dataframe_p = create_df_from_disjoint_arrays( + { + "majors": majors_array_p, + "minors": minors_array_p, + "map": renumber_map_array_p, + "label_hop_offsets": label_hop_offsets_array_p, + "weight": weight_array_p, + "edge_id": edge_id_array_p, + "edge_type": edge_type_array_p, + "renumber_map_offsets": renumber_map_offsets_array_p, + "input_index": input_index_p, + "input_offsets": input_offsets_p, + "edge_inverse": edge_inverse_p, + } + ) + + end_batch_id = start_batch_id + len(batch_id_array_p) - 1 + rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 + + full_output_path = os.path.join( + self.__directory, + f"batch={rank:05d}.{start_batch_id:08d}-" + f"{rank:05d}.{end_batch_id:08d}.parquet", + ) + + results_dataframe_p.to_parquet( + full_output_path, + compression=None, + index=False, + force_nullable_schema=True, + ) + + def __write_minibatches_csr(self, minibatch_dict): + has_edge_ids = minibatch_dict["edge_id"] is not None + has_edge_types = minibatch_dict["edge_type"] is not None + has_weights = minibatch_dict["weight"] is not None + + if minibatch_dict["renumber_map"] is None: + raise ValueError( + "Distributed sampling without renumbering is not supported" + ) + + # Quit if there are no batches to write. + if len(minibatch_dict["batch_id"]) == 0: + return + + fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( + minibatch_dict["batch_id"] + ) + + for p in range( + 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) + ): + partition_start = p * (self.__batches_per_partition) + partition_end = (p + 1) * (self.__batches_per_partition) + + label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ + partition_start * fanout_length : partition_end * fanout_length + 1 + ] + + batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] + start_batch_id = batch_id_array_p[0] + + input_offsets_p = minibatch_dict["input_offsets"][ + partition_start : (partition_end + 1) + ] + input_index_p = minibatch_dict["input_index"][ + input_offsets_p[0] : input_offsets_p[-1] + ] + edge_inverse_p = ( + minibatch_dict["edge_inverse"][ + (input_offsets_p[0] * 2) : (input_offsets_p[-1] * 2) + ] + if "edge_inverse" in minibatch_dict + else None + ) + + # major offsets and minors + ( + major_offsets_start_incl, + major_offsets_end_incl, + ) = label_hop_offsets_array_p[[0, -1]] + + start_ix, end_ix = minibatch_dict["major_offsets"][ + [major_offsets_start_incl, major_offsets_end_incl] + ] + + major_offsets_array_p = minibatch_dict["major_offsets"][ + major_offsets_start_incl : major_offsets_end_incl + 1 + ] + + minors_array_p = minibatch_dict["minors"][start_ix:end_ix] + edge_id_array_p = ( + minibatch_dict["edge_id"][start_ix:end_ix] + if has_edge_ids + else cupy.array([], dtype="int64") + ) + edge_type_array_p = ( + minibatch_dict["edge_type"][start_ix:end_ix] + if has_edge_types + else cupy.array([], dtype="int32") + ) + weight_array_p = ( + minibatch_dict["weight"][start_ix:end_ix] + if has_weights + else cupy.array([], dtype="float32") + ) + + # create the renumber map offsets + renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ + partition_start : partition_end + 1 + ] + + renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ + [0, -1] + ] + + renumber_map_array_p = minibatch_dict["renumber_map"][ + renumber_map_start_ix:renumber_map_end_ix + ] + + results_dataframe_p = create_df_from_disjoint_arrays( + { + "major_offsets": major_offsets_array_p, + "minors": minors_array_p, + "map": renumber_map_array_p, + "label_hop_offsets": label_hop_offsets_array_p, + "weight": weight_array_p, + "edge_id": edge_id_array_p, + "edge_type": edge_type_array_p, + "renumber_map_offsets": renumber_map_offsets_array_p, + "input_index": input_index_p, + "input_offsets": input_offsets_p, + "edge_inverse": edge_inverse_p, + } + ) + + end_batch_id = start_batch_id + len(batch_id_array_p) - 1 + rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 + + full_output_path = os.path.join( + self.__directory, + f"batch={rank:05d}.{start_batch_id:08d}-" + f"{rank:05d}.{end_batch_id:08d}.parquet", + ) + + results_dataframe_p.to_parquet( + full_output_path, + compression=None, + index=False, + force_nullable_schema=True, + ) + + def write_minibatches(self, minibatch_dict): + if (minibatch_dict["majors"] is not None) and ( + minibatch_dict["minors"] is not None + ): + self.__write_minibatches_coo(minibatch_dict) + elif (minibatch_dict["major_offsets"] is not None) and ( + minibatch_dict["minors"] is not None + ): + self.__write_minibatches_csr(minibatch_dict) + else: + raise ValueError("invalid columns") diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index 52ffd8fadfd..0ff38741e1a 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -11,8 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re import warnings from math import ceil from functools import reduce @@ -27,348 +25,19 @@ from cugraph.utilities.utils import import_optional, MissingModule from cugraph.gnn.comms import cugraph_comms_get_raft_handle -from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays + +from cugraph.gnn.data_loading.dist_io import BufferedSampleReader +from cugraph.gnn.data_loading.dist_io import DistSampleWriter torch = MissingModule("torch") TensorType = Union["torch.Tensor", cupy.ndarray, cudf.Series] -class DistSampleReader: - def __init__( - self, - directory: str, - *, - format: str = "parquet", - rank: Optional[int] = None, - filelist=None, - ): - torch = import_optional("torch") - - self.__format = format - self.__directory = directory - - if format != "parquet": - raise ValueError("Invalid format (currently supported: 'parquet')") - - if filelist is None: - files = os.listdir(directory) - ex = re.compile(r"batch\=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet") - filematch = [ex.match(f) for f in files] - filematch = [f for f in filematch if f] - - if rank is not None: - filematch = [f for f in filematch if int(f[1]) == rank] - - batch_count = sum([int(f[4]) - int(f[2]) + 1 for f in filematch]) - filematch = sorted(filematch, key=lambda f: int(f[2]), reverse=True) - - self.__files = filematch - else: - self.__files = list(filelist) - - if rank is None: - self.__batch_count = batch_count - else: - batch_count = torch.tensor([batch_count], device="cuda") - torch.distributed.all_reduce(batch_count, torch.distributed.ReduceOp.MIN) - self.__batch_count = int(batch_count) - - def __iter__(self): - return self - - def __next__(self): - torch = import_optional("torch") - - if len(self.__files) > 0: - f = self.__files.pop() - fname = f[0] - start_inclusive = int(f[2]) - end_inclusive = int(f[4]) - - if (end_inclusive - start_inclusive + 1) > self.__batch_count: - end_inclusive = start_inclusive + self.__batch_count - 1 - self.__batch_count = 0 - else: - self.__batch_count -= end_inclusive - start_inclusive + 1 - - df = cudf.read_parquet(os.path.join(self.__directory, fname)) - tensors = {} - for col in list(df.columns): - s = df[col].dropna() - if len(s) > 0: - tensors[col] = torch.as_tensor(s, device="cuda") - df.drop(col, axis=1, inplace=True) - - return tensors, start_inclusive, end_inclusive - - raise StopIteration - - -class DistSampleWriter: - def __init__( - self, - directory: str, - *, - batches_per_partition: int = 256, - format: str = "parquet", - ): - """ - Parameters - ---------- - directory: str (required) - The directory where samples will be written. This - writer can only write to disk. - batches_per_partition: int (optional, default=256) - The number of batches to write in a single file. - format: str (optional, default='parquet') - The file format of the output files containing the - sampled minibatches. Currently, only parquet format - is supported. - """ - if format != "parquet": - raise ValueError("Invalid format (currently supported: 'parquet')") - - self.__format = format - self.__directory = directory - self.__batches_per_partition = batches_per_partition - - @property - def _format(self): - return self.__format - - @property - def _directory(self): - return self.__directory - - @property - def _batches_per_partition(self): - return self.__batches_per_partition - - def get_reader( - self, rank: int - ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: - """ - Returns an iterator over sampled data. - """ - - # currently only disk reading is supported - return DistSampleReader(self._directory, format=self._format, rank=rank) - - def __write_minibatches_coo(self, minibatch_dict): - has_edge_ids = minibatch_dict["edge_id"] is not None - has_edge_types = minibatch_dict["edge_type"] is not None - has_weights = minibatch_dict["weight"] is not None - - if minibatch_dict["renumber_map"] is None: - raise ValueError( - "Distributed sampling without renumbering is not supported" - ) - - # Quit if there are no batches to write. - if len(minibatch_dict["batch_id"]) == 0: - return - - fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( - minibatch_dict["batch_id"] - ) - rank_batch_offset = minibatch_dict["batch_id"][0] - - for p in range( - 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) - ): - partition_start = p * (self.__batches_per_partition) - partition_end = (p + 1) * (self.__batches_per_partition) - - label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ - partition_start * fanout_length : partition_end * fanout_length + 1 - ] - - batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] - start_batch_id = batch_id_array_p[0] - rank_batch_offset - - start_ix, end_ix = label_hop_offsets_array_p[[0, -1]] - majors_array_p = minibatch_dict["majors"][start_ix:end_ix] - minors_array_p = minibatch_dict["minors"][start_ix:end_ix] - edge_id_array_p = ( - minibatch_dict["edge_id"][start_ix:end_ix] - if has_edge_ids - else cupy.array([], dtype="int64") - ) - edge_type_array_p = ( - minibatch_dict["edge_type"][start_ix:end_ix] - if has_edge_types - else cupy.array([], dtype="int32") - ) - weight_array_p = ( - minibatch_dict["weight"][start_ix:end_ix] - if has_weights - else cupy.array([], dtype="float32") - ) - - # create the renumber map offsets - renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ - partition_start : partition_end + 1 - ] - - renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ - [0, -1] - ] - - renumber_map_array_p = minibatch_dict["renumber_map"][ - renumber_map_start_ix:renumber_map_end_ix - ] - - results_dataframe_p = create_df_from_disjoint_arrays( - { - "majors": majors_array_p, - "minors": minors_array_p, - "map": renumber_map_array_p, - "label_hop_offsets": label_hop_offsets_array_p, - "weight": weight_array_p, - "edge_id": edge_id_array_p, - "edge_type": edge_type_array_p, - "renumber_map_offsets": renumber_map_offsets_array_p, - } - ) - - end_batch_id = start_batch_id + len(batch_id_array_p) - 1 - rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 - - full_output_path = os.path.join( - self.__directory, - f"batch={rank:05d}.{start_batch_id:08d}-" - f"{rank:05d}.{end_batch_id:08d}.parquet", - ) - - results_dataframe_p.to_parquet( - full_output_path, - compression=None, - index=False, - force_nullable_schema=True, - ) - - def __write_minibatches_csr(self, minibatch_dict): - has_edge_ids = minibatch_dict["edge_id"] is not None - has_edge_types = minibatch_dict["edge_type"] is not None - has_weights = minibatch_dict["weight"] is not None - - if minibatch_dict["renumber_map"] is None: - raise ValueError( - "Distributed sampling without renumbering is not supported" - ) - - # Quit if there are no batches to write. - if len(minibatch_dict["batch_id"]) == 0: - return - - fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( - minibatch_dict["batch_id"] - ) - - for p in range( - 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) - ): - partition_start = p * (self.__batches_per_partition) - partition_end = (p + 1) * (self.__batches_per_partition) - - label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ - partition_start * fanout_length : partition_end * fanout_length + 1 - ] - - batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] - start_batch_id = batch_id_array_p[0] - - # major offsets and minors - ( - major_offsets_start_incl, - major_offsets_end_incl, - ) = label_hop_offsets_array_p[[0, -1]] - - start_ix, end_ix = minibatch_dict["major_offsets"][ - [major_offsets_start_incl, major_offsets_end_incl] - ] - - major_offsets_array_p = minibatch_dict["major_offsets"][ - major_offsets_start_incl : major_offsets_end_incl + 1 - ] - - minors_array_p = minibatch_dict["minors"][start_ix:end_ix] - edge_id_array_p = ( - minibatch_dict["edge_id"][start_ix:end_ix] - if has_edge_ids - else cupy.array([], dtype="int64") - ) - edge_type_array_p = ( - minibatch_dict["edge_type"][start_ix:end_ix] - if has_edge_types - else cupy.array([], dtype="int32") - ) - weight_array_p = ( - minibatch_dict["weight"][start_ix:end_ix] - if has_weights - else cupy.array([], dtype="float32") - ) - - # create the renumber map offsets - renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ - partition_start : partition_end + 1 - ] - - renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ - [0, -1] - ] - - renumber_map_array_p = minibatch_dict["renumber_map"][ - renumber_map_start_ix:renumber_map_end_ix - ] - - results_dataframe_p = create_df_from_disjoint_arrays( - { - "major_offsets": major_offsets_array_p, - "minors": minors_array_p, - "map": renumber_map_array_p, - "label_hop_offsets": label_hop_offsets_array_p, - "weight": weight_array_p, - "edge_id": edge_id_array_p, - "edge_type": edge_type_array_p, - "renumber_map_offsets": renumber_map_offsets_array_p, - } - ) - - end_batch_id = start_batch_id + len(batch_id_array_p) - 1 - rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 - - full_output_path = os.path.join( - self.__directory, - f"batch={rank:05d}.{start_batch_id:08d}-" - f"{rank:05d}.{end_batch_id:08d}.parquet", - ) - - results_dataframe_p.to_parquet( - full_output_path, - compression=None, - index=False, - force_nullable_schema=True, - ) - - def write_minibatches(self, minibatch_dict): - if (minibatch_dict["majors"] is not None) and ( - minibatch_dict["minors"] is not None - ): - self.__write_minibatches_coo(minibatch_dict) - elif (minibatch_dict["major_offsets"] is not None) and ( - minibatch_dict["minors"] is not None - ): - self.__write_minibatches_csr(minibatch_dict) - else: - raise ValueError("invalid columns") - - class DistSampler: def __init__( self, graph: Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph], - writer: DistSampleWriter, + writer: Optional[DistSampleWriter], local_seeds_per_call: int, retain_original_seeds: bool = False, ): @@ -379,7 +48,8 @@ def __init__( The pylibcugraph graph object that will be sampled. writer: DistSampleWriter (required) The writer responsible for writing samples to disk - or, in the future, device or host memory. + or; if None, then samples will be written to memory + instead. local_seeds_per_call: int The number of seeds on this rank this sampler will process in a single sampling call. Batches will @@ -402,14 +72,6 @@ def __init__( self.__handle = None self.__retain_original_seeds = retain_original_seeds - def get_reader(self) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: - """ - Returns an iterator over sampled data. - """ - torch = import_optional("torch") - rank = torch.distributed.get_rank() if self.is_multi_gpu else None - return self.__writer.get_reader(rank) - def sample_batches( self, seeds: TensorType, @@ -564,6 +226,108 @@ def get_start_batch_offset( else: return 0, input_size_is_equal + def __sample_from_nodes_func( + self, + call_id: int, + current_seeds_and_ix: Tuple["torch.Tensor", "torch.Tensor"], + batch_id_start: int, + batch_size: int, + batches_per_call: int, + random_state: int, + assume_equal_input_size: bool, + ) -> Union[None, Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]]: + torch = import_optional("torch") + + current_seeds, current_ix = current_seeds_and_ix + + current_batches = torch.arange( + batch_id_start + call_id * batches_per_call, + batch_id_start + + call_id * batches_per_call + + int(ceil(len(current_seeds))) + + 1, + device="cuda", + dtype=torch.int32, + ) + + current_batches = current_batches.repeat_interleave(batch_size)[ + : len(current_seeds) + ] + + # do qr division to get the number of batch_size batches and the + # size of the last batch + num_full, last_count = divmod(len(current_seeds), batch_size) + input_offsets = torch.concatenate( + [ + torch.tensor([0], device="cuda", dtype=torch.int64), + torch.full((num_full,), batch_size, device="cuda", dtype=torch.int64), + torch.tensor([last_count], device="cuda", dtype=torch.int64) + if last_count > 0 + else torch.tensor([], device="cuda", dtype=torch.int64), + ] + ).cumsum(-1) + + minibatch_dict = self.sample_batches( + seeds=current_seeds, + batch_ids=current_batches, + random_state=random_state, + assume_equal_input_size=assume_equal_input_size, + ) + minibatch_dict["input_index"] = current_ix.cuda() + minibatch_dict["input_offsets"] = input_offsets + + if self.__writer is None: + # rename renumber_map -> map to match unbuffered format + minibatch_dict["map"] = minibatch_dict["renumber_map"] + del minibatch_dict["renumber_map"] + minibatch_dict = { + k: torch.as_tensor(v, device="cuda") + for k, v in minibatch_dict.items() + if v is not None + } + + return iter([(minibatch_dict, current_batches[0], current_batches[-1])]) + else: + self.__writer.write_minibatches(minibatch_dict) + return None + + def __get_call_groups( + self, + seeds: TensorType, + input_id: TensorType, + seeds_per_call: int, + assume_equal_input_size: bool = False, + ): + torch = import_optional("torch") + + # Split the input seeds into call groups. Each call group + # corresponds to one sampling call. A call group contains + # many batches. + seeds_call_groups = torch.split(seeds, seeds_per_call, dim=-1) + index_call_groups = torch.split(input_id, seeds_per_call, dim=-1) + + # Need to add empties to the list of call groups to handle the case + # where not all ranks have the same number of call groups. This + # prevents a hang since we need all ranks to make the same number + # of calls. + if not assume_equal_input_size: + num_call_groups = torch.tensor( + [len(seeds_call_groups)], device="cuda", dtype=torch.int32 + ) + torch.distributed.all_reduce( + num_call_groups, op=torch.distributed.ReduceOp.MAX + ) + seeds_call_groups = list(seeds_call_groups) + ( + [torch.tensor([], dtype=seeds.dtype, device="cuda")] + * (int(num_call_groups) - len(seeds_call_groups)) + ) + index_call_groups = list(index_call_groups) + ( + [torch.tensor([], dtype=torch.int64, device=input_id.device)] + * (int(num_call_groups) - len(index_call_groups)) + ) + + return seeds_call_groups, index_call_groups + def sample_from_nodes( self, nodes: TensorType, @@ -571,7 +335,8 @@ def sample_from_nodes( batch_size: int = 16, random_state: int = 62, assume_equal_input_size: bool = False, - ): + input_id: Optional[TensorType] = None, + ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: """ Performs node-based sampling. Accepts a list of seed nodes, and batch size. Splits the seed list into batches, then divides the batches into call groups @@ -587,64 +352,301 @@ def sample_from_nodes( The size of each batch. random_state: int The random seed to use for sampling. + assume_equal_input_size: bool + Whether the inputs across workers should be assumed to be equal in + dimension. Skips some checks if True. + input_id: Optional[TensorType] + Input ids corresponding to the original batch tensor, if it + was permuted prior to calling this function. If present, + will be saved with the samples. """ torch = import_optional("torch") nodes = torch.as_tensor(nodes, device="cuda") + num_seeds = nodes.numel() batches_per_call = self._local_seeds_per_call // batch_size actual_seeds_per_call = batches_per_call * batch_size - # Split the input seeds into call groups. Each call group - # corresponds to one sampling call. A call group contains - # many batches. - num_seeds = len(nodes) - nodes_call_groups = torch.split(nodes, actual_seeds_per_call) + if input_id is None: + input_id = torch.arange(num_seeds, dtype=torch.int64, device="cpu") local_num_batches = int(ceil(num_seeds / batch_size)) batch_id_start, input_size_is_equal = self.get_start_batch_offset( local_num_batches, assume_equal_input_size=assume_equal_input_size ) - # Need to add empties to the list of call groups to handle the case - # where not all nodes have the same number of call groups. This - # prevents a hang since we need all ranks to make the same number - # of calls. - if not input_size_is_equal: - num_call_groups = torch.tensor( - [len(nodes_call_groups)], device="cuda", dtype=torch.int32 - ) - torch.distributed.all_reduce( - num_call_groups, op=torch.distributed.ReduceOp.MAX + nodes_call_groups, index_call_groups = self.__get_call_groups( + nodes, + input_id, + actual_seeds_per_call, + assume_equal_input_size=input_size_is_equal, + ) + + sample_args = ( + batch_id_start, + batch_size, + batches_per_call, + random_state, + input_size_is_equal, + ) + + if self.__writer is None: + # Buffered sampling + return BufferedSampleReader( + zip(nodes_call_groups, index_call_groups), + self.__sample_from_nodes_func, + *sample_args, ) - nodes_call_groups = list(nodes_call_groups) + ( - [torch.tensor([], dtype=nodes.dtype, device="cuda")] - * (int(num_call_groups) - len(nodes_call_groups)) + else: + # Unbuffered sampling + for i, current_seeds_and_ix in enumerate( + zip(nodes_call_groups, index_call_groups) + ): + self.__sample_from_nodes_func( + i, + current_seeds_and_ix, + *sample_args, + ) + + # Return a reader that points to the stored samples + rank = torch.distributed.get_rank() if self.is_multi_gpu else None + return self.__writer.get_reader(rank) + + def __sample_from_edges_func( + self, + call_id: int, + current_seeds_and_ix: Tuple["torch.Tensor", "torch.Tensor"], + batch_id_start: int, + batch_size: int, + batches_per_call: int, + random_state: int, + assume_equal_input_size: bool, + ) -> Union[None, Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]]: + torch = import_optional("torch") + + current_seeds, current_ix = current_seeds_and_ix + num_seed_edges = current_ix.numel() + + # The index gets stored as-is regardless of what makes it into + # the final batch and in what order. + # do qr division to get the number of batch_size batches and the + # size of the last batch + num_whole_batches, last_count = divmod(num_seed_edges, batch_size) + input_offsets = torch.concatenate( + [ + torch.tensor([0], device="cuda", dtype=torch.int64), + torch.full( + (num_whole_batches,), batch_size, device="cuda", dtype=torch.int64 + ), + torch.tensor([last_count], device="cuda", dtype=torch.int64) + if last_count > 0 + else torch.tensor([], device="cuda", dtype=torch.int64), + ] + ).cumsum(-1) + + current_seeds, leftover_seeds = ( + current_seeds[:, : (batch_size * num_whole_batches)], + current_seeds[:, (batch_size * num_whole_batches) :], + ) + + # For input edges, we need to translate this into unique vertices + # for each batch. + # We start by reorganizing the seed and index tensors so we can + # determine the unique vertices. This results in the expected + # src-to-dst concatenation for each batch + current_seeds = torch.concat( + [ + current_seeds[0].reshape((-1, batch_size)), + current_seeds[1].reshape((-1, batch_size)), + ], + axis=-1, + ) + + # The returned unique values must be sorted or else the inverse won't line up + # In the future this may be a good target for a C++ function + # Each element is a tuple of (unique, index, inverse) + # The seeds must be presorted with a stable sort prior to calling + # unique_consecutive in order to support negative sampling. This is + # because if we put positive edges after negative ones, then we may + # inadvertently turn a true positive into a false negative. + y = ( + torch.sort( + t, + stable=True, ) + for t in current_seeds + ) + z = ((v, torch.sort(i)[1]) for v, i in y) - # Make a call to sample_batches for each call group - for i, current_seeds in enumerate(nodes_call_groups): - current_batches = torch.arange( - batch_id_start + i * batches_per_call, - batch_id_start - + i * batches_per_call - + int(ceil(len(current_seeds))) - + 1, - device="cuda", - dtype=torch.int32, + u = [ + ( + torch.unique_consecutive( + t, + return_inverse=True, + ), + i, ) + for t, i in z + ] - current_batches = current_batches.repeat_interleave(batch_size)[ - : len(current_seeds) + if len(u) > 0: + current_seeds = torch.concat([a[0] for a, _ in u]) + current_inv = torch.concat([a[1][i] for a, i in u]) + current_batches = torch.concat( + [ + torch.full( + (a[0].numel(),), + i + batch_id_start + (call_id * batches_per_call), + device="cuda", + dtype=torch.int32, + ) + for i, (a, _) in enumerate(u) + ] + ) + else: + current_seeds = torch.tensor([], device="cuda", dtype=torch.int64) + current_inv = torch.tensor([], device="cuda", dtype=torch.int64) + current_batches = torch.tensor([], device="cuda", dtype=torch.int32) + del u + + # Join with the leftovers + leftover_seeds, lyi = torch.sort( + leftover_seeds.flatten(), + stable=True, + ) + lz = torch.sort(lyi)[1] + leftover_seeds, lui = leftover_seeds.unique_consecutive(return_inverse=True) + leftover_inv = lui[lz] + + current_seeds = torch.concat([current_seeds, leftover_seeds]) + current_inv = torch.concat([current_inv, leftover_inv]) + current_batches = torch.concat( + [ + current_batches, + torch.full( + (leftover_seeds.numel(),), + (current_batches[-1] + 1) if current_batches.numel() > 0 else 0, + device="cuda", + dtype=torch.int32, + ), ] + ) + del leftover_seeds + del lz + del lui + + minibatch_dict = self.sample_batches( + seeds=current_seeds, + batch_ids=current_batches, + random_state=random_state, + assume_equal_input_size=assume_equal_input_size, + ) + minibatch_dict["input_index"] = current_ix.cuda() + minibatch_dict["input_offsets"] = input_offsets + minibatch_dict[ + "edge_inverse" + ] = current_inv # (2 * batch_size) entries per batch + + if self.__writer is None: + # rename renumber_map -> map to match unbuffered format + minibatch_dict["map"] = minibatch_dict["renumber_map"] + del minibatch_dict["renumber_map"] + minibatch_dict = { + k: torch.as_tensor(v, device="cuda") + for k, v in minibatch_dict.items() + if v is not None + } + + return iter([(minibatch_dict, current_batches[0], current_batches[-1])]) + else: + self.__writer.write_minibatches(minibatch_dict) + return None - minibatch_dict = self.sample_batches( - seeds=current_seeds, - batch_ids=current_batches, - random_state=random_state, - assume_equal_input_size=input_size_is_equal, + def sample_from_edges( + self, + edges: TensorType, + *, + batch_size: int = 16, + random_state: int = 62, + assume_equal_input_size: bool = False, + input_id: Optional[TensorType] = None, + ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: + """ + Performs sampling starting from seed edges. + + Parameters + ---------- + edges: TensorType + 2 x (# edges) tensor of edges to sample from. + Standard src/dst format. This will be converted + to a list of seed nodes. + batch_size: int + The size of each batch. + random_state: int + The random seed to use for sampling. + assume_equal_input_size: bool + Whether this function should assume that inputs + are equal across ranks. Skips some potentially + slow steps if True. + input_id: Optional[TensorType] + Input ids corresponding to the original batch tensor, if it + was permuted prior to calling this function. If present, + will be saved with the samples. + """ + + torch = import_optional("torch") + + edges = torch.as_tensor(edges, device="cuda") + num_seed_edges = edges.shape[-1] + + batches_per_call = self._local_seeds_per_call // batch_size + actual_seed_edges_per_call = batches_per_call * batch_size + + if input_id is None: + input_id = torch.arange(len(edges), dtype=torch.int64, device="cpu") + + local_num_batches = int(ceil(num_seed_edges / batch_size)) + batch_id_start, input_size_is_equal = self.get_start_batch_offset( + local_num_batches, assume_equal_input_size=assume_equal_input_size + ) + + edges_call_groups, index_call_groups = self.__get_call_groups( + edges, + input_id, + actual_seed_edges_per_call, + assume_equal_input_size=input_size_is_equal, + ) + + sample_args = ( + batch_id_start, + batch_size, + batches_per_call, + random_state, + input_size_is_equal, + ) + + if self.__writer is None: + # Buffered sampling + return BufferedSampleReader( + zip(edges_call_groups, index_call_groups), + self.__sample_from_edges_func, + *sample_args, ) - self.__writer.write_minibatches(minibatch_dict) + else: + # Unbuffered sampling + for i, current_seeds_and_ix in enumerate( + zip(edges_call_groups, index_call_groups) + ): + self.__sample_from_edges_func( + i, + current_seeds_and_ix, + *sample_args, + ) + + # Return a reader that points to the stored samples + rank = torch.distributed.get_rank() if self.is_multi_gpu else None + return self.__writer.get_reader(rank) @property def is_multi_gpu(self): @@ -709,6 +711,8 @@ def __init__( # sampling. So setting the function here is safe. In the future, # if libcugraph allows setting a new attribute, this API might # change. + # TODO allow func to be a call to a future remote sampling API + # if the provided graph is in another process (rapidsai/cugraph#4623). self.__func = ( pylibcugraph.biased_neighbor_sample if biased diff --git a/python/cugraph/cugraph/tests/data_store/test_property_graph.py b/python/cugraph/cugraph/tests/data_store/test_property_graph.py index da5608e0193..50f08cdf3d0 100644 --- a/python/cugraph/cugraph/tests/data_store/test_property_graph.py +++ b/python/cugraph/cugraph/tests/data_store/test_property_graph.py @@ -2576,9 +2576,10 @@ def bench_extract_subgraph_for_rmat(gpubenchmark, rmat_PropertyGraph): scn = PropertyGraph.src_col_name dcn = PropertyGraph.dst_col_name - verts = [] - for i in range(0, 10000, 10): - verts.append(generated_df["src"].iloc[i]) + # Build a query string to extract a graph with only specific edges based on + # the integer vertex IDs. Other edge and/or vertex properties can be + # included in the query as well. + verts = [int(generated_df["src"].iloc[i]) for i in range(0, 10000, 10)] selected_edges = pG.select_edges(f"{scn}.isin({verts}) | {dcn}.isin({verts})") gpubenchmark( @@ -2618,9 +2619,10 @@ def bench_extract_subgraph_for_rmat_detect_duplicate_edges( scn = PropertyGraph.src_col_name dcn = PropertyGraph.dst_col_name - verts = [] - for i in range(0, 10000, 10): - verts.append(generated_df["src"].iloc[i]) + # Build a query string to extract a graph with only specific edges based on + # the integer vertex IDs. Other edge and/or vertex properties can be + # included in the query as well. + verts = [int(generated_df["src"].iloc[i]) for i in range(0, 10000, 10)] selected_edges = pG.select_edges(f"{scn}.isin({verts}) | {dcn}.isin({verts})") diff --git a/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py b/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py index 70b20e7baec..64db0232fb1 100644 --- a/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py +++ b/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py @@ -20,6 +20,7 @@ from cugraph.datasets import karate from cugraph.gnn import UniformNeighborSampler, DistSampleWriter +from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays from pylibcugraph import SGGraph, ResourceHandle, GraphProperties @@ -41,7 +42,7 @@ @pytest.fixture -def karate_graph(): +def karate_graph() -> SGGraph: el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"}) G = SGGraph( ResourceHandle(), @@ -101,3 +102,60 @@ def test_dist_sampler_simple( assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i] shutil.rmtree(samples_path) + + +@pytest.mark.sg +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.parametrize("seeds_per_call", [4, 5, 10]) +@pytest.mark.parametrize("compression", ["COO", "CSR"]) +def test_dist_sampler_buffered_in_memory( + scratch_dir: str, karate_graph: SGGraph, seeds_per_call: int, compression: str +): + G = karate_graph + + samples_path = os.path.join(scratch_dir, "test_bulk_sampler_buffered_in_memory") + create_directory_with_overwrite(samples_path) + + seeds = cupy.arange(10, dtype="int64") + + unbuffered_sampler = UniformNeighborSampler( + G, + writer=DistSampleWriter(samples_path), + local_seeds_per_call=seeds_per_call, + compression=compression, + ) + + buffered_sampler = UniformNeighborSampler( + G, + writer=None, + local_seeds_per_call=seeds_per_call, + compression=compression, + ) + + unbuffered_results = unbuffered_sampler.sample_from_nodes( + seeds, + batch_size=4, + ) + + unbuffered_results = [ + (create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in unbuffered_results + ] + + buffered_results = buffered_sampler.sample_from_nodes(seeds, batch_size=4) + buffered_results = [ + (create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in buffered_results + ] + + assert len(buffered_results) == len(unbuffered_results) + + for k in range(len(buffered_results)): + br, bs, be = buffered_results[k] + ur, us, ue = unbuffered_results[k] + + assert bs == us + assert be == ue + + for col in ur.columns: + assert (br[col].dropna() == ur[col].dropna()).all() + + shutil.rmtree(samples_path) diff --git a/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py b/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py index a1c32938994..5bb541d6cf3 100644 --- a/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py +++ b/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py @@ -18,6 +18,8 @@ import cupy import cudf +from typing import Any + from cugraph.datasets import karate from cugraph.gnn import ( UniformNeighborSampler, @@ -27,6 +29,7 @@ cugraph_comms_init, cugraph_comms_shutdown, ) +from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays from pylibcugraph import MGGraph, ResourceHandle, GraphProperties from cugraph.utilities.utils import ( @@ -235,3 +238,80 @@ def test_dist_sampler_uneven(scratch_dir, batch_size, fanout, seeds_per_call): assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i] shutil.rmtree(samples_path) + + +def run_test_dist_sampler_buffered_in_memory( + rank: int, + world_size: int, + uid: Any, + samples_path: str, + seeds_per_call: int, + compression: str, +): + init_pytorch(rank, world_size) + cugraph_comms_init(rank, world_size, uid, device=rank) + + G = karate_mg_graph(rank, world_size) + + num_seeds = 8 + seeds = cupy.random.randint(0, 34, num_seeds, dtype="int64") + + unbuffered_sampler = UniformNeighborSampler( + G, + writer=DistSampleWriter(samples_path), + local_seeds_per_call=seeds_per_call, + compression=compression, + ) + + buffered_sampler = UniformNeighborSampler( + G, + writer=None, + local_seeds_per_call=seeds_per_call, + compression=compression, + ) + + unbuffered_results = unbuffered_sampler.sample_from_nodes( + seeds, + batch_size=4, + ) + + unbuffered_results = [ + (create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in unbuffered_results + ] + + buffered_results = buffered_sampler.sample_from_nodes(seeds, batch_size=4) + buffered_results = [ + (create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in buffered_results + ] + + assert len(buffered_results) == len(unbuffered_results) + + for k in range(len(buffered_results)): + br, bs, be = buffered_results[k] + ur, us, ue = unbuffered_results[k] + + assert bs == us + assert be == ue + + for col in ur.columns: + assert (br[col].dropna() == ur[col].dropna()).all() + + +@pytest.mark.mg +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.parametrize("seeds_per_call", [4, 5, 10]) +@pytest.mark.parametrize("compression", ["COO", "CSR"]) +def test_dist_sampler_buffered_in_memory(scratch_dir, seeds_per_call, compression): + uid = cugraph_comms_create_unique_id() + + samples_path = os.path.join(scratch_dir, "test_bulk_sampler_buffered_in_memory_mg") + create_directory_with_overwrite(samples_path) + + world_size = torch.cuda.device_count() + torch.multiprocessing.spawn( + run_test_dist_sampler_buffered_in_memory, + args=(world_size, uid, samples_path, seeds_per_call, compression), + nprocs=world_size, + ) + + shutil.rmtree(samples_path) diff --git a/python/cugraph/pyproject.toml b/python/cugraph/pyproject.toml index 142d0bcd5fa..8185a8d915d 100644 --- a/python/cugraph/pyproject.toml +++ b/python/cugraph/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "dask-cudf==24.12.*,>=0.0.0a0", "fsspec[http]>=0.6.0", "numba>=0.57", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "pylibcugraph==24.12.*,>=0.0.0a0", "raft-dask==24.12.*,>=0.0.0a0", "rapids-dask-dependency==24.12.*,>=0.0.0a0", @@ -47,7 +47,7 @@ classifiers = [ [project.optional-dependencies] test = [ "networkx>=2.5.1", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "pandas", "pylibwholegraph==24.12.*,>=0.0.0a0", "pytest", diff --git a/python/cugraph/pytest.ini b/python/cugraph/pytest.ini index 675a6cf8fde..bca148538d9 100644 --- a/python/cugraph/pytest.ini +++ b/python/cugraph/pytest.ini @@ -17,6 +17,7 @@ addopts = --benchmark-max-time=0 --benchmark-min-rounds=1 --benchmark-columns="mean, rounds" + --tb=native ## do not run the slow tests/benchmarks by default -m "not slow" ## for use with rapids-pytest-benchmark plugin diff --git a/python/nx-cugraph/nx_cugraph/tests/pytest.ini b/python/nx-cugraph/nx_cugraph/tests/pytest.ini new file mode 100644 index 00000000000..7b0a9f29fb1 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/tests/pytest.ini @@ -0,0 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +[pytest] +addopts = --tb=native diff --git a/python/nx-cugraph/pyproject.toml b/python/nx-cugraph/pyproject.toml index ef2d9a8eda9..d145aa549da 100644 --- a/python/nx-cugraph/pyproject.toml +++ b/python/nx-cugraph/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ dependencies = [ "cupy-cuda11x>=12.0.0", "networkx>=3.0", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "pylibcugraph==24.12.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. diff --git a/python/pylibcugraph/pylibcugraph/CMakeLists.txt b/python/pylibcugraph/pylibcugraph/CMakeLists.txt index 514caeac6aa..9f1b9924336 100644 --- a/python/pylibcugraph/pylibcugraph/CMakeLists.txt +++ b/python/pylibcugraph/pylibcugraph/CMakeLists.txt @@ -55,6 +55,7 @@ set(cython_sources two_hop_neighbors.pyx uniform_neighbor_sample.pyx biased_neighbor_sample.pyx + negative_sampling.pyx uniform_random_walks.pyx utils.pyx weakly_connected_components.pyx diff --git a/python/pylibcugraph/pylibcugraph/__init__.py b/python/pylibcugraph/pylibcugraph/__init__.py index 8a8923827b8..26fa3f64ddd 100644 --- a/python/pylibcugraph/pylibcugraph/__init__.py +++ b/python/pylibcugraph/pylibcugraph/__init__.py @@ -41,6 +41,8 @@ from pylibcugraph.biased_neighbor_sample import biased_neighbor_sample +from pylibcugraph.negative_sampling import negative_sampling + from pylibcugraph.core_number import core_number from pylibcugraph.k_core import k_core diff --git a/python/pylibcugraph/pylibcugraph/_cugraph_c/sampling_algorithms.pxd b/python/pylibcugraph/pylibcugraph/_cugraph_c/sampling_algorithms.pxd index 4a707db03c5..c982b12665a 100644 --- a/python/pylibcugraph/pylibcugraph/_cugraph_c/sampling_algorithms.pxd +++ b/python/pylibcugraph/pylibcugraph/_cugraph_c/sampling_algorithms.pxd @@ -112,10 +112,10 @@ cdef extern from "cugraph_c/sampling_algorithms.h": const cugraph_resource_handle_t* handle, cugraph_rng_state_t* rng_state, cugraph_graph_t* graph, - size_t num_samples, const cugraph_type_erased_device_array_view_t* vertices, const cugraph_type_erased_device_array_view_t* src_bias, const cugraph_type_erased_device_array_view_t* dst_bias, + size_t num_samples, bool_t remove_duplicates, bool_t remove_false_negatives, bool_t exact_number_of_samples, diff --git a/python/pylibcugraph/pylibcugraph/internal_types/CMakeLists.txt b/python/pylibcugraph/pylibcugraph/internal_types/CMakeLists.txt index 1ca169c5869..22f07939db0 100644 --- a/python/pylibcugraph/pylibcugraph/internal_types/CMakeLists.txt +++ b/python/pylibcugraph/pylibcugraph/internal_types/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -14,6 +14,7 @@ set(cython_sources sampling_result.pyx + coo.pyx ) set(linked_libraries cugraph::cugraph;cugraph::cugraph_c) diff --git a/python/pylibcugraph/pylibcugraph/internal_types/coo.pxd b/python/pylibcugraph/pylibcugraph/internal_types/coo.pxd new file mode 100644 index 00000000000..129b0be4dbe --- /dev/null +++ b/python/pylibcugraph/pylibcugraph/internal_types/coo.pxd @@ -0,0 +1,28 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Have cython use python 3 syntax +# cython: language_level = 3 + + +from pylibcugraph._cugraph_c.coo cimport ( + cugraph_coo_t, +) +from pylibcugraph._cugraph_c.array cimport ( + cugraph_type_erased_device_array_view_t, +) + +cdef class COO: + cdef cugraph_coo_t* c_coo_ptr + cdef set_ptr(self, cugraph_coo_t* ptr) + cdef get_array(self, cugraph_type_erased_device_array_view_t* ptr) diff --git a/python/pylibcugraph/pylibcugraph/internal_types/coo.pyx b/python/pylibcugraph/pylibcugraph/internal_types/coo.pyx new file mode 100644 index 00000000000..64d10c22eaf --- /dev/null +++ b/python/pylibcugraph/pylibcugraph/internal_types/coo.pyx @@ -0,0 +1,96 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Have cython use python 3 syntax +# cython: language_level = 3 + +from pylibcugraph._cugraph_c.coo cimport ( + cugraph_coo_t, + cugraph_coo_free, + cugraph_coo_get_sources, + cugraph_coo_get_destinations, + cugraph_coo_get_edge_weights, + cugraph_coo_get_edge_id, + cugraph_coo_get_edge_type, +) +from pylibcugraph._cugraph_c.array cimport ( + cugraph_type_erased_device_array_view_t, +) +from pylibcugraph.utils cimport create_cupy_array_view_for_device_ptr + +cdef class COO: + """ + Cython interface to a cugraph_coo_t pointer. Instances of this + call will take ownership of the pointer and free it under standard python + GC rules (ie. when all references to it are no longer present). + + This class provides methods to return non-owning cupy ndarrays for the + corresponding array members. Returning these cupy arrays increments the ref + count on the COO instances from which the cupy arrays are + referencing. + """ + def __cinit__(self): + # This COO instance owns sample_result_ptr now. It will be + # freed when this instance is deleted (see __dealloc__()) + self.c_coo_ptr = NULL + + def __dealloc__(self): + if self.c_coo_ptr is not NULL: + cugraph_coo_free(self.c_coo_ptr) + + cdef set_ptr(self, cugraph_coo_t* ptr): + self.c_coo_ptr = ptr + + cdef get_array(self, cugraph_type_erased_device_array_view_t* ptr): + if ptr is NULL: + return None + + return create_cupy_array_view_for_device_ptr( + ptr, + self, + ) + + def get_sources(self): + if self.c_coo_ptr is NULL: + raise ValueError("pointer not set, must call set_ptr() with a " + "non-NULL value first.") + cdef cugraph_type_erased_device_array_view_t* ptr = cugraph_coo_get_sources(self.c_coo_ptr) + return self.get_array(ptr) + + def get_destinations(self): + if self.c_coo_ptr is NULL: + raise ValueError("pointer not set, must call set_ptr() with a " + "non-NULL value first.") + cdef cugraph_type_erased_device_array_view_t* ptr = cugraph_coo_get_destinations(self.c_coo_ptr) + return self.get_array(ptr) + + def get_edge_ids(self): + if self.c_coo_ptr is NULL: + raise ValueError("pointer not set, must call set_ptr() with a " + "non-NULL value first.") + cdef cugraph_type_erased_device_array_view_t* ptr = cugraph_coo_get_edge_id(self.c_coo_ptr) + return self.get_array(ptr) + + def get_edge_types(self): + if self.c_coo_ptr is NULL: + raise ValueError("pointer not set, must call set_ptr() with a " + "non-NULL value first.") + cdef cugraph_type_erased_device_array_view_t* ptr = cugraph_coo_get_edge_type(self.c_coo_ptr) + return self.get_array(ptr) + + def get_edge_weights(self): + if self.c_coo_ptr is NULL: + raise ValueError("pointer not set, must call set_ptr() with a " + "non-NULL value first.") + cdef cugraph_type_erased_device_array_view_t* ptr = cugraph_coo_get_edge_weights(self.c_coo_ptr) + return self.get_array(ptr) diff --git a/python/pylibcugraph/pylibcugraph/negative_sampling.pyx b/python/pylibcugraph/pylibcugraph/negative_sampling.pyx new file mode 100644 index 00000000000..610cfa90ccf --- /dev/null +++ b/python/pylibcugraph/pylibcugraph/negative_sampling.pyx @@ -0,0 +1,184 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Have cython use python 3 syntax +# cython: language_level = 3 + +from libc.stdint cimport uintptr_t + +from pylibcugraph._cugraph_c.resource_handle cimport ( + cugraph_resource_handle_t, + bool_t, +) +from pylibcugraph._cugraph_c.error cimport ( + cugraph_error_code_t, + cugraph_error_t, +) +from pylibcugraph._cugraph_c.array cimport ( + cugraph_type_erased_device_array_view_t, + cugraph_type_erased_device_array_view_create, + cugraph_type_erased_device_array_view_free, + cugraph_type_erased_host_array_view_t, + cugraph_type_erased_host_array_view_create, + cugraph_type_erased_host_array_view_free, +) +from pylibcugraph.resource_handle cimport ( + ResourceHandle, +) +from pylibcugraph.graphs cimport ( + _GPUGraph, +) +from pylibcugraph._cugraph_c.graph cimport ( + cugraph_graph_t, +) +from pylibcugraph._cugraph_c.sampling_algorithms cimport ( + cugraph_negative_sampling, +) +from pylibcugraph._cugraph_c.coo cimport ( + cugraph_coo_t, +) +from pylibcugraph.internal_types.coo cimport ( + COO, +) +from pylibcugraph.utils cimport ( + assert_success, + assert_CAI_type, + create_cugraph_type_erased_device_array_view_from_py_obj, +) +from pylibcugraph._cugraph_c.random cimport ( + cugraph_rng_state_t +) +from pylibcugraph.random cimport ( + CuGraphRandomState +) + +def negative_sampling(ResourceHandle resource_handle, + _GPUGraph graph, + size_t num_samples, + random_state=None, + vertices=None, + src_bias=None, + dst_bias=None, + remove_duplicates=False, + remove_false_negatives=False, + exact_number_of_samples=False, + do_expensive_check=False): + """ + Performs negative sampling, which is essentially a form of graph generation. + + By setting vertices, src_bias, and dst_bias, this function can perform + biased negative sampling. + + Parameters + ---------- + resource_handle: ResourceHandle + Handle to the underlying device and host resources needed for + referencing data and running algorithms. + input_graph: SGGraph or MGGraph + The stored cuGraph graph to create negative samples for. + num_samples: int + The number of negative edges to generate for each positive edge. + random_state: int (Optional) + Random state to use when generating samples. Optional argument, + defaults to a hash of process id, time, and hostname. + (See pylibcugraph.random.CuGraphRandomState) + vertices: device array type (Optional) + Vertex ids corresponding to the src/dst biases, if provided. + Ignored if src/dst biases are not provided. + src_bias: device array type (Optional) + Probability per edge that a vertex is selected as a source vertex. + Does not have to be normalized. Uses a uniform distribution if + not provided. + dst_bias: device array type (Optional) + Probability per edge that a vertex is selected as a destination vertex. + Does not have to be normalized. Uses a uniform distribution if + not provided. + remove_duplicates: bool (Optional) + Whether to remove duplicate edges from the generated edgelist. + Defaults to False (does not remove duplicates). + remove_false_negatives: bool (Optional) + Whether to remove false negatives from the generated edgelist. + Defaults to False (does not check for and remove false negatives). + exact_number_of_samples: bool (Optional) + Whether to manually regenerate samples until the desired number + as specified by num_samples has been generated. + Defaults to False (does not regenerate if enough samples are not + produced in the initial round). + do_expensive_check: bool (Optional) + Whether to perform an expensive error check at the C++ level. + Defaults to False (no error check). + + Returns + ------- + dict[str, cupy.ndarray] + Generated edges in COO format. + """ + + assert_CAI_type(vertices, "vertices", True) + assert_CAI_type(src_bias, "src_bias", True) + assert_CAI_type(dst_bias, "dst_bias", True) + + cdef cugraph_resource_handle_t* c_resource_handle_ptr = ( + resource_handle.c_resource_handle_ptr + ) + + cdef cugraph_graph_t* c_graph_ptr = graph.c_graph_ptr + + cdef bool_t c_remove_duplicates = remove_duplicates + cdef bool_t c_remove_false_negatives = remove_false_negatives + cdef bool_t c_exact_number_of_samples = exact_number_of_samples + cdef bool_t c_do_expensive_check = do_expensive_check + + cg_rng_state = CuGraphRandomState(resource_handle, random_state) + + cdef cugraph_rng_state_t* rng_state_ptr = \ + cg_rng_state.rng_state_ptr + + cdef cugraph_type_erased_device_array_view_t* vertices_ptr = \ + create_cugraph_type_erased_device_array_view_from_py_obj(vertices) + cdef cugraph_type_erased_device_array_view_t* src_bias_ptr = \ + create_cugraph_type_erased_device_array_view_from_py_obj(src_bias) + cdef cugraph_type_erased_device_array_view_t* dst_bias_ptr = \ + create_cugraph_type_erased_device_array_view_from_py_obj(dst_bias) + + cdef cugraph_coo_t* result_ptr + cdef cugraph_error_t* err_ptr + cdef cugraph_error_code_t error_code + + error_code = cugraph_negative_sampling( + c_resource_handle_ptr, + rng_state_ptr, + c_graph_ptr, + vertices_ptr, + src_bias_ptr, + dst_bias_ptr, + num_samples, + c_remove_duplicates, + c_remove_false_negatives, + c_exact_number_of_samples, + c_do_expensive_check, + &result_ptr, + &err_ptr, + ) + assert_success(error_code, err_ptr, "cugraph_negative_sampling") + + coo = COO() + coo.set_ptr(result_ptr) + + return { + 'sources': coo.get_sources(), + 'destinations': coo.get_destinations(), + 'edge_id': coo.get_edge_ids(), + 'edge_type': coo.get_edge_types(), + 'weight': coo.get_edge_weights(), + } diff --git a/python/pylibcugraph/pyproject.toml b/python/pylibcugraph/pyproject.toml index 98bbe255e3e..c12280473b5 100644 --- a/python/pylibcugraph/pyproject.toml +++ b/python/pylibcugraph/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ [project.optional-dependencies] test = [ "cudf==24.12.*,>=0.0.0a0", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "pandas", "pytest", "pytest-benchmark", diff --git a/python/pylibcugraph/pytest.ini b/python/pylibcugraph/pytest.ini index 573628de680..d5ade9f4836 100644 --- a/python/pylibcugraph/pytest.ini +++ b/python/pylibcugraph/pytest.ini @@ -14,3 +14,5 @@ [pytest] markers = cugraph_ops: Tests requiring cugraph-ops + +addopts = --tb=native