Skip to content

Commit 94a6923

Browse files
authored
Merge pull request #578 from pfackeldey/pfackeldey/fix_new_known_scalar_dtype
fix: dtype inference in new_known_scalar
2 parents 54ef202 + 6062dc7 commit 94a6923

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

src/dask_awkward/lib/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,8 @@ def new_known_scalar(
706706
dtype = np.dtype(int)
707707
elif isinstance(s, (float, np.floating)):
708708
dtype = np.dtype(float)
709+
elif hasattr(s, "dtype"):
710+
dtype = getattr(s, "dtype")
709711
else:
710712
dtype = np.dtype(type(s))
711713
else:

tests/test_core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,14 @@ def test_new_known_scalar() -> None:
466466
assert c.compute() == s1
467467

468468

469+
def test_new_known_scalar_from_array() -> None:
470+
s = np.array(0.0)
471+
c = new_known_scalar(s)
472+
assert c.compute() == s
473+
assert c._meta is not None
474+
assert c.dtype == s.dtype
475+
476+
469477
def test_scalar_dtype() -> None:
470478
s = 2
471479
c = new_known_scalar(s)

0 commit comments

Comments
 (0)