Skip to content

Commit

Permalink
[WIP] Make edge ids optional (#3702)
Browse files Browse the repository at this point in the history
This PR makes edge ids optional for cugraph-dgl dataloaders


Todo: 
- [ ] Add tests

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Brad Rees (https://github.com/BradReesWork)

Approvers:
  - Alex Barghi (https://github.com/alexbarghi-nv)

URL: #3702
  • Loading branch information
VibhuJawa authored Jul 19, 2023
1 parent 59b0eb7 commit a280986
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def _get_tensor_ls_from_sampled_df(df):
batch_indices = torch.searchsorted(batch_id_tensor, batch_indices)

split_d = {}

for column in ["sources", "destinations", "edge_id", "hop_id"]:
tensor = cast_to_tensor(df[column])
split_d[column] = torch.tensor_split(tensor, batch_indices.cpu())
if column in df.columns:
tensor = cast_to_tensor(df[column])
split_d[column] = torch.tensor_split(tensor, batch_indices.cpu())

result_tensor_ls = []
for i, hop_id_tensor in enumerate(split_d["hop_id"]):
Expand All @@ -66,7 +68,11 @@ def _get_tensor_ls_from_sampled_df(df):
hop_indices = torch.searchsorted(hop_id_tensor, hop_indices)
s = torch.tensor_split(split_d["sources"][i], hop_indices.cpu())
d = torch.tensor_split(split_d["destinations"][i], hop_indices.cpu())
eid = torch.tensor_split(split_d["edge_id"][i], hop_indices.cpu())
if "edge_id" in split_d:
eid = torch.tensor_split(split_d["edge_id"][i], hop_indices.cpu())
else:
eid = [None] * len(s)

result_tensor_ls.append((x, y, z) for x, y, z in zip(s, d, eid))

return result_tensor_ls
Expand Down Expand Up @@ -125,15 +131,16 @@ def _create_homogeneous_sampled_graphs_from_tensors_perhop(
def create_homogeneous_dgl_block_from_tensors_ls(
src_ids: torch.Tensor,
dst_ids: torch.Tensor,
edge_ids: torch.Tensor,
edge_ids: Optional[torch.Tensor],
seed_nodes: Optional[torch.Tensor],
total_number_of_nodes: int,
):
sampled_graph = dgl.graph(
(src_ids, dst_ids),
num_nodes=total_number_of_nodes,
)
sampled_graph.edata[dgl.EID] = edge_ids
if edge_ids is not None:
sampled_graph.edata[dgl.EID] = edge_ids
# TODO: Check if unique is needed
if seed_nodes is None:
seed_nodes = dst_ids.unique()
Expand All @@ -144,7 +151,8 @@ def create_homogeneous_dgl_block_from_tensors_ls(
src_nodes=src_ids.unique(),
include_dst_in_src=True,
)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
if edge_ids is not None:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
return block


Expand Down
2 changes: 2 additions & 0 deletions python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def _write_samples_to_parquet(
raise ValueError("Invalid value of partition_info")

max_batch_id = offsets.batch_id.max()
results.dropna(axis=1, how="all", inplace=True)
results["hop_id"] = results["hop_id"].astype("uint8")

for p in range(0, len(offsets), batches_per_partition):
offsets_p = offsets.iloc[p : p + batches_per_partition]
Expand Down

0 comments on commit a280986

Please sign in to comment.