Skip to content

Commit

Permalink
Merge pull request #1274 from IntelPython/feature/use_dpjit_specific_…
Browse files Browse the repository at this point in the history
…data_model

Use dpjit specific data model
  • Loading branch information
ZzEeKkAa committed Mar 29, 2024
2 parents af226f3 + 9bb4a47 commit f6a79b4
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 57 deletions.
5 changes: 3 additions & 2 deletions numba_dpex/core/boxing/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from contextlib import ExitStack

from numba.core import cgutils, types
from numba.core.datamodel import default_manager
from numba.extending import NativeValue, box, unbox

from numba_dpex.core.types import NdRangeType, RangeType
Expand Down Expand Up @@ -78,7 +77,9 @@ def unbox_ndrange(typ, obj, c):
].value
local_range_struct = ndrange_attr_native_value_map["local_range"].value

range_datamodel = default_manager.lookup(RangeType(typ.ndim))
range_datamodel = c.context.data_model_manager.lookup(
RangeType(typ.ndim)
)
ndrange_struct.ndim = c.builder.extract_value(
global_range_struct,
range_datamodel.get_field_position("ndim"),
Expand Down
1 change: 0 additions & 1 deletion numba_dpex/core/boxing/usm_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from contextlib import ExitStack

from numba.core import cgutils, types
from numba.core.datamodel import default_manager
from numba.core.errors import NumbaNotImplementedError
from numba.extending import NativeValue, box, unbox
from numba.np import numpy_support
Expand Down
64 changes: 35 additions & 29 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from llvmlite import ir as llvmir
from numba.core import datamodel, types
from numba.core.datamodel.models import OpaqueModel, PrimitiveModel, StructModel
from numba.core.extending import register_model

from numba_dpex.core.exceptions import UnreachableError
from numba_dpex.core.types.kernel_api.atomic_ref import AtomicRefType
Expand Down Expand Up @@ -316,7 +315,7 @@ def __init__(self, dmm, fe_type):
super().__init__(dmm, fe_type, members)


def _init_data_model_manager() -> datamodel.DataModelManager:
def _init_kernel_data_model_manager() -> datamodel.DataModelManager:
"""Initializes a data model manager used by the SPRIVTarget.
SPIRV kernel functions for certain types of devices require an explicit
Expand Down Expand Up @@ -370,43 +369,50 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
return dmm


dpex_data_model_manager = _init_data_model_manager()
def _init_dpjit_data_model_manager() -> datamodel.DataModelManager:
# TODO: copy manager
dmm = datamodel.default_manager

# Register the USMNdArray type to USMArrayHostModel in numba's default data
# model manager
dmm.register(USMNdArray, USMArrayHostModel)

# Register the USMNdArray type to USMArrayDeviceModel in numba's default data
# model manager
register_model(USMNdArray)(USMArrayHostModel)
# Register the DpnpNdArray type to USMArrayHostModel in numba's default data
# model manager
dmm.register(DpnpNdArray, USMArrayHostModel)

# Register the DpnpNdArray type to USMArrayHostModel in numba's default data
# model manager
register_model(DpnpNdArray)(USMArrayHostModel)
# Register the DpctlSyclQueue type
dmm.register(DpctlSyclQueue, SyclQueueModel)

# Register the DpctlSyclQueue type
register_model(DpctlSyclQueue)(SyclQueueModel)
# Register the DpctlSyclEvent type
dmm.register(DpctlSyclEvent, SyclEventModel)

# Register the DpctlSyclEvent type
register_model(DpctlSyclEvent)(SyclEventModel)
# Register the RangeType type
dmm.register(RangeType, RangeModel)

# Register the RangeType type
register_model(RangeType)(RangeModel)
# Register the NdRangeType type
dmm.register(NdRangeType, NdRangeModel)

# Register the NdRangeType type
register_model(NdRangeType)(NdRangeModel)
# Register the GroupType type
dmm.register(GroupType, EmptyStructModel)

# Register the GroupType type
register_model(GroupType)(EmptyStructModel)
# Register the ItemType type
dmm.register(ItemType, EmptyStructModel)

# Register the ItemType type
register_model(ItemType)(EmptyStructModel)
# Register the NdItemType type
dmm.register(NdItemType, EmptyStructModel)

# Register the MDLocalAccessorType type
dmm.register(DpctlMDLocalAccessorType, DpctlMDLocalAccessorModel)

# Register the NdItemType type
register_model(NdItemType)(EmptyStructModel)
# Register the LocalAccessorType type
dmm.register(LocalAccessorType, LocalAccessorModel)

# Register the MDLocalAccessorType type
register_model(DpctlMDLocalAccessorType)(DpctlMDLocalAccessorModel)
# Register the KernelDispatcherType type
dmm.register(KernelDispatcherType, OpaqueModel)

return dmm

# Register the LocalAccessorType type
register_model(LocalAccessorType)(LocalAccessorModel)

# Register the KernelDispatcherType type
register_model(KernelDispatcherType)(OpaqueModel)
dpex_data_model_manager = _init_kernel_data_model_manager()
dpjit_data_model_manager = _init_dpjit_data_model_manager()
4 changes: 2 additions & 2 deletions numba_dpex/core/kernel_interface/ranges_overloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from llvmlite import ir as llvmir
from numba.core import cgutils, errors, types
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload

from numba_dpex.kernel_api import NdRange, Range
Expand Down Expand Up @@ -60,11 +59,12 @@ def _intrin_ndrange_alloc(
ty_local_range,
ty_ndrange,
)
range_datamodel = default_manager.lookup(ty_global_range)

def codegen(context, builder, sig, args):
typ = sig.return_type

range_datamodel = context.data_model_manager.lookup(ty_global_range)

global_range, local_range, _ = args
ndrange_struct = cgutils.create_struct_proxy(typ)(context, builder)
ndrange_struct.ndim = llvmir.Constant(
Expand Down
3 changes: 3 additions & 0 deletions numba_dpex/core/targets/dpjit_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numba.core.imputils import Registry
from numba.core.target_extension import CPU, target_registry

from numba_dpex.core.datamodel.models import _init_dpjit_data_model_manager
from numba_dpex.dpnp_iface import dpnp_ufunc_db


Expand Down Expand Up @@ -49,6 +50,8 @@ def init(self):
self.lower_extensions = {}
super().init()

self.data_model_manager = _init_dpjit_data_model_manager()

# TODO: initialize nrt once switched to nrt from drt. Most likely we
# call it somewhere. Double check.
# https://github.com/IntelPython/numba-dpex/issues/1175
Expand Down
3 changes: 1 addition & 2 deletions numba_dpex/dpctl_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import dpctl
from llvmlite.ir import IRBuilder
from numba import types
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload, overload_method

import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
Expand Down Expand Up @@ -45,7 +44,7 @@ def sycl_event_wait(typingctx, ty_event: dpex_types.DpctlSyclEvent):

# defines the custom code generation
def codegen(context, builder, signature, args):
sycl_event_dm = default_manager.lookup(ty_event)
sycl_event_dm = context.data_model_manager.lookup(ty_event)
event_ref = builder.extract_value(
args[0],
sycl_event_dm.get_field_position("event_ref"),
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from numba.core.types.scalars import IntEnumClass
from numba.core.typing import cmathdecl, enumdecl

from numba_dpex.core.datamodel.models import _init_data_model_manager
from numba_dpex.core.datamodel.models import _init_kernel_data_model_manager
from numba_dpex.core.types import IntEnumLiteral
from numba_dpex.core.typing import dpnpdecl
from numba_dpex.kernel_api.flag_enum import FlagEnum
Expand Down Expand Up @@ -154,7 +154,7 @@ def init(self):
)

# Override data model manager to SPIR model
self.data_model_manager = _init_data_model_manager()
self.data_model_manager = _init_kernel_data_model_manager()
self.extra_compile_options = {}

_lazy_init_dpnp_db()
Expand Down
8 changes: 3 additions & 5 deletions numba_dpex/tests/core/types/DpctlSyclEvent/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
#
# SPDX-License-Identifier: Apache-2.0

import dpctl
from numba import types
from numba.core.datamodel import default_manager, models
from numba.core.datamodel import models

from numba_dpex.core.datamodel.models import (
SyclEventModel,
dpex_data_model_manager,
dpjit_data_model_manager,
)
from numba_dpex.core.types.dpctl_types import DpctlSyclEvent

Expand All @@ -18,7 +16,7 @@ def test_model_for_DpctlSyclEvent():
default data model manager.
"""
sycl_event = DpctlSyclEvent()
default_model = default_manager.lookup(sycl_event)
default_model = dpjit_data_model_manager.lookup(sycl_event)
assert isinstance(default_model, SyclEventModel)


Expand Down
22 changes: 10 additions & 12 deletions numba_dpex/tests/core/types/range_types/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
from numba.core.datamodel import default_manager
from numba.core.registry import cpu_target

from numba_dpex.core.datamodel.models import (
NdRangeModel,
RangeModel,
dpex_data_model_manager,
dpjit_data_model_manager,
)
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.descriptor import dpex_kernel_target, dpex_target
from numba_dpex.core.types.kernel_api.ranges import NdRangeType, RangeType

rfields = ["ndim", "dim0", "dim1", "dim2"]
Expand All @@ -30,8 +29,8 @@ def test_datamodel_registration():
dpex_data_model_manager.lookup(range_ty)
dpex_data_model_manager.lookup(ndrange_ty)

default_range_model = default_manager.lookup(range_ty)
default_ndrange_model = default_manager.lookup(ndrange_ty)
default_range_model = dpjit_data_model_manager.lookup(range_ty)
default_ndrange_model = dpjit_data_model_manager.lookup(ndrange_ty)

assert isinstance(default_range_model, RangeModel)
assert isinstance(default_ndrange_model, NdRangeModel)
Expand All @@ -43,7 +42,7 @@ def test_range_model_fields(field):
RangeType
"""
range_ty = RangeType(ndim=1)
dm = default_manager.lookup(range_ty)
dm = dpjit_data_model_manager.lookup(range_ty)
try:
dm.get_field_position(field)
except:
Expand All @@ -56,7 +55,7 @@ def test_ndrange_model_fields(field):
NdRangeType
"""
ndrange_ty = NdRangeType(ndim=1)
dm = default_manager.lookup(ndrange_ty)
dm = dpjit_data_model_manager.lookup(ndrange_ty)
try:
dm.get_field_position(field)
except:
Expand All @@ -69,15 +68,14 @@ def test_flattened_member_count(range_type):
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager
dpjit_target_ctx = dpex_target.target_context
dpjit_dmm = dpjit_target_ctx.data_model_manager

for ndim in range(1, 3):
dty = range_type(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
datamodel = dpjit_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)
ap = dpjit_target_ctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)
5 changes: 3 additions & 2 deletions numba_dpex/tests/core/types/test_array_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

import pytest
from numba import types
from numba.core.datamodel import default_manager, models
from numba.core.datamodel import models
from numba.core.registry import cpu_target

from numba_dpex.core.datamodel.models import (
USMArrayDeviceModel,
USMArrayHostModel,
dpex_data_model_manager,
dpjit_data_model_manager,
)
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray, USMNdArray

Expand All @@ -32,7 +33,7 @@ def test_model_for_array(nd_array):
"""
device_model = dpex_data_model_manager.lookup(nd_array)
assert isinstance(device_model, USMArrayDeviceModel)
host_model = default_manager.lookup(nd_array)
host_model = dpjit_data_model_manager.lookup(nd_array)
assert isinstance(host_model, USMArrayHostModel)


Expand Down

0 comments on commit f6a79b4

Please sign in to comment.