Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontend): adding __truediv__ to paddle.tensor.tensor.Tensor #28113

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 87 additions & 56 deletions ivy/functional/frontends/jax/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# global
from numbers import Number
from typing import Union, Tuple, Iterable

import jax.numpy as jnp

# local
import ivy
Expand All @@ -19,101 +19,123 @@
_uint16 = ivy.UintDtype("uint16")
_uint32 = ivy.UintDtype("uint32")
_uint64 = ivy.UintDtype("uint64")
_bfloat16 = ivy.FloatDtype("bfloat16")
_float16 = ivy.FloatDtype("float16")
_float32 = ivy.FloatDtype("float32")
_float64 = ivy.FloatDtype("float64")
_complex64 = ivy.ComplexDtype("complex64")
_complex128 = ivy.ComplexDtype("complex128")
_bool = ivy.Dtype("bool")
_bfloat16 = jnp.bfloat16
_float16 = jnp.float16
_float32 = jnp.float32
_float64 = jnp.float64
_complex64 = jnp.complex64
_complex128 = jnp.complex128
_bool = jnp.bool_

# jax-numpy casting table
jax_numpy_casting_table = {
_bool: [
_bool,
_int8,
_int16,
_int32,
_int64,
_uint8,
_uint16,
_uint32,
_uint64,
jnp.int8,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.uint8,
jnp.uint16,
jnp.uint32,
jnp.uint64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_int8: [
_int8,
_int16,
_int32,
_int64,
jnp.int8: [
jnp.int8,
jnp.int16,
jnp.int32,
jnp.int64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_int16: [
_int16,
_int32,
_int64,
jnp.int16: [
jnp.int16,
jnp.int32,
jnp.int64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_int32: [
_int32,
_int64,
jnp.int32: [
jnp.int32,
jnp.int64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_int64: [
_int64,
jnp.int64: [
jnp.int64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_uint8: [
_int16,
_int32,
_int64,
_uint8,
_uint16,
_uint32,
_uint64,
jnp.uint8: [
jnp.uint8,
jnp.uint16,
jnp.uint32,
jnp.uint64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_uint16: [
_int32,
_int64,
_uint16,
_uint32,
_uint64,
jnp.uint16: [
jnp.uint16,
jnp.uint32,
jnp.uint64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_uint32: [
_int64,
_uint32,
_uint64,
jnp.uint32: [
jnp.uint32,
jnp.uint64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_uint64: [
_uint64,
jnp.uint64: [
jnp.uint64,
_float16,
_float32,
_float64,
_complex64,
_complex128,
_bfloat16,
],
_bfloat16: [
_bfloat16,
_float16,
_float32,
_float64,
_complex64,
_complex128,
],
_float16: [
Expand All @@ -131,17 +153,16 @@
],
_float64: [
_float64,
_complex64,
_complex128,
],
_complex64: [_complex64, ivy.complex128],
_complex128: [_complex128],
_bfloat16: [
_bfloat16,
_float32,
_float64,
_complex64: [
_complex64,
_complex128,
],
_complex128: [
_complex128,
],
}


Expand Down Expand Up @@ -384,6 +405,16 @@
}


def array_repr(arr):
shape = arr.shape
dtype = arr.dtype
device = arr.device
data = ivy.to_numpy(arr)

repr_str = f"Ivy Array (shape={shape}, dtype={dtype}, device={device}):\n{data}" # takes an Ivy array as input and generates a string representation

return repr_str

@handle_exceptions
def promote_types_jax(
type1: Union[ivy.Dtype, ivy.NativeDtype],
Expand Down
6 changes: 6 additions & 0 deletions ivy/functional/frontends/paddle/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def ivy_array(self, array):
{"2.6.0 and below": ("bool", "unsigned", "int8", "float16", "bfloat16")},
"paddle",
)
def __truediv__(self, other):
if isinstance(other, Tensor):
return Tensor(self.ivy_array / other.ivy_array)
else:
return Tensor(self.ivy_array / other)

def __add__(self, y, /, name=None):
return paddle_frontend.add(self, y)

Expand Down
Loading