diff --git a/arkouda/array_api/utility_functions.py b/arkouda/array_api/utility_functions.py index 05eaf907d..ea79f4eb1 100644 --- a/arkouda/array_api/utility_functions.py +++ b/arkouda/array_api/utility_functions.py @@ -68,6 +68,9 @@ def clip(a: Array, a_min, a_max, /) -> Array: a_max : scalar The maximum value """ + if a.dtype == ak.bigint or a.dtype == ak.bool_: + raise RuntimeError(f"Error executing command: clip does not support dtype {a.dtype}") + return Array._new( create_pdarray( generic_msg( @@ -99,6 +102,9 @@ def diff(a: Array, /, n: int = 1, axis: int = -1, prepend=None, append=None) -> append : Array, optional Array to append to `a` along `axis` before calculating the difference. """ + if a.dtype == ak.bigint or a.dtype == ak.bool_: + raise RuntimeError(f"Error executing command: diff does not support dtype {a.dtype}") + if prepend is not None and append is not None: a_ = concat((prepend, a, append), axis=axis) elif prepend is not None: @@ -146,6 +152,9 @@ def pad( if mode != "constant": raise NotImplementedError(f"pad mode '{mode}' is not supported") + if array.dtype == ak.bigint: + raise RuntimeError("Error executing command: pad does not support dtype bigint") + if "constant_values" not in kwargs: cvals = 0 else: diff --git a/src/ArgSortMsg.chpl b/src/ArgSortMsg.chpl index 8e636255f..9ec94f2b2 100644 --- a/src/ArgSortMsg.chpl +++ b/src/ArgSortMsg.chpl @@ -434,17 +434,11 @@ module ArgSortMsg axis = msgArgs["axis"].toScalar(int), symEntry = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd), vals = if (array_dtype == bool) then (symEntry.a:int) else (symEntry.a: array_dtype); - + const iv = argsortDefault(vals, algorithm=algorithm, axis); return st.insert(new shared SymEntry(iv)); } - proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8)) - { - return MsgTuple.error("argsort does not support the %s dtype".format(array_dtype:string)); - } - proc argsortStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws { const name = msgArgs["name"].toScalar(string), strings = getSegString(name, st), diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index 71fa35100..a4cec46bc 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -945,9 +945,9 @@ module AryUtil flatten a multi-dimensional array into a 1D array */ @arkouda.registerCommand - proc flatten(const ref a: [?d] ?t): [] t throws - where a.rank > 1 - { + proc flatten(const ref a: [?d] ?t): [] t throws { + if a.rank == 1 then return a; + var flat = makeDistArray(d.size, t); // ranges of flat indices owned by each locale @@ -1004,12 +1004,6 @@ module AryUtil return flat; } - proc flatten(const ref a: [?d] ?t): [] t throws - where a.rank == 1 - { - return a; - } - // helper for computing an array element's index from its order record orderer { param rank: int; diff --git a/src/CastMsg.chpl b/src/CastMsg.chpl index 91453a316..b1417b7df 100644 --- a/src/CastMsg.chpl +++ b/src/CastMsg.chpl @@ -15,17 +15,13 @@ module CastMsg { private config const logChannel = ServerConfig.logChannel; const castLogger = new Logger(logLevel, logChannel); - proc isFloatingType(type t) param : bool { - return isRealType(t) || isImagType(t) || isComplexType(t); - } - @arkouda.instantiateAndRegister(prefix="cast") proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_from, type array_dtype_to, param array_nd: int ): MsgTuple throws - where !(isFloatingType(array_dtype_from) && array_dtype_to == bigint) && + where !((isRealType(array_dtype_from) || isImagType(array_dtype_from) || isComplexType(array_dtype_from)) && array_dtype_to == bigint) && !(array_dtype_from == bigint && array_dtype_to == bool) { const a = st[msgArgs["name"]]: SymEntry(array_dtype_from, array_nd); @@ -40,22 +36,6 @@ module CastMsg { } } - // cannot cast float types to bigint, cannot cast bigint to bool - proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, - type array_dtype_from, - type array_dtype_to, - param array_nd: int - ): MsgTuple throws - where (isFloatingType(array_dtype_from) && array_dtype_to == bigint) || - (array_dtype_from == bigint && array_dtype_to == bool) - { - return MsgTuple.error( - "cannot cast array of type %s to %s".format( - type2str(array_dtype_from), - type2str(array_dtype_to) - )); - } - @arkouda.instantiateAndRegister(prefix="castToStrings") proc castArrayToStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws { const name = msgArgs["name"].toScalar(string); diff --git a/src/GenSymIO.chpl b/src/GenSymIO.chpl index 75a83bc1d..9efa25b31 100644 --- a/src/GenSymIO.chpl +++ b/src/GenSymIO.chpl @@ -43,12 +43,6 @@ module GenSymIO { return st.insert(new shared SymEntry(makeArrayFromBytes(msgArgs.payload, shape, array_dtype))); } - proc array(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_dtype == bigint - { - return MsgTuple.error("Array creation from binary payload is not supported for bigint arrays"); - } - proc makeArrayFromBytes(ref payload: bytes, shape: ?N*int, type t): [] t throws { var size = 1; for s in shape do size *= s; @@ -138,12 +132,6 @@ module GenSymIO { return MsgTuple.payload(bytes.createAdoptingBuffer(ptr:c_ptr(uint(8)), size, size)); } - proc tondarray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_dtype == bigint - { - return MsgTuple.error("cannot create ndarray from bigint array"); - } - /* * Utility proc to test casting a string to a specified type * :arg c: String to cast diff --git a/src/IndexingMsg.chpl b/src/IndexingMsg.chpl index 96a65a925..be7c646dc 100644 --- a/src/IndexingMsg.chpl +++ b/src/IndexingMsg.chpl @@ -206,12 +206,6 @@ module IndexingMsg } } - proc multiPDArrayIndex(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_a, type array_dtype_idx, param array_nd: int): MsgTuple throws - where array_dtype_idx != int && array_dtype_idx != uint - { - return MsgTuple.error("Invalid index type: %s; must be 'int' or 'uint'".format(type2str(array_dtype_idx))); - } - private proc multiIndexShape(inShape: ?N*int, idxDims: [?d] int, outSize: int): (bool, int, N*int) { var minShape: N*int = inShape, firstRank = -1; @@ -960,14 +954,6 @@ module IndexingMsg return st.insert(new shared SymEntry(y, x.max_bits)); } - proc takeAlongAxis(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, - type array_dtype_x, - type array_dtype_idx, - param array_nd: int - ): MsgTuple throws { - return MsgTuple.error("Cannot take along axis with non-integer index array"); - } - use CommandMap; registerFunction("arrayViewMixedIndex", arrayViewMixedIndexMsg, getModuleName()); registerFunction("[pdarray]", pdarrayIndexMsg, getModuleName()); diff --git a/src/LinalgMsg.chpl b/src/LinalgMsg.chpl index 855d35cf7..a4b6d4e30 100644 --- a/src/LinalgMsg.chpl +++ b/src/LinalgMsg.chpl @@ -61,13 +61,6 @@ module LinalgMsg { return st.insert(e); } - - proc eye(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("eye does not support the bigint dtype"); - } - // tril and triu are identical except for the argument they pass to triluHandler (true for upper, false for lower) // The zeros are written into the upper (or lower) triangle of the array, offset by the value of diag. @@ -79,11 +72,6 @@ module LinalgMsg { return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, false); } - proc tril(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_nd < 2 { - return MsgTuple.error("Array must be at least 2 dimensional for 'tril'"); - } - // Create an array from an existing array with its lower triangle zeroed out @arkouda.instantiateAndRegister @@ -92,13 +80,9 @@ module LinalgMsg { return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, true); } - proc triu(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_nd < 2 { - return MsgTuple.error("Array must be at least 2 dimensional for 'triu'"); - } - - // Fetch the arguments, call zeroTri, return result. - + // Fetch the arguments, call zeroTri, return result. + // TODO: support instantiating param bools with 'true' and 'false' s.t. we'd have 'triluHandler' and 'triluHandler' + // cmds if this procedure were annotated instead of the two above. proc triluHandler(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int, param upper: bool ): MsgTuple throws { @@ -195,16 +179,6 @@ module LinalgMsg { } - proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_nd < 2) && (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) { - return MsgTuple.error("Matrix multiplication with arrays of dimension < 2 is not supported"); - } - - proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) { - return MsgTuple.error("Matrix multiplication with arrays of bigint type is not supported"); - } - proc compute_result_type_matmul(type t1, type t2) type { if t1 == real || t2 == real then return real; if t1 == int || t2 == int then return int; @@ -302,11 +276,6 @@ module LinalgMsg { return ret; } - proc transpose(array: [?d] ?t): [d] t throws - where d.rank < 2 { - throw new Error("Matrix transpose with arrays of dimension < 2 is not supported"); - } - /* Compute the generalized dot product of two tensors along the specified axis. @@ -366,22 +335,6 @@ module LinalgMsg { return bool; } - proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_nd < 2) && ((array_dtype_x1 != bool) || (array_dtype_x2 != bool)) - && (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) { - return MsgTuple.error("VecDot with arrays of dimension < 2 is not supported"); - } - - proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_dtype_x1 == bool) && (array_dtype_x2 == bool) { - return MsgTuple.error("VecDot with arrays both of type bool is not supported"); - } - - proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws - where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) { - return MsgTuple.error("VecDot with arrays of type bigint is not supported"); - } - // @arkouda.registerND(???) // proc tensorDotMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd1: int, param nd2: int): MsgTuple throws { // if nd < 3 { diff --git a/src/MsgProcessing.chpl b/src/MsgProcessing.chpl index df95f5d7b..691782288 100644 --- a/src/MsgProcessing.chpl +++ b/src/MsgProcessing.chpl @@ -339,11 +339,6 @@ module MsgProcessing return msg; } - proc chunkInfoAsString(array: [?d] ?t): string throws - where (t != bool) && (t != int(64)) && (t != uint(64)) && (t != uint(8)) && (t != real){ - throw new Error("chunkInfo does not support dtype %s".format(t:string)); - } - @arkouda.registerCommand proc chunkInfoAsArray(array: [?d] ?t):[] int throws where (t == bool) || (t == int(64)) || (t == uint(64)) || (t == uint(8)) ||(t == real) { @@ -357,9 +352,4 @@ module MsgProcessing } return blockSizes; } - - proc chunkInfoAsArray(array: [?d] ?t): [d] int throws - where (t != bool) && (t != int(64)) && (t != uint(64)) && (t != uint(8)) && (t != real){ - throw new Error("chunkInfo does not support dtype %s".format(t:string)); - } } diff --git a/src/RandMsg.chpl b/src/RandMsg.chpl index 3d2e9b198..69fafb30d 100644 --- a/src/RandMsg.chpl +++ b/src/RandMsg.chpl @@ -75,12 +75,6 @@ module RandMsg return st.insert(e); } - proc randint(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("randint does not support the bigint dtype"); - } - @arkouda.instantiateAndRegister proc randomNormal(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd: int): MsgTuple throws { const shape = msgArgs["shape"].toScalarTuple(int, array_nd), @@ -117,12 +111,6 @@ module RandMsg return st.insert(new shared GeneratorSymEntry(generator, state)); } - proc createGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("createGenerator does not support the bigint dtype"); - } - @arkouda.instantiateAndRegister proc uniformGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws where array_dtype != BigInteger.bigint @@ -151,12 +139,6 @@ module RandMsg return st.insert(uniformEntry); } - proc uniformGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("uniformGenerator does not support the bigint dtype"); - } - /* Use the ziggurat method (https://en.wikipedia.org/wiki/Ziggurat_algorithm#Theory_of_operation) @@ -252,6 +234,9 @@ module RandMsg @arkouda.instantiateAndRegister proc standardNormalGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws + do return standardNormalGeneratorHelp(cmd, msgArgs, st, array_nd); + + proc standardNormalGeneratorHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws where array_nd == 1 { const name = msgArgs["name"], // generator name @@ -287,7 +272,7 @@ module RandMsg } - proc standardNormalGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws + proc standardNormalGeneratorHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws where array_nd > 1 { const name = msgArgs["name"], // generator name @@ -387,6 +372,9 @@ module RandMsg @arkouda.instantiateAndRegister proc standardExponential(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws + do return standardExponentialHelp(cmd, msgArgs, st, array_nd); + + proc standardExponentialHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws where array_nd == 1 { const name = msgArgs["name"], // generator name @@ -421,7 +409,7 @@ module RandMsg } } - proc standardExponential(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws + proc standardExponentialHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws where array_nd > 1 { const name = msgArgs["name"], // generator name @@ -567,12 +555,6 @@ module RandMsg } } - proc choice(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws - where array_dtype == BigInteger.bigint - { - return MsgTuple.error("choice does not support the bigint dtype"); - } - inline proc logisticGenerator(mu: real, scale: real, ref rs) { var U = rs.next(0, 1); @@ -693,7 +675,10 @@ module RandMsg } @arkouda.instantiateAndRegister - proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + do return shuffleHelp(cmd, msgArgs, st, array_dtype, array_nd); + + proc shuffleHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws where array_nd == 1 { const name = msgArgs["name"], @@ -715,7 +700,7 @@ module RandMsg return MsgTuple.success(); } - proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + proc shuffleHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws where array_nd != 1 { const name = msgArgs["name"], diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index dd69bf526..febaed486 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -380,16 +380,6 @@ module ReductionMsg // simple and efficient 'nonzero' implementation for 1D arrays - proc nonzero( - cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, - type array_dtype, - param array_nd: int - ): MsgTuple throws - where array_dtype == bigint - { - return MsgTuple.error("nonzero is not supported for bigint arrays"); - } - proc nonzero1D(x: [?d] ?t): [] int throws { const nTasksPerLoc = here.maxTaskPar; var nnzPerTask: [0.. 1) { - + proc sortHelp(array: [?d] ?t, alg: string, axis: int): [d] t throws + where d.rank > 1 + { var algorithm: SortingAlgorithm = ArgSortMsg.getSortingAlgorithm(alg); const itemsize = dtypeSize(whichDtype(t)); overMemLimit(radixSortLSD_keys_memEst(d.size, itemsize)); @@ -91,16 +95,11 @@ module SortMsg return sorted; } - proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws - where ((t != real) && (t!=int) && (t!=uint(64))) { - throw new Error("sort does not support type %s".format(type2str(t))); - } - // https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#array_api.searchsorted @arkouda.registerCommand proc searchSorted(x1: [?d1] real, x2: [?d2] real, side: string): [d2] int throws - where (d1.rank == 1) { - + where d1.rank == 1 + { if side != "left" && side != "right" { throw new Error("searchSorted side must be a string with value 'left' or 'right'."); } @@ -123,11 +122,6 @@ module SortMsg return ret; } - proc searchSorted(x1: [?d1] real, x2: [?d2] real, side: string): [d2] int throws - where (d1.rank != 1){ - throw new Error("searchSorted only arrays x1 of dimension 1."); - } - record leftCmp: relativeComparator { proc compare(a: real, b: real): int { if a < b then return -1; diff --git a/src/SparseMatrixMsg.chpl b/src/SparseMatrixMsg.chpl index dff40372a..7ec09af88 100644 --- a/src/SparseMatrixMsg.chpl +++ b/src/SparseMatrixMsg.chpl @@ -63,7 +63,7 @@ module SparseMatrixMsg { return MsgTuple.fromResponses(responses); } - @arkouda.registerCommand("fill_sparse_vals") + @arkouda.registerCommand("fill_sparse_vals", ignoreWhereClause=true) proc fillSparseMatrixMsg(matrix: borrowed SparseSymEntry(?), vals: [?d] ?t /* matrix.etype */) throws where t == matrix.etype && d.rank == 1 do fillSparseMatrix(matrix.a, vals, matrix.matLayout); diff --git a/src/StatsMsg.chpl b/src/StatsMsg.chpl index 05a9bda2b..1e67e55fc 100644 --- a/src/StatsMsg.chpl +++ b/src/StatsMsg.chpl @@ -85,15 +85,6 @@ module StatsMsg { return (+ reduce ((x:real - mx) * (y:real - my))) / (dx.size - 1):real; } - // above registration will instantiate `cov` for all combinations of array ranks - // even though it is only valid when the ranks are the same - // (respecting the where clause in the signature is future work for 'registerCommand') - proc cov(const ref x: [?dx], const ref y: [?dy]): real throws - where dx.rank != dy.rank - { - throw new Error("x and y must have the same rank"); - } - @arkouda.registerCommand() proc corr(const ref x: [?dx] ?tx, const ref y: [?dy] ?ty): real throws where dx.rank == dy.rank @@ -107,15 +98,6 @@ module StatsMsg { return cov(x, y) / (std(x, 1) * std(y, 1)); } - // above registration will instantiate `corr` for all combinations of array ranks - // even though it is only valid when the ranks are the same - // (respecting the where clause in the signature is future work for 'registerCommand') - proc corr(const ref x: [?dx], const ref y: [?dy]): real throws - where dx.rank != dy.rank - { - throw new Error("x and y must have the same rank"); - } - @arkouda.registerCommand() proc cumSum(const ref x: [?d] ?t, axis: int, includeInitial: bool): [] t throws { if d.rank == 1 { diff --git a/src/UtilMsg.chpl b/src/UtilMsg.chpl index a23bf24a8..20ded2c7f 100644 --- a/src/UtilMsg.chpl +++ b/src/UtilMsg.chpl @@ -43,11 +43,6 @@ module UtilMsg { return y; } - proc clip(const ref x: [?d] ?t, min: real, max: real): [d] t throws - where (t != int) && (t != real) && (t != uint(8)) && (t != uint(64)){ - throw new Error("clip does not support dtype %s".format(t:string)); - } - /* Compute the n'th order discrete difference along a given axis @@ -95,11 +90,6 @@ module UtilMsg { } } - proc diff(x: [?d] ?t, n: int, axis: int): [d] t throws - where (t != real) && (t != int) && (t != uint(8)) && (t != uint(64)){ - throw new Error("diff does not support dtype %s".format(t:string)); - } - // helper to create a domain that's 'n' elements smaller in the 'axis' dimension private proc subDomain(shape: ?N*int, axis: int, n: int) { var rngs: N*range; @@ -164,9 +154,4 @@ module UtilMsg { return st.insert(new shared SymEntry(paddedArray)); } - proc pad(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws - where (array_dtype != int) && (array_dtype != uint(8)) && (array_dtype != uint(64)) && (array_dtype != real) && (array_dtype != bool) { - throw new Error("pad does not support dtype %s".format(array_dtype:string)); - } - } diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index e7c3436c9..825f5f24a 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -7,7 +7,7 @@ DEFAULT_MODS = ["MsgProcessing", "GenSymIO"] -registerAttr = ("arkouda.registerCommand", ["name"]) +registerAttr = ("arkouda.registerCommand", ["name", "ignoreWhereClause"]) instAndRegisterAttr = ("arkouda.instantiateAndRegister", ["prefix"]) # chapel types and their numpy equivalents @@ -105,6 +105,9 @@ def __init__(self, name): def name(self): return self.name + def __str__(self): + return f"?{self.name}" + class FormalQueryRef: def __init__(self, name): @@ -113,6 +116,9 @@ def __init__(self, name): def name(self): return self.name + def __str__(self): + return f"QRef: '{self.name}'" + class StaticTypeInfo: def __init__(self, value): @@ -121,6 +127,9 @@ def __init__(self, value): def value(self): return self.value + def __str__(self): + return f"static: '{self.value}'" + class formalKind(Enum): ARRAY = 1 @@ -170,6 +179,9 @@ def stringify(self) -> str: else f"{self.storage_kind} {self.name}" ) + def __str__(self): + return f"{self.kind} [{self.storage_kind} {self.name}: {self.type_str}] (info: {self.info})" + def get_formals(fn, require_type_annotations): """ @@ -339,7 +351,7 @@ def clean_enum_name(name): def stamp_generic_command( - generic_proc_name, prefix, module_name, formals, line_num, is_user_proc + generic_proc_name, prefix, module_name, formals, line_num, iar_annotation ): """ Create code to stamp out and register a generic command using a generic @@ -376,8 +388,8 @@ def stamp_generic_command( stamp_formal_args = ", ".join([f"{k}={v}" for k, v in formals.items()]) - # use qualified naming if generic_proc belongs in a use defined module to avoid name conflicts - call = f"{module_name}.{generic_proc_name}" if is_user_proc else generic_proc_name + # use qualified naming if generic_proc belongs in a user defined module to avoid name conflicts + call = f"{module_name}.{generic_proc_name}" if iar_annotation else generic_proc_name proc = ( f"proc {stamp_name}(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): {RESPONSE_TYPE_NAME} throws do\n" @@ -590,7 +602,9 @@ def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries ) -def unpack_generic_symbol_arg(arg_name, symbol_class_name, symbol_count, symbol_param_class): +def unpack_generic_symbol_arg( + arg_name, symbol_class_name, symbol_count, symbol_param_class +): """ Generate the code to unpack a non-array symbol-table entry class (a class that inherits from 'AbstractSymEntry'). @@ -739,7 +753,10 @@ def gen_arg_unpacking(formals, config): """ Generate argument unpacking code for a message handler procedure - Returns the chapel code to unpack the arguments, and a list of generic arguments + Returns a tuple containing: + * the chapel code to unpack the arguments + * a list of generic arguments + * a map of array domain/type queries to their corresponding generic arguments """ unpack_lines = [] generic_args = [] @@ -850,7 +867,11 @@ def gen_arg_unpacking(formals, config): generic_args += scalar_args scalar_arg_counter += 1 - return ("\n".join(unpack_lines), generic_args) + return ( + "\n".join(unpack_lines), + generic_args, + {**array_domain_queries, **array_dtype_queries}, + ) def gen_user_function_call(name, arg_names, mod_name, user_rt): @@ -924,8 +945,8 @@ def gen_command_proc(name, return_type, formals, mod_name, config): * the chapel code for the command procedure * the name of the command procedure * a boolean indicating whether the command has generic (param/type) formals - * a list of tuples in the format (name, storage kind, type expression) - representing the generic formals of the command procedure + * a list of FormalTypeSpec representing the command procedure's generic formals + * a table of domain/type queries used in array formals mapped to their respective generic arguments proc (cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, ): MsgTuple throws { () @@ -936,7 +957,7 @@ def gen_command_proc(name, return_type, formals, mod_name, config): """ - arg_unpack, command_formals = gen_arg_unpacking(formals, config) + arg_unpack, command_formals, query_table = gen_arg_unpacking(formals, config) is_generic_command = len(command_formals) > 0 signature, cmd_name = gen_signature(name, command_formals) fn_call, result_name = gen_user_function_call( @@ -987,11 +1008,192 @@ def gen_command_proc(name, return_type, formals, mod_name, config): [signature, arg_unpack, fn_call, symbol_creation, response, "}"] ) - return (command_proc, cmd_name, is_generic_command, command_formals) + return (command_proc, cmd_name, is_generic_command, command_formals, query_table) + + +# TODO: use the compiler's built-in support for where-clause evaluation and resolution +# instead of re-implementing it in a much less robust manner here +class WCNode: + def __init__(self, ast): + if isinstance(ast, chapel.OpCall): + if ast.is_binary_op(): + self.node = WCBinOP(ast) + else: + self.node = WCUnaryOP(ast) + elif isinstance(ast, chapel.FnCall): + # 'int(8)' for example should be treated as a literal type name, not a function call + call_name = ast.called_expression().name() + if call_name in chapel_scalar_types.keys(): + self.node = WCLiteral(call_name, list(ast.actuals())[0].text()) + else: + self.node = WCFunc(ast) + else: + self.node = WCLiteral(ast) + + def eval(self, args, translation_table=None): + return self.node.eval(args, translation_table) + + def __str__(self): + return self.node.__str__() + + def __repr__(self): + return self.node.__str__() + + +class WCBinOP(WCNode): + def __init__(self, ast): + self.op = ast.op() + actuals = list(ast.actuals()) + self.lhs = WCNode(actuals[0]) + self.rhs = WCNode(actuals[1]) + + def eval(self, args, translation_table=None): + lhse = self.lhs.eval(args, translation_table) + rhse = self.rhs.eval(args, translation_table) + + if self.op == "==": + return str(lhse) == str(rhse) + elif self.op == "!=": + return str(lhse) != str(rhse) + elif self.op == "<": + return int(lhse) < int(rhse) + elif self.op == "<=": + return int(lhse) <= int(rhse) + elif self.op == ">": + return int(lhse) > int(rhse) + elif self.op == ">=": + return int(lhse) >= int(rhse) + elif self.op == "&&": + return bool(lhse) and bool(rhse) + elif self.op == "||": + return bool(lhse) or bool(rhse) + else: + error_message( + "evaluating where-clause", + f"binary operator '{self.op}' not yet supported in where-clauses", + ) + return True + + def __str__(self): + return f"({self.lhs} {self.op} {self.rhs})" + + +class WCUnaryOP(WCNode): + def __init__(self, ast): + self.op = ast.op() + self.operand = WCNode(list(ast.actuals())[0]) + + def eval(self, args, translation_table=None): + if self.op == "!": + return not bool(self.operand.eval(args, translation_table)) + elif self.op == "-": + return -int(self.operand.eval(args, translation_table)) + else: + error_message( + "evaluating where-clause", + f"unary operator '{self.op}' not yet supported in where-clauses", + ) + return True + + def __str__(self): + return f"{self.op}{self.operand}" + + +class WCFunc(WCNode): + def __init__(self, ast): + self.name = ast.called_expression().name() + self.actuals = [WCNode(a) for a in list(ast.actuals())] + + def eval(self, args, translation_table=None): + # TODO: this is a really bad way to do this. the compiler should be leveraged much more heavily here + arg = self.actuals[0].eval(args, translation_table) + if self.name == "isIntegralType": + return arg in [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + ] + elif self.name == "isRealType": + return arg in [ + "float32", + "float64", + ] + elif self.name == "isComplexType": + return arg in [ + "complex", + "complex64", + "complex128", + ] + elif self.name == "isImagType": + return arg in [ + "imag", + "imag32", + "imag64", + ] + else: + error_message( + "evaluating where-clause", + f"general function calls not yet supported in where-clauses; ignoring function: {self.name}", + ) + return True + + def __str__(self): + return f"{self.name}({', '.join([str(a) for a in self.actuals])})" + + +def canonicalize_type_name(name): + if name in chapel_scalar_types: + return chapel_scalar_types[name] + else: + return name + + +class WCLiteral(WCNode): + def __init__(self, ast, width=None): + # note: scalar type names are canonicalized to ensure 'int' == 'int(64)' (for example) + if width is not None: + self.value = canonicalize_type_name(f"{ast}({width})") + elif isinstance(ast, chapel.Identifier): + self.value = canonicalize_type_name(ast.name()) + elif isinstance(ast, chapel.IntLiteral): + self.value = ast.text() + elif isinstance(ast, chapel.Dot): + self.value = ast.receiver().name() + "." + ast.field() + # 🥲 + if self.value == "BigInteger.bigint": + self.value = "bigint" + if self.value.endswith(".rank"): + self.value = self.value.split(".")[0] # ex: d1.rank -> d1 + else: + raise ValueError("invalid where-clause literal") + + def eval(self, args, translation_table=None): + if self.value in args: + return canonicalize_type_name(args[self.value]) + elif translation_table is not None and self.value in translation_table: + return canonicalize_type_name(args[translation_table[self.value]]) + else: + return self.value + + def __str__(self): + return self.value def stamp_out_command( - config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc + config, + formals, + name, + cmd_prefix, + mod_name, + line_num, + iar_annotation, + wc, + query_table=None, ): """ Yield instantiations of a generic command with using the @@ -1007,15 +1209,28 @@ def stamp_out_command( * cmd_prefix: the prefix to use for the command names * mod_name: the name of the module containing the command procedure (or the user-defined procedure that the command calls) + * line_num: the line number of the annotated procedure + * iar_annotation: a boolean indicating whether the command procedure was annotated with 'instantiateAndRegister' + * wc: the where clause of the annotated procedure + * query_table: a dictionary mapping query names to their corresponding generic formal names The name of the instantiated command will be in the format: 'cmd_prefix' where v1, v2, ... are the values of the generic formals """ formal_perms = generic_permutations(config, formals) + if wc is not None: + wc_node = WCNode(wc) + else: + wc_node = None + for fp in formal_perms: + # skip instantiation for this permutation if the where clause evaluates to false + if wcn := wc_node: + if not wcn.eval(fp, query_table): + continue stamp = stamp_generic_command( - name, cmd_prefix, mod_name, fp, line_num, is_user_proc + name, cmd_prefix, mod_name, fp, line_num, iar_annotation ) yield stamp @@ -1079,6 +1294,10 @@ def register_commands(config, source_files): else: command_prefix = name + ignore_where_clause = False + if iwc := attr_call["ignoreWhereClause"]: + ignore_where_clause = bool(iwc.value()) + if len(gen_formals) > 0: error_message( f"registering '{name}'", @@ -1087,8 +1306,8 @@ def register_commands(config, source_files): ) continue - (cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals) = gen_command_proc( - name, fn.return_type(), con_formals, mod_name, config + (cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals, query_table) = ( + gen_command_proc(name, fn.return_type(), con_formals, mod_name, config) ) file_stamps.append(cmd_proc) @@ -1104,6 +1323,8 @@ def register_commands(config, source_files): mod_name, line_num, False, + fn.where_clause() if not ignore_where_clause else None, + query_table, ): file_stamps.append(stamp) except ValueError as e: @@ -1143,7 +1364,14 @@ def register_commands(config, source_files): try: for stamp in stamp_out_command( - config, gen_formals, name, command_prefix, mod_name, line_num, True + config, + gen_formals, + name, + command_prefix, + mod_name, + line_num, + True, + fn.where_clause(), ): file_stamps.append(stamp) count += 1