Skip to content

Commit

Permalink
Update all rmm imports to use pylibrmm/librmm
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Sep 25, 2024
1 parent 03f8025 commit 623666a
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 28 deletions.
6 changes: 3 additions & 3 deletions python/pylibraft/pylibraft/common/handle.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,8 +22,8 @@

from libcpp.memory cimport shared_ptr, unique_ptr

from rmm._lib.cuda_stream_pool cimport cuda_stream_pool
from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm.librmm.cuda_stream_pool cimport cuda_stream_pool
from rmm.librmm.cuda_stream_view cimport cuda_stream_view


# Keeping `handle_t` around for backwards compatibility at the
Expand Down
7 changes: 5 additions & 2 deletions python/pylibraft/pylibraft/common/handle.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,7 +24,10 @@ import functools
from cuda.ccudart cimport cudaStream_t
from libc.stdint cimport uintptr_t

from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread, cuda_stream_view
from rmm.librmm.cuda_stream_view cimport (
cuda_stream_per_thread,
cuda_stream_view,
)

from .cuda cimport Stream

Expand Down
4 changes: 2 additions & 2 deletions python/pylibraft/pylibraft/common/interruptible.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,7 @@

from libcpp.memory cimport shared_ptr

from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm.librmm.cuda_stream_view cimport cuda_stream_view


cdef extern from "raft/core/interruptible.hpp" namespace "raft" nogil:
Expand Down
4 changes: 2 additions & 2 deletions python/pylibraft/pylibraft/common/interruptible.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -25,7 +25,7 @@ import signal
from cuda.ccudart cimport cudaStream_t
from cython.operator cimport dereference

from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm.librmm.cuda_stream_view cimport cuda_stream_view

from .cuda cimport Stream

Expand Down
6 changes: 2 additions & 4 deletions python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ from pylibraft.common.handle cimport device_resources
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.input_validation import is_c_contiguous

from rmm._lib.memory_resource cimport (
DeviceMemoryResource,
device_memory_resource,
)
from rmm.librmm.memory_resource cimport device_memory_resource
from rmm.pylibrmm.memory_resource cimport DeviceMemoryResource

cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra
from pylibraft.common.optional cimport make_optional, optional
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t
from libcpp cimport bool, nullptr
from libcpp.string cimport string

from rmm._lib.memory_resource cimport device_memory_resource
from rmm.librmm.memory_resource cimport device_memory_resource

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ from libcpp cimport bool, nullptr
from libcpp.string cimport string
from libcpp.vector cimport vector

from rmm._lib.memory_resource cimport device_memory_resource
from rmm.librmm.memory_resource cimport device_memory_resource

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/neighbors/cpp/rbc.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ from libcpp cimport bool, nullptr
from libcpp.string cimport string
from libcpp.vector cimport vector

from rmm._lib.memory_resource cimport device_memory_resource
from rmm.librmm.memory_resource cimport device_memory_resource

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,7 +27,7 @@ from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uintptr_t
from libcpp cimport bool, nullptr
from libcpp.string cimport string

from rmm._lib.memory_resource cimport device_memory_resource
from rmm.librmm.memory_resource cimport device_memory_resource

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
Expand Down
8 changes: 3 additions & 5 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,10 +51,8 @@ from pylibraft.common.handle cimport device_resources
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.input_validation import is_c_contiguous

from rmm._lib.memory_resource cimport (
DeviceMemoryResource,
device_memory_resource,
)
from rmm.librmm.memory_resource cimport device_memory_resource
from rmm.pylibrmm.memory_resource cimport DeviceMemoryResource

cimport pylibraft.neighbors.ivf_flat.cpp.c_ivf_flat as c_ivf_flat
from pylibraft.common.cpp.optional cimport optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uintptr_t
from libcpp cimport bool, nullptr
from libcpp.string cimport string

from rmm._lib.memory_resource cimport device_memory_resource
from rmm.librmm.memory_resource cimport device_memory_resource

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
Expand Down
6 changes: 2 additions & 4 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ from pylibraft.common.handle cimport device_resources
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.input_validation import is_c_contiguous

from rmm._lib.memory_resource cimport (
DeviceMemoryResource,
device_memory_resource,
)
from rmm.librmm.memory_resource cimport device_memory_resource
from rmm.pylibrmm.memory_resource cimport DeviceMemoryResource

cimport pylibraft.neighbors.ivf_flat.cpp.c_ivf_flat as c_ivf_flat
cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq
Expand Down

0 comments on commit 623666a

Please sign in to comment.