Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support where-clause evaluation in registration annotations #3841

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
9 changes: 9 additions & 0 deletions arkouda/array_api/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def clip(a: Array, a_min, a_max, /) -> Array:
a_max : scalar
The maximum value
"""
if a.dtype == ak.bigint or a.dtype == ak.bool_:
raise RuntimeError(f"Error executing command: clip does not support dtype {a.dtype}")

return Array._new(
create_pdarray(
generic_msg(
Expand Down Expand Up @@ -99,6 +102,9 @@ def diff(a: Array, /, n: int = 1, axis: int = -1, prepend=None, append=None) ->
append : Array, optional
Array to append to `a` along `axis` before calculating the difference.
"""
if a.dtype == ak.bigint or a.dtype == ak.bool_:
raise RuntimeError(f"Error executing command: diff does not support dtype {a.dtype}")

if prepend is not None and append is not None:
a_ = concat((prepend, a, append), axis=axis)
elif prepend is not None:
Expand Down Expand Up @@ -146,6 +152,9 @@ def pad(
if mode != "constant":
raise NotImplementedError(f"pad mode '{mode}' is not supported")

if array.dtype == ak.bigint:
raise RuntimeError("Error executing command: pad does not support dtype bigint")

if "constant_values" not in kwargs:
cvals = 0
else:
Expand Down
8 changes: 1 addition & 7 deletions src/ArgSortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,11 @@ module ArgSortMsg
axis = msgArgs["axis"].toScalar(int),
symEntry = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
vals = if (array_dtype == bool) then (symEntry.a:int) else (symEntry.a: array_dtype);

const iv = argsortDefault(vals, algorithm=algorithm, axis);
return st.insert(new shared SymEntry(iv));
}

proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("argsort does not support the %s dtype".format(array_dtype:string));
}

proc argsortStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const name = msgArgs["name"].toScalar(string),
strings = getSegString(name, st),
Expand Down
12 changes: 3 additions & 9 deletions src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,9 @@ module AryUtil
flatten a multi-dimensional array into a 1D array
*/
@arkouda.registerCommand
proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank > 1
{
proc flatten(const ref a: [?d] ?t): [] t throws {
if a.rank == 1 then return a;

var flat = makeDistArray(d.size, t);

// ranges of flat indices owned by each locale
Expand Down Expand Up @@ -1004,12 +1004,6 @@ module AryUtil
return flat;
}

proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank == 1
{
return a;
}

// helper for computing an array element's index from its order
record orderer {
param rank: int;
Expand Down
22 changes: 1 addition & 21 deletions src/CastMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@ module CastMsg {
private config const logChannel = ServerConfig.logChannel;
const castLogger = new Logger(logLevel, logChannel);

proc isFloatingType(type t) param : bool {
return isRealType(t) || isImagType(t) || isComplexType(t);
}

@arkouda.instantiateAndRegister(prefix="cast")
proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_from,
type array_dtype_to,
param array_nd: int
): MsgTuple throws
where !(isFloatingType(array_dtype_from) && array_dtype_to == bigint) &&
where !((isRealType(array_dtype_from) || isImagType(array_dtype_from) || isComplexType(array_dtype_from)) && array_dtype_to == bigint) &&
!(array_dtype_from == bigint && array_dtype_to == bool)
{
const a = st[msgArgs["name"]]: SymEntry(array_dtype_from, array_nd);
Expand All @@ -40,22 +36,6 @@ module CastMsg {
}
}

// cannot cast float types to bigint, cannot cast bigint to bool
proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_from,
type array_dtype_to,
param array_nd: int
): MsgTuple throws
where (isFloatingType(array_dtype_from) && array_dtype_to == bigint) ||
(array_dtype_from == bigint && array_dtype_to == bool)
{
return MsgTuple.error(
"cannot cast array of type %s to %s".format(
type2str(array_dtype_from),
type2str(array_dtype_to)
));
}

@arkouda.instantiateAndRegister(prefix="castToStrings")
proc castArrayToStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws {
const name = msgArgs["name"].toScalar(string);
Expand Down
12 changes: 0 additions & 12 deletions src/GenSymIO.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ module GenSymIO {
return st.insert(new shared SymEntry(makeArrayFromBytes(msgArgs.payload, shape, array_dtype)));
}

proc array(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("Array creation from binary payload is not supported for bigint arrays");
}

proc makeArrayFromBytes(ref payload: bytes, shape: ?N*int, type t): [] t throws {
var size = 1;
for s in shape do size *= s;
Expand Down Expand Up @@ -138,12 +132,6 @@ module GenSymIO {
return MsgTuple.payload(bytes.createAdoptingBuffer(ptr:c_ptr(uint(8)), size, size));
}

proc tondarray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("cannot create ndarray from bigint array");
}

/*
* Utility proc to test casting a string to a specified type
* :arg c: String to cast
Expand Down
14 changes: 0 additions & 14 deletions src/IndexingMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,6 @@ module IndexingMsg
}
}

proc multiPDArrayIndex(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_a, type array_dtype_idx, param array_nd: int): MsgTuple throws
where array_dtype_idx != int && array_dtype_idx != uint
{
return MsgTuple.error("Invalid index type: %s; must be 'int' or 'uint'".format(type2str(array_dtype_idx)));
}

private proc multiIndexShape(inShape: ?N*int, idxDims: [?d] int, outSize: int): (bool, int, N*int) {
var minShape: N*int = inShape,
firstRank = -1;
Expand Down Expand Up @@ -960,14 +954,6 @@ module IndexingMsg
return st.insert(new shared SymEntry(y, x.max_bits));
}

proc takeAlongAxis(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_x,
type array_dtype_idx,
param array_nd: int
): MsgTuple throws {
return MsgTuple.error("Cannot take along axis with non-integer index array");
}

use CommandMap;
registerFunction("arrayViewMixedIndex", arrayViewMixedIndexMsg, getModuleName());
registerFunction("[pdarray]", pdarrayIndexMsg, getModuleName());
Expand Down
53 changes: 3 additions & 50 deletions src/LinalgMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ module LinalgMsg {
return st.insert(e);
}


proc eye(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws
where array_dtype == BigInteger.bigint
{
return MsgTuple.error("eye does not support the bigint dtype");
}

// tril and triu are identical except for the argument they pass to triluHandler (true for upper, false for lower)
// The zeros are written into the upper (or lower) triangle of the array, offset by the value of diag.

Expand All @@ -79,11 +72,6 @@ module LinalgMsg {
return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, false);
}

proc tril(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd < 2 {
return MsgTuple.error("Array must be at least 2 dimensional for 'tril'");
}

// Create an array from an existing array with its lower triangle zeroed out

@arkouda.instantiateAndRegister
Expand All @@ -92,13 +80,9 @@ module LinalgMsg {
return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, true);
}

proc triu(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd < 2 {
return MsgTuple.error("Array must be at least 2 dimensional for 'triu'");
}

// Fetch the arguments, call zeroTri, return result.

// Fetch the arguments, call zeroTri, return result.
// TODO: support instantiating param bools with 'true' and 'false' s.t. we'd have 'triluHandler<true>' and 'triluHandler<false>'
// cmds if this procedure were annotated instead of the two above.
proc triluHandler(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype, param array_nd: int, param upper: bool
): MsgTuple throws {
Expand Down Expand Up @@ -195,16 +179,6 @@ module LinalgMsg {

}

proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_nd < 2) && (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) {
return MsgTuple.error("Matrix multiplication with arrays of dimension < 2 is not supported");
}

proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) {
return MsgTuple.error("Matrix multiplication with arrays of bigint type is not supported");
}

proc compute_result_type_matmul(type t1, type t2) type {
if t1 == real || t2 == real then return real;
if t1 == int || t2 == int then return int;
Expand Down Expand Up @@ -302,11 +276,6 @@ module LinalgMsg {
return ret;
}

proc transpose(array: [?d] ?t): [d] t throws
where d.rank < 2 {
throw new Error("Matrix transpose with arrays of dimension < 2 is not supported");
}

/*
Compute the generalized dot product of two tensors along the specified axis.

Expand Down Expand Up @@ -366,22 +335,6 @@ module LinalgMsg {
return bool;
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_nd < 2) && ((array_dtype_x1 != bool) || (array_dtype_x2 != bool))
&& (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) {
return MsgTuple.error("VecDot with arrays of dimension < 2 is not supported");
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == bool) && (array_dtype_x2 == bool) {
return MsgTuple.error("VecDot with arrays both of type bool is not supported");
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) {
return MsgTuple.error("VecDot with arrays of type bigint is not supported");
}

// @arkouda.registerND(???)
// proc tensorDotMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd1: int, param nd2: int): MsgTuple throws {
// if nd < 3 {
Expand Down
10 changes: 0 additions & 10 deletions src/MsgProcessing.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,6 @@ module MsgProcessing
return msg;
}

proc chunkInfoAsString(array: [?d] ?t): string throws
where (t != bool) && (t != int(64)) && (t != uint(64)) && (t != uint(8)) && (t != real){
throw new Error("chunkInfo does not support dtype %s".format(t:string));
}

@arkouda.registerCommand
proc chunkInfoAsArray(array: [?d] ?t):[] int throws
where (t == bool) || (t == int(64)) || (t == uint(64)) || (t == uint(8)) ||(t == real) {
Expand All @@ -357,9 +352,4 @@ module MsgProcessing
}
return blockSizes;
}

proc chunkInfoAsArray(array: [?d] ?t): [d] int throws
where (t != bool) && (t != int(64)) && (t != uint(64)) && (t != uint(8)) && (t != real){
throw new Error("chunkInfo does not support dtype %s".format(t:string));
}
}
41 changes: 13 additions & 28 deletions src/RandMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ module RandMsg
return st.insert(e);
}

proc randint(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == BigInteger.bigint
{
return MsgTuple.error("randint does not support the bigint dtype");
}

@arkouda.instantiateAndRegister
proc randomNormal(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd: int): MsgTuple throws {
const shape = msgArgs["shape"].toScalarTuple(int, array_nd),
Expand Down Expand Up @@ -117,12 +111,6 @@ module RandMsg
return st.insert(new shared GeneratorSymEntry(generator, state));
}

proc createGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws
where array_dtype == BigInteger.bigint
{
return MsgTuple.error("createGenerator does not support the bigint dtype");
}

@arkouda.instantiateAndRegister
proc uniformGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype != BigInteger.bigint
Expand Down Expand Up @@ -151,12 +139,6 @@ module RandMsg
return st.insert(uniformEntry);
}

proc uniformGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == BigInteger.bigint
{
return MsgTuple.error("uniformGenerator does not support the bigint dtype");
}


/*
Use the ziggurat method (https://en.wikipedia.org/wiki/Ziggurat_algorithm#Theory_of_operation)
Expand Down Expand Up @@ -252,6 +234,9 @@ module RandMsg

@arkouda.instantiateAndRegister
proc standardNormalGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws
do return standardNormalGeneratorHelp(cmd, msgArgs, st, array_nd);

proc standardNormalGeneratorHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws
where array_nd == 1
{
const name = msgArgs["name"], // generator name
Expand Down Expand Up @@ -287,7 +272,7 @@ module RandMsg
}


proc standardNormalGenerator(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws
proc standardNormalGeneratorHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws
where array_nd > 1
{
const name = msgArgs["name"], // generator name
Expand Down Expand Up @@ -387,6 +372,9 @@ module RandMsg

@arkouda.instantiateAndRegister
proc standardExponential(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws
do return standardExponentialHelp(cmd, msgArgs, st, array_nd);

proc standardExponentialHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws
where array_nd == 1
{
const name = msgArgs["name"], // generator name
Expand Down Expand Up @@ -421,7 +409,7 @@ module RandMsg
}
}

proc standardExponential(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws
proc standardExponentialHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param array_nd): MsgTuple throws
where array_nd > 1
{
const name = msgArgs["name"], // generator name
Expand Down Expand Up @@ -567,12 +555,6 @@ module RandMsg
}
}

proc choice(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws
where array_dtype == BigInteger.bigint
{
return MsgTuple.error("choice does not support the bigint dtype");
}

inline proc logisticGenerator(mu: real, scale: real, ref rs) {
var U = rs.next(0, 1);

Expand Down Expand Up @@ -693,7 +675,10 @@ module RandMsg
}

@arkouda.instantiateAndRegister
proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
do return shuffleHelp(cmd, msgArgs, st, array_dtype, array_nd);

proc shuffleHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd == 1
{
const name = msgArgs["name"],
Expand All @@ -715,7 +700,7 @@ module RandMsg
return MsgTuple.success();
}

proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
proc shuffleHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd != 1
{
const name = msgArgs["name"],
Expand Down
Loading
Loading