Skip to content

Commit

Permalink
Move to Python3.10 type hints (remove use of Optional and Union)
Browse files Browse the repository at this point in the history
  • Loading branch information
hunhoffe committed Sep 30, 2024
1 parent 0c9dd9f commit b4a7759
Show file tree
Hide file tree
Showing 13 changed files with 147 additions and 141 deletions.
46 changes: 22 additions & 24 deletions python/dialects/aie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from dataclasses import dataclass
import inspect
from typing import List, Optional, Tuple, Union, Dict, Any
from typing import List, Tuple, Dict, Any
import contextlib

import numpy as np
Expand Down Expand Up @@ -33,10 +33,6 @@
transaction_binary_to_mlir,
)
from ..extras import types as T
from ..extras.dialects.ext.arith import constant

# noinspection PyUnresolvedReferences
from ..extras.dialects.ext import memref
from ..extras.meta import region_op

# this is inside the aie-python-extras (shared) namespace package
Expand All @@ -50,20 +46,19 @@
)

from ..ir import (
ArrayAttr,
Attribute,
Block,
BlockList,
DenseElementsAttr,
DictAttr,
FlatSymbolRefAttr,
FunctionType,
InsertionPoint,
IntegerAttr,
IntegerType,
TypeAttr,
UnitAttr,
_i32ArrayAttr,
Location,
)

# Comes from _aie
Expand All @@ -72,7 +67,7 @@


class external_func(FuncOp):
def __init__(self, name, inputs, outputs=None, visibility="private"):
def __init__(self, name: str, inputs, outputs=None, visibility="private"):
if outputs is None:
outputs = []
super().__init__(
Expand All @@ -88,9 +83,7 @@ def bd_dim_layout(size, stride):


@register_attribute_builder("BDDimLayoutArrayAttr")
def bd_dim_layout_array_attr_builder(
tups: List[Union[Attribute, Tuple[int]]], context=None
):
def bd_dim_layout_array_attr_builder(tups: List[Attribute | Tuple[int]], context=None):
if isinstance(tups, list) and all(isinstance(t, tuple) for t in tups):
tups = list(map(lambda t: bd_dim_layout(*t), tups))
return Attribute.parse(
Expand Down Expand Up @@ -208,7 +201,7 @@ def bds(parent):

class Core(CoreOp):
# Until https://github.com/llvm/llvm-project/pull/73620 gets figured out.
def __init__(self, tile, link_with=None):
def __init__(self, tile, link_with: str | None = None):
super().__init__(result=T.index(), tile=tile, link_with=link_with)


Expand All @@ -220,7 +213,14 @@ def __init__(self):
raise ValueError("Should never be called")

def __new__(
cls, tile, shape, datatype, name=None, initial_value=None, loc=None, ip=None
cls,
tile,
shape,
datatype,
name=None,
initial_value: np.ndarray | None = None,
loc=None,
ip=None,
):
if initial_value is not None:
assert isinstance(initial_value, np.ndarray)
Expand All @@ -247,7 +247,7 @@ class ExternalBuffer(MemRef):
def __init__(self):
raise ValueError("Should never be called")

def __new__(cls, shape, datatype, name=None, loc=None, ip=None):
def __new__(cls, shape, datatype, name: str | None = None, loc=None, ip=None):
my_buffer = ExternalBufferOp(
buffer=T.memref(*shape, datatype),
sym_name=name,
Expand Down Expand Up @@ -361,7 +361,7 @@ def __init__(
dest,
dest_port,
dest_channel,
keep_pkt_header: Optional[bool] = None,
keep_pkt_header: bool | None = None,
):
super().__init__(ID=pkt_id, keep_pkt_header=keep_pkt_header)
bb = Block.create_at_start(self.ports)
Expand Down Expand Up @@ -435,9 +435,9 @@ def __init__(
channel_dir,
channel_index,
*,
dest: Optional[Union[Successor, Block]] = None,
chain: Optional[Union[Successor, Block]] = None,
repeat_count: Optional[int] = None,
dest: Successor | Block | None = None,
chain: Successor | Block | None = None,
repeat_count: int | None = None,
loc=None,
ip=None,
):
Expand Down Expand Up @@ -472,8 +472,8 @@ def dma_start(
channel_dir,
channel_index,
*,
dest: Optional[Union[Successor, Block]] = None,
chain: Optional[Union[Successor, Block]] = None,
dest: Successor | Block | None = None,
chain: Successor | Block | None = None,
loc=None,
ip=None,
):
Expand All @@ -483,9 +483,7 @@ def dma_start(

@_cext.register_operation(_Dialect, replace=True)
class NextBDOp(NextBDOp):
def __init__(
self, dest: Optional[Union[Successor, Block]] = None, *, loc=None, ip=None
):
def __init__(self, dest: Successor | Block | None = None, *, loc=None, ip=None):
if isinstance(dest, Successor):
dest = dest.block
if dest is None:
Expand All @@ -500,7 +498,7 @@ def dest(self):


def next_bd(
dest: Optional[Union[Successor, Block, ContextManagedBlock]] = None,
dest: Successor | Block | ContextManagedBlock | None = None,
loc=None,
ip=None,
):
Expand Down
12 changes: 5 additions & 7 deletions python/dialects/aiex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial
import itertools
from operator import itemgetter
from typing import Union, Optional

import numpy as np

Expand All @@ -16,7 +15,6 @@
LockAction,
Neighbors,
TileOp,
object_fifo,
find_matching_buffers,
find_matching_flows,
find_matching_locks,
Expand All @@ -37,7 +35,7 @@
npu_sync = partial(npu_sync, column_num=1, row_num=1)


def dma_wait(*args: Union[ObjectFifoCreateOp, str]):
def dma_wait(*args: ObjectFifoCreateOp | str):
if len(args) == 0:
raise ValueError(
"dma_wait must receive at least one dma_meta information to wait for"
Expand All @@ -54,13 +52,13 @@ class NpuDmaMemcpyNd(NpuDmaMemcpyNdOp):

def __init__(
self,
metadata: Union[str, ObjectFifoCreateOp],
metadata: str | ObjectFifoCreateOp,
bd_id,
mem,
offsets: MixedValues = None,
sizes: MixedValues = None,
strides: MixedValues = None,
issue_token: Optional[bool] = None,
issue_token: bool | None = None,
):
x = 0
y = 0
Expand Down Expand Up @@ -700,8 +698,8 @@ def __repr__(self):


def broadcast_flow(
source: Union[np.ndarray, TileOp],
dest: Union[np.ndarray, TileOp],
source: np.ndarray | TileOp,
dest: np.ndarray | TileOp,
source_bundle=None,
source_channel=None,
dest_bundle=None,
Expand Down
9 changes: 4 additions & 5 deletions python/extras/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass
from typing import Optional

from .. import ir

Expand All @@ -17,9 +16,9 @@ def __str__(self):

@contextmanager
def mlir_mod_ctx(
src: Optional[str] = None,
context: ir.Context = None,
location: ir.Location = None,
src: str | None = None,
context: ir.Context | None = None,
location: ir.Location | None = None,
allow_unregistered_dialects=False,
) -> MLIRContext:
if context is None:
Expand All @@ -45,7 +44,7 @@ class RAIIMLIRContext:
context: ir.Context
location: ir.Location

def __init__(self, location: Optional[ir.Location] = None):
def __init__(self, location: ir.Location | None = None):
self.context = ir.Context()
self.context.__enter__()
if location is None:
Expand Down
Empty file.
29 changes: 15 additions & 14 deletions python/extras/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from abc import abstractmethod
from copy import deepcopy
from functools import cached_property, partialmethod
from typing import Optional, Tuple
import numpy as np
from typing import Tuple

from ...util import get_user_code_loc, infer_mlir_type, mlir_type_to_np_dtype
from ...._mlir_libs._mlir import register_value_caster
Expand Down Expand Up @@ -43,13 +44,13 @@


def constant(
value: Union[int, float, bool, np.ndarray],
type: Optional[Type] = None,
index: Optional[bool] = None,
value: int | float | bool | np.ndarray,
type: Type | None = None,
index: bool | None = None,
*,
vector: Optional[bool] = False,
loc: Location = None,
ip: InsertionPoint = None,
vector: bool | None = False,
loc: Location | None = None,
ip: InsertionPoint | None = None,
) -> Value:
"""Instantiate arith.constant with value `value`.
Expand Down Expand Up @@ -215,7 +216,7 @@ def __call__(cls, *args, **kwargs):


@register_attribute_builder("Arith_CmpIPredicateAttr", replace=True)
def _arith_CmpIPredicateAttr(predicate: Union[str, Attribute], context: Context):
def _arith_CmpIPredicateAttr(predicate: str | Attribute, context: Context):
predicates = {
"eq": CmpIPredicate.eq,
"ne": CmpIPredicate.ne,
Expand All @@ -235,7 +236,7 @@ def _arith_CmpIPredicateAttr(predicate: Union[str, Attribute], context: Context)


@register_attribute_builder("Arith_CmpFPredicateAttr", replace=True)
def _arith_CmpFPredicateAttr(predicate: Union[str, Attribute], context: Context):
def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context):
predicates = {
"false": CmpFPredicate.AlwaysFalse,
# ordered comparison
Expand Down Expand Up @@ -271,10 +272,10 @@ def _binary_op(
lhs: "ArithValue",
rhs: "ArithValue",
op: str,
predicate: str = None,
signedness: str = None,
predicate: str | None = None,
signedness: str | None = None,
*,
loc: Location = None,
loc: Location | None = None,
) -> "ArithValue":
"""Generic for handling infix binary operator dispatch.
Expand Down Expand Up @@ -373,7 +374,7 @@ class ArithValue(Value, metaclass=ArithValueMeta):
Value.__init__
"""

def __init__(self, val, *, fold: Optional[bool] = None):
def __init__(self, val, *, fold: bool | None = None):
self._fold = fold if fold is not None else False
super().__init__(val)

Expand Down Expand Up @@ -464,7 +465,7 @@ def dtype(self) -> Type:
return self.type

@cached_property
def literal_value(self) -> Union[int, float, bool]:
def literal_value(self) -> int | float | bool:
if not self.is_constant():
raise ValueError("Can't build literal from non-constant Scalar")
return self.owner.opview.literal_value
Expand Down
Loading

0 comments on commit b4a7759

Please sign in to comment.