Skip to content

Commit

Permalink
Correct boundry type test
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Oct 29, 2020
1 parent c55f410 commit 1f312a4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
11 changes: 10 additions & 1 deletion sparse_dot_mkl/_mkl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,16 @@ def _is_double(arr):


def _is_allowed_sparse_format(matrix):
return _spsparse.isspmatrix_csr(matrix) or _spsparse.isspmatrix_csc(matrix) or _spsparse.isspmatrix_bsr(matrix)
"""
Return True if the matrix is dense or a sparse format we can turn into an MKL object. False otherwise.
:param matrix:
:return:
:rtype: bool
"""
if _spsparse.isspmatrix(matrix):
return _spsparse.isspmatrix_csr(matrix) or _spsparse.isspmatrix_csc(matrix) or _spsparse.isspmatrix_bsr(matrix)
else:
return True


def _empty_output_check(matrix_a, matrix_b):
Expand Down
6 changes: 4 additions & 2 deletions sparse_dot_mkl/_sparse_vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sparse_dot_mkl._mkl_interface import (MKL, _sanity_check, _empty_output_check, _type_check, _create_mkl_sparse,
_destroy_mkl_handle, matrix_descr, RETURN_CODES, _is_dense_vector,
_out_matrix, _check_return_value)
_out_matrix, _check_return_value, _is_allowed_sparse_format)

import numpy as np
import ctypes as _ctypes
Expand Down Expand Up @@ -86,7 +86,9 @@ def _sparse_dot_vector(mv_a, mv_b, cast=False, scalar=1., out=None, out_scalar=N
_sanity_check(mv_a, mv_b, allow_vector=True)
mv_a, mv_b = _type_check(mv_a, mv_b, cast=cast)

if _is_dense_vector(mv_b):
if not _is_allowed_sparse_format(mv_a) or not _is_allowed_sparse_format(mv_b):
raise ValueError("Only CSR, CSC, and BSR-type sparse matrices are supported")
elif _is_dense_vector(mv_b):
return _sparse_dense_vector_mult(mv_a, mv_b, scalar=scalar, out=out, out_scalar=out_scalar)
elif _is_dense_vector(mv_a) and out is None:
return _sparse_dense_vector_mult(mv_b, mv_a.T, scalar=scalar, transpose=True).T
Expand Down
20 changes: 20 additions & 0 deletions sparse_dot_mkl/tests/test_sparse_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,26 @@ def setUp(self):
self.mat2_d = VECTOR.copy()


class TestSparseVectorMultiplicationCOO(unittest.TestCase):

def setUp(self):
self.mat1 = _spsparse.coo_matrix(MATRIX_1).copy()
self.mat2 = VECTOR.copy()

self.mat1_d = np.asarray(MATRIX_1.A, order="C")
self.mat2_d = VECTOR.copy()

def make_2d(self, arr):
return arr.reshape(-1, 1) if arr.ndim == 1 else arr

def test_fails(self):
with self.assertRaises(ValueError):
dot_product_mkl(self.mat1, self.make_2d(self.mat2), cast=True)

with self.assertRaises(ValueError):
dot_product_mkl(self.make_2d(self.mat2), self.mat1.T, cast=True)


class TestVectorSparseMultiplication(TestSparseVectorMultiplication):

sparse_func = _spsparse.csr_matrix
Expand Down

0 comments on commit 1f312a4

Please sign in to comment.