From 2dd12007d0db8ddc9d03a6dd28606f5b06a87bc1 Mon Sep 17 00:00:00 2001 From: Aryan Pandey Date: Wed, 20 Sep 2023 11:13:34 +0530 Subject: [PATCH 1/3] fix bug #23044 --- .../array/experimental/manipulation.py | 5 +++++ ivy/functional/backends/jax/general.py | 22 ++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/ivy/data_classes/array/experimental/manipulation.py b/ivy/data_classes/array/experimental/manipulation.py index 48d9586c24dba..5d09c746ad7a5 100644 --- a/ivy/data_classes/array/experimental/manipulation.py +++ b/ivy/data_classes/array/experimental/manipulation.py @@ -173,6 +173,11 @@ def vstack( ivy.array([[1, 2], [5, 6], [7, 8]]) + + x = ivy.array([1, 2]) + y = [ivy.array([[5, 6]]), ivy.array([[7, 8]])] + ivy.vstack((x, y)) + ivy.vstack((x, y, x, y)) """ if not isinstance(arrays, (list, tuple)): arrays = [arrays] diff --git a/ivy/functional/backends/jax/general.py b/ivy/functional/backends/jax/general.py index 9221319cfeaed..0b3d38a25f106 100644 --- a/ivy/functional/backends/jax/general.py +++ b/ivy/functional/backends/jax/general.py @@ -120,6 +120,26 @@ def to_list(x: JaxArray, /) -> list: return _to_array(x).tolist() +# ivy/utils/assertions.py + +def get_positive_axis_for_gather(axis, ndims): + if not isinstance(axis, int): + raise TypeError(f"{axis} must be an int; got {type(axis).__name__}") + if ndims is not None: + if 0 <= axis < ndims: + return axis + elif -ndims <= axis < 0: + return axis + ndims + else: + raise ValueError(f"{axis}={axis} out of bounds: " + f"expected {-ndims}<={axis}<{ndims}") + elif axis < 0: + raise ValueError(f"{axis} may only be negative " + f"if {ndims} is statically known.") + return axis + +# ivy/functional/backends/jax/general.py + def gather( params: JaxArray, indices: JaxArray, @@ -129,7 +149,7 @@ def gather( batch_dims: int = 0, out: Optional[JaxArray] = None, ) -> JaxArray: - axis = axis % len(params.shape) + axis = get_positive_axis_for_gather(axis, params.ndim) batch_dims = batch_dims % len(params.shape) ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims) result = [] From d15439434865939055ee2e7d61987d51624bac59 Mon Sep 17 00:00:00 2001 From: Aryan Pandey Date: Wed, 20 Sep 2023 21:14:59 +0530 Subject: [PATCH 2/3] fix the bug #22842 --- docs/overview/related_work/what_does_ivy_add.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/overview/related_work/what_does_ivy_add.rst b/docs/overview/related_work/what_does_ivy_add.rst index 14a407d24a751..8effb386c5edd 100644 --- a/docs/overview/related_work/what_does_ivy_add.rst +++ b/docs/overview/related_work/what_does_ivy_add.rst @@ -105,6 +105,6 @@ Firstly, we are adhering to the `Array API Standard`_ defined by Quansight. In essence, they have written the standard and we have implemented it, which is pretty much as complementary as it gets. Similarly, OctoML makes it easy for anyone to *deploy* their model anywhere, while Ivy makes it easy for anyone to mix and match any code from any frameworks and versions to *train* their model anywhere. Again very complementary objectives. -Finally, Modular will perhaps make it possible for developers to make changes at various levels of the stack when creating ML models using their "", and this would also be a great addition to the field. +Finally, Modular will perhaps make it possible for developers to make changes at various levels of the stack when creating ML models using their own, and this would also be a great addition to the field. Compared to Modular which focuses on the lower levels of the stack, Ivy instead unifies the ML frameworks at the functional API level, enabling code conversions to and from the user-facing APIs themselves, without diving into any of the lower level details. All of these features are entirely complementary, and together would form a powerful suite of unifying tools for ML practitioners. From 8be108d52cf0b26de4300630f8e423894a818fdc Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Mon, 2 Oct 2023 19:44:38 +0000 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/backends/jax/general.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ivy/functional/backends/jax/general.py b/ivy/functional/backends/jax/general.py index 0b3d38a25f106..19fec6839248e 100644 --- a/ivy/functional/backends/jax/general.py +++ b/ivy/functional/backends/jax/general.py @@ -122,6 +122,7 @@ def to_list(x: JaxArray, /) -> list: # ivy/utils/assertions.py + def get_positive_axis_for_gather(axis, ndims): if not isinstance(axis, int): raise TypeError(f"{axis} must be an int; got {type(axis).__name__}") @@ -131,15 +132,17 @@ def get_positive_axis_for_gather(axis, ndims): elif -ndims <= axis < 0: return axis + ndims else: - raise ValueError(f"{axis}={axis} out of bounds: " - f"expected {-ndims}<={axis}<{ndims}") + raise ValueError( + f"{axis}={axis} out of bounds: expected {-ndims}<={axis}<{ndims}" + ) elif axis < 0: - raise ValueError(f"{axis} may only be negative " - f"if {ndims} is statically known.") + raise ValueError(f"{axis} may only be negative if {ndims} is statically known.") return axis + # ivy/functional/backends/jax/general.py + def gather( params: JaxArray, indices: JaxArray,