diff --git a/src/finch/tensor.py b/src/finch/tensor.py index 0caed35..afbc971 100644 --- a/src/finch/tensor.py +++ b/src/finch/tensor.py @@ -62,17 +62,22 @@ class Tensor(_Display): def __init__( self, - obj: Union[np.ndarray, spmatrix, Storage, JuliaObj], + obj: Union[np.ndarray, np.number, spmatrix, Storage, JuliaObj], /, *, - fill_value: np.number = 0.0, + fill_value: Optional[np.number] = None, ): if _is_scipy_sparse_obj(obj): # scipy constructor jl_data = self._from_scipy_sparse(obj) self._obj = jl_data elif isinstance(obj, np.ndarray): # numpy constructor + fill_value = 0.0 if fill_value is None else fill_value jl_data = self._from_numpy(obj, fill_value=fill_value) self._obj = jl_data + elif np.isscalar(obj): + if fill_value is not None: + raise UserWarning("`fill_value` argument is ignored for scalar input") + self._obj = jl.Scalar(obj) elif isinstance(obj, Storage): # from-storage constructor order = self.preprocess_order(obj.order, self.get_lvl_ndim(obj.levels_descr._obj)) self._obj = jl.swizzle(jl.Tensor(obj.levels_descr._obj), *order) diff --git a/tests/test_ops.py b/tests/test_ops.py index 4fb484b..7f6ce30 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -143,3 +143,16 @@ def test_matmul(arr2d, arr3d): with pytest.raises(ValueError, match="Both tensors must be 2-dimensional"): A_finch @ D_finch + + +def test_scalars(arr3d): + A_finch = finch.Tensor(arr3d) + result = A_finch + finch.Tensor(1) # Scalar{1, Int64}(1) + + assert result._is_dense + + storage = finch.Storage(finch.Dense(finch.SparseList(finch.SparseList(finch.Element(0))))) + B_finch = A_finch.to_device(storage) + result = B_finch + finch.Tensor(1) + + assert not result._is_dense