Skip to content

Commit

Permalink
Merge pull request #1412 from IntelPython/merge_exp_spv_targt_into_sp…
Browse files Browse the repository at this point in the history
…v_target

Completely remove numba_dpex.experimental module
  • Loading branch information
ZzEeKkAa committed Mar 28, 2024
2 parents 8e3b63d + 7729548 commit 53bb704
Show file tree
Hide file tree
Showing 22 changed files with 228 additions and 432 deletions.
6 changes: 6 additions & 0 deletions numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .kernel_api_impl.spirv import target as spirv_kernel_target
from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc
from .register_kernel_api_overloads import init_kernel_api_spirv_overloads


def load_dpctl_sycl_interface():
Expand Down Expand Up @@ -136,11 +137,16 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
__version__ = get_versions()["version"]
del get_versions

# Initialize the kernel_api SPIRV overloads
init_kernel_api_spirv_overloads()

__all__ = types.__all__ + [
"call_kernel",
"call_kernel_async",
"device_func",
"dpjit",
"kernel",
"prange",
"Range",
"NdRange",
]
12 changes: 7 additions & 5 deletions numba_dpex/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
)

from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME
from numba_dpex.experimental.target import DPEX_KERNEL_EXP_TARGET_NAME
from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
from numba_dpex.kernel_api_impl.spirv.target import CompilationMode
from numba_dpex.kernel_api_impl.spirv.target import (
SPIRV_TARGET_NAME,
CompilationMode,
)


def _parse_func_or_sig(signature_or_function):
Expand Down Expand Up @@ -154,7 +156,7 @@ def vecadd(item: kapi.Item, a, b, c):

# dispatcher is a type:
# <class 'numba_dpex.experimental.kernel_dispatcher.KernelDispatcher'>
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
dispatcher = resolve_dispatcher_from_str(SPIRV_TARGET_NAME)
if "_compilation_mode" in options:
user_compilation_mode = options["_compilation_mode"]
warn(
Expand Down Expand Up @@ -280,7 +282,7 @@ def another_kernel(nd_item: NdItem, a):
dpex_exp.call_kernel(another_kernel, dpex.NdRange((N,), (N,)), b)
"""
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
dispatcher = resolve_dispatcher_from_str(SPIRV_TARGET_NAME)

if "_compilation_mode" in options:
user_compilation_mode = options["_compilation_mode"]
Expand Down Expand Up @@ -342,4 +344,4 @@ def dpjit(*args, **kws):
# add it to the decorator registry, this is so e.g. @overload can look up a
# JIT function to do the compilation work.
jit_registry[target_registry[DPEX_TARGET_NAME]] = dpjit
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = device_func
jit_registry[target_registry[SPIRV_TARGET_NAME]] = device_func
84 changes: 84 additions & 0 deletions numba_dpex/core/typing/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
from numba.extending import typeof_impl
from numba.np import numpy_support

from numba_dpex.kernel_api import AtomicRef, Group, Item, LocalAccessor, NdItem
from numba_dpex.kernel_api.ranges import NdRange, Range
from numba_dpex.utils.constants import address_space

from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
from ..types.dpnp_ndarray_type import DpnpNdArray
from ..types.kernel_api.atomic_ref import AtomicRefType
from ..types.kernel_api.index_space_ids import GroupType, ItemType, NdItemType
from ..types.kernel_api.local_accessor import LocalAccessorType
from ..types.kernel_api.ranges import NdRangeType, RangeType
from ..types.usm_ndarray_type import USMNdArray

Expand Down Expand Up @@ -150,3 +154,83 @@ def typeof_ndrange(val, c):
Returns: A numba_dpex.core.types.range_types.RangeType instance.
"""
return NdRangeType(val.global_range.ndim)


@typeof_impl.register(AtomicRef)
def typeof_atomic_ref(val: AtomicRef, ctx) -> AtomicRefType:
"""Returns a ``numba_dpex.experimental.dpctpp_types.AtomicRefType``
instance for a Python AtomicRef object.
Args:
val (AtomicRef): Instance of the AtomicRef type.
ctx : Numba typing context used for type inference.
Returns: AtomicRefType object corresponding to the AtomicRef object.
"""
dtype = typeof_impl(val.ref, ctx)

return AtomicRefType(
dtype=dtype,
memory_order=val.memory_order.value,
memory_scope=val.memory_scope.value,
address_space=val.address_space.value,
)


@typeof_impl.register(Group)
def typeof_group(val: Group, c):
"""Registers the type inference implementation function for a
numba_dpex.kernel_api.Group PyObject.
Args:
val : An instance of numba_dpex.kernel_api.Group.
c : Unused argument used to be consistent with Numba API.
Returns: A numba_dpex.experimental.core.types.kernel_api.items.GroupType
instance.
"""
return GroupType(val.ndim)


@typeof_impl.register(Item)
def typeof_item(val: Item, c):
"""Registers the type inference implementation function for a
numba_dpex.kernel_api.Item PyObject.
Args:
val : An instance of numba_dpex.kernel_api.Item.
c : Unused argument used to be consistent with Numba API.
Returns: A numba_dpex.experimental.core.types.kernel_api.items.ItemType
instance.
"""
return ItemType(val.dimensions)


@typeof_impl.register(NdItem)
def typeof_nditem(val: NdItem, c):
"""Registers the type inference implementation function for a
numba_dpex.kernel_api.NdItem PyObject.
Args:
val : An instance of numba_dpex.kernel_api.NdItem.
c : Unused argument used to be consistent with Numba API.
Returns: A numba_dpex.experimental.core.types.kernel_api.items.NdItemType
instance.
"""
return NdItemType(val.dimensions)


@typeof_impl.register(LocalAccessor)
def typeof_local_accessor(val: LocalAccessor, c) -> LocalAccessorType:
"""Returns a ``numba_dpex.experimental.dpctpp_types.LocalAccessorType``
instance for a Python LocalAccessor object.
Args:
val (LocalAccessor): Instance of the LocalAccessor type.
c : Numba typing context used for type inference.
Returns: LocalAccessorType object corresponding to the LocalAccessor object.
"""
# pylint: disable=protected-access
return LocalAccessorType(ndim=len(val._shape), dtype=val._dtype)
25 changes: 0 additions & 25 deletions numba_dpex/experimental/__init__.py

This file was deleted.

32 changes: 0 additions & 32 deletions numba_dpex/experimental/models.py

This file was deleted.

97 changes: 0 additions & 97 deletions numba_dpex/experimental/target.py

This file was deleted.

32 changes: 0 additions & 32 deletions numba_dpex/experimental/testing.py

This file was deleted.

Loading

0 comments on commit 53bb704

Please sign in to comment.