From 1f312a4bf6af1632cc668d4d46ab9ddacdd7913a Mon Sep 17 00:00:00 2001 From: asistradition Date: Thu, 29 Oct 2020 18:33:07 -0400 Subject: [PATCH] Correct boundry type test --- sparse_dot_mkl/_mkl_interface.py | 11 ++++++++++- sparse_dot_mkl/_sparse_vector.py | 6 ++++-- sparse_dot_mkl/tests/test_sparse_vector.py | 20 ++++++++++++++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/sparse_dot_mkl/_mkl_interface.py b/sparse_dot_mkl/_mkl_interface.py index 2037474..2011ae6 100644 --- a/sparse_dot_mkl/_mkl_interface.py +++ b/sparse_dot_mkl/_mkl_interface.py @@ -1107,7 +1107,16 @@ def _is_double(arr): def _is_allowed_sparse_format(matrix): - return _spsparse.isspmatrix_csr(matrix) or _spsparse.isspmatrix_csc(matrix) or _spsparse.isspmatrix_bsr(matrix) + """ + Return True if the matrix is dense or a sparse format we can turn into an MKL object. False otherwise. + :param matrix: + :return: + :rtype: bool + """ + if _spsparse.isspmatrix(matrix): + return _spsparse.isspmatrix_csr(matrix) or _spsparse.isspmatrix_csc(matrix) or _spsparse.isspmatrix_bsr(matrix) + else: + return True def _empty_output_check(matrix_a, matrix_b): diff --git a/sparse_dot_mkl/_sparse_vector.py b/sparse_dot_mkl/_sparse_vector.py index f103512..28c7f80 100644 --- a/sparse_dot_mkl/_sparse_vector.py +++ b/sparse_dot_mkl/_sparse_vector.py @@ -1,6 +1,6 @@ from sparse_dot_mkl._mkl_interface import (MKL, _sanity_check, _empty_output_check, _type_check, _create_mkl_sparse, _destroy_mkl_handle, matrix_descr, RETURN_CODES, _is_dense_vector, - _out_matrix, _check_return_value) + _out_matrix, _check_return_value, _is_allowed_sparse_format) import numpy as np import ctypes as _ctypes @@ -86,7 +86,9 @@ def _sparse_dot_vector(mv_a, mv_b, cast=False, scalar=1., out=None, out_scalar=N _sanity_check(mv_a, mv_b, allow_vector=True) mv_a, mv_b = _type_check(mv_a, mv_b, cast=cast) - if _is_dense_vector(mv_b): + if not _is_allowed_sparse_format(mv_a) or not _is_allowed_sparse_format(mv_b): + raise ValueError("Only CSR, CSC, and BSR-type sparse matrices are supported") + elif _is_dense_vector(mv_b): return _sparse_dense_vector_mult(mv_a, mv_b, scalar=scalar, out=out, out_scalar=out_scalar) elif _is_dense_vector(mv_a) and out is None: return _sparse_dense_vector_mult(mv_b, mv_a.T, scalar=scalar, transpose=True).T diff --git a/sparse_dot_mkl/tests/test_sparse_vector.py b/sparse_dot_mkl/tests/test_sparse_vector.py index 4bf7603..7a43281 100644 --- a/sparse_dot_mkl/tests/test_sparse_vector.py +++ b/sparse_dot_mkl/tests/test_sparse_vector.py @@ -135,6 +135,26 @@ def setUp(self): self.mat2_d = VECTOR.copy() +class TestSparseVectorMultiplicationCOO(unittest.TestCase): + + def setUp(self): + self.mat1 = _spsparse.coo_matrix(MATRIX_1).copy() + self.mat2 = VECTOR.copy() + + self.mat1_d = np.asarray(MATRIX_1.A, order="C") + self.mat2_d = VECTOR.copy() + + def make_2d(self, arr): + return arr.reshape(-1, 1) if arr.ndim == 1 else arr + + def test_fails(self): + with self.assertRaises(ValueError): + dot_product_mkl(self.mat1, self.make_2d(self.mat2), cast=True) + + with self.assertRaises(ValueError): + dot_product_mkl(self.make_2d(self.mat2), self.mat1.T, cast=True) + + class TestVectorSparseMultiplication(TestSparseVectorMultiplication): sparse_func = _spsparse.csr_matrix