Skip to content

Commit

Permalink
Make pylibcugraphops optional imports in cugraph-dgl and -pyg (#…
Browse files Browse the repository at this point in the history
…3693)

Fixes #3691 .

Authors:
  - Tingyu Wang (https://github.com/tingyu66)

Approvers:
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Rick Ratzel (https://github.com/rlratzel)
  - Ray Douglass (https://github.com/raydouglass)

URL: #3693
  • Loading branch information
tingyu66 authored Jul 18, 2023
1 parent 32e6e51 commit 917b98b
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 56 deletions.
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ channels:
- dask/label/dev
- conda-forge
- nvidia
- pytorch
- dglteam/label/cu118
dependencies:
- aiohttp
- c-compiler
Expand Down
21 changes: 21 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,21 @@ files:
# list is really minimal or if it is a superset.
- test_python_common
- test_python_cugraph
cugraph_dgl_dev:
matrix:
cuda: ["11.8"]
output: conda
conda_dir: python/cugraph-dgl/conda
includes:
- cugraph_dgl_dev
channels:
- rapidsai
- rapidsai-nightly
- dask/label/dev
- conda-forge
- nvidia
- pytorch
- dglteam/label/cu118
dependencies:
checks:
common:
Expand Down Expand Up @@ -418,3 +427,15 @@ dependencies:
- output_types: [conda, pyproject]
packages:
- *cudf
cugraph_dgl_dev:
common:
- output_types: [conda]
packages:
- cugraph==23.8.*
- pylibcugraphops==23.8.*
- pytorch>=2.0
- pytorch-cuda==11.8
- dgl>=1.1.0.cu*
- setuptools
- pre-commit
- pytest
18 changes: 0 additions & 18 deletions python/cugraph-dgl/conda/cugraph_dgl_dev_11.6.yml

This file was deleted.

20 changes: 20 additions & 0 deletions python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This file is generated by `rapids-dependency-file-generator`.
# To make changes, edit ../../../dependencies.yaml and run `rapids-dependency-file-generator`.
channels:
- rapidsai
- rapidsai-nightly
- dask/label/dev
- conda-forge
- nvidia
- pytorch
- dglteam/label/cu118
dependencies:
- cugraph==23.8.*
- dgl>=1.1.0.cu*
- pre-commit
- pylibcugraphops==23.8.*
- pytest
- pytorch-cuda==11.8
- pytorch>=2.0
- setuptools
name: cugraph_dgl_dev_cuda-118
8 changes: 3 additions & 5 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from cugraph_dgl.nn.conv.base import BaseConv
from cugraph.utilities.utils import import_optional

from pylibcugraphops.pytorch import CSC
from pylibcugraphops.pytorch.operators import mha_gat_n2n

dgl = import_optional("dgl")
torch = import_optional("torch")
nn = import_optional("torch.nn")
ops_torch = import_optional("pylibcugraphops.pytorch")


class GATConv(BaseConv):
Expand Down Expand Up @@ -179,7 +177,7 @@ def forward(
bipartite = not isinstance(nfeat, torch.Tensor)
offsets, indices, _ = g.adj_tensors("csc")

graph = CSC(
graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
Expand Down Expand Up @@ -212,7 +210,7 @@ def forward(
)
nfeat = self.fc(nfeat)

out = mha_gat_n2n(
out = ops_torch.operators.mha_gat_n2n(
(nfeat_src, nfeat_dst) if bipartite else nfeat,
self.attn_weights,
graph,
Expand Down
8 changes: 3 additions & 5 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from cugraph_dgl.nn.conv.base import BaseConv
from cugraph.utilities.utils import import_optional

from pylibcugraphops.pytorch import CSC
from pylibcugraphops.pytorch.operators import mha_simple_n2n

dgl = import_optional("dgl")
torch = import_optional("torch")
nn = import_optional("torch.nn")
ops_torch = import_optional("pylibcugraphops.pytorch")


class TransformerConv(BaseConv):
Expand Down Expand Up @@ -133,7 +131,7 @@ def forward(
Edge feature tensor. Default: ``None``.
"""
offsets, indices, _ = g.adj_tensors("csc")
graph = CSC(
graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
Expand All @@ -155,7 +153,7 @@ def forward(
)
efeat = self.lin_edge(efeat)

out = mha_simple_n2n(
out = ops_torch.operators.mha_simple_n2n(
key_emb=key,
query_emb=query,
value_emb=value,
Expand Down
24 changes: 5 additions & 19 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,12 @@

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")

try: # pragma: no cover
from pylibcugraphops.pytorch import CSC, HeteroCSC

HAS_PYLIBCUGRAPHOPS = True
except ImportError:
HAS_PYLIBCUGRAPHOPS = False
ops_torch = import_optional("pylibcugraphops.pytorch")


class BaseConv(torch.nn.Module): # pragma: no cover
r"""An abstract base class for implementing cugraph-ops message passing layers."""

def __init__(self):
super().__init__()

if HAS_PYLIBCUGRAPHOPS is False:
raise ModuleNotFoundError(
f"'{self.__class__.__name__}' requires " f"'pylibcugraphops>=23.04'"
)

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
pass
Expand Down Expand Up @@ -88,7 +74,7 @@ def get_cugraph(
csc: Tuple[torch.Tensor, torch.Tensor, int],
bipartite: bool = False,
max_num_neighbors: Optional[int] = None,
) -> CSC:
) -> ops_torch.CSC:
r"""Constructs a :obj:`cugraph-ops` graph object from CSC representation.
Supports both bipartite and non-bipartite graphs.
Expand Down Expand Up @@ -116,7 +102,7 @@ def get_cugraph(
if max_num_neighbors is None:
max_num_neighbors = -1

return CSC(
return ops_torch.CSC(
offsets=colptr,
indices=row,
num_src_nodes=num_src_nodes,
Expand All @@ -131,7 +117,7 @@ def get_typed_cugraph(
num_edge_types: Optional[int] = None,
bipartite: bool = False,
max_num_neighbors: Optional[int] = None,
) -> HeteroCSC:
) -> ops_torch.HeteroCSC:
r"""Constructs a typed :obj:`cugraph` graph object from a CSC
representation where each edge corresponds to a given edge type.
Supports both bipartite and non-bipartite graphs.
Expand Down Expand Up @@ -162,7 +148,7 @@ def get_typed_cugraph(
row, colptr, num_src_nodes = csc
edge_type = edge_type.int()

return HeteroCSC(
return ops_torch.HeteroCSC(
offsets=colptr,
indices=row,
edge_types=edge_type,
Expand Down
5 changes: 2 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
# limitations under the License.
from typing import Optional, Tuple, Union

from pylibcugraphops.pytorch.operators import mha_gat_n2n

from cugraph.utilities.utils import import_optional

from .base import BaseConv

torch = import_optional("torch")
nn = import_optional("torch.nn")
torch_geometric = import_optional("torch_geometric")
ops_torch = import_optional("pylibcugraphops.pytorch")


class GATConv(BaseConv):
Expand Down Expand Up @@ -211,7 +210,7 @@ def forward(
)
x = self.lin(x)

out = mha_gat_n2n(
out = ops_torch.operators.mha_gat_n2n(
(x_src, x_dst) if bipartite else x,
self.att,
graph,
Expand Down
5 changes: 2 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
# limitations under the License.
from typing import Optional, Tuple, Union

from pylibcugraphops.pytorch.operators import mha_gat_v2_n2n

from cugraph.utilities.utils import import_optional

from .base import BaseConv

torch = import_optional("torch")
nn = import_optional("torch.nn")
torch_geometric = import_optional("torch_geometric")
ops_torch = import_optional("pylibcugraphops.pytorch")


class GATv2Conv(BaseConv):
Expand Down Expand Up @@ -208,7 +207,7 @@ def forward(
else:
x = self.lin_src(x)

out = mha_gat_v2_n2n(
out = ops_torch.operators.mha_gat_v2_n2n(
(x_src, x_dst) if bipartite else x,
self.att,
graph,
Expand Down
5 changes: 2 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
# limitations under the License.
from typing import Optional, Tuple, Union

from pylibcugraphops.pytorch.operators import mha_simple_n2n

from cugraph.utilities.utils import import_optional

from .base import BaseConv

torch = import_optional("torch")
nn = import_optional("torch.nn")
torch_geometric = import_optional("torch_geometric")
ops_torch = import_optional("pylibcugraphops.pytorch")


class TransformerConv(BaseConv):
Expand Down Expand Up @@ -186,7 +185,7 @@ def forward(
)
edge_attr = self.lin_edge(edge_attr)

out = mha_simple_n2n(
out = ops_torch.operators.mha_simple_n2n(
key,
query,
value,
Expand Down

0 comments on commit 917b98b

Please sign in to comment.