Skip to content

Commit e355ac4

Browse files
committed
address review comments
1 parent 42e9893 commit e355ac4

File tree

2 files changed

+101
-80
lines changed

2 files changed

+101
-80
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,8 @@ def isin(
12021202
test_elements,
12031203
assume_unique=False, # pylint: disable=unused-argument
12041204
invert=False,
1205+
*,
1206+
kind=None, # pylint: disable=unused-argument
12051207
):
12061208
"""
12071209
Calculates ``element in test_elements``, broadcasting over `element` only.
@@ -1221,22 +1223,27 @@ def isin(
12211223
assume_unique : bool, optional
12221224
Ignored, as no performance benefit is gained by assuming the
12231225
input arrays are unique. Included for compatibility with NumPy.
1226+
12241227
Default: ``False``.
12251228
invert : bool, optional
12261229
If ``True``, the values in the returned array are inverted, as if
1227-
calculating `element not in test_elements`.
1230+
calculating ``element not in test_elements``.
12281231
``dpnp.isin(a, b, invert=True)`` is equivalent to (but faster
12291232
than) ``dpnp.invert(dpnp.isin(a, b))``.
1233+
12301234
Default: ``False``.
1235+
kind : {None, "sort"}, optional
1236+
Ignored, as the only algorithm implemented is ``"sort"``. Included for
1237+
compatibility with NumPy.
12311238
1239+
Default: ``None``.
12321240
12331241
Returns
12341242
-------
12351243
isin : dpnp.ndarray of bool dtype
12361244
Has the same shape as `element`. The values `element[isin]`
12371245
are in `test_elements`.
12381246
1239-
12401247
Examples
12411248
--------
12421249
>>> import dpnp as np
@@ -1269,14 +1276,27 @@ def isin(
12691276
"""
12701277

12711278
dpnp.check_supported_arrays_type(element, test_elements, scalar_type=True)
1272-
usm_element = dpnp.as_usm_ndarray(
1273-
element, usm_type=element.usm_type, sycl_queue=element.sycl_queue
1274-
)
1275-
usm_test = dpnp.as_usm_ndarray(
1276-
test_elements,
1277-
usm_type=test_elements.usm_type,
1278-
sycl_queue=test_elements.sycl_queue,
1279-
)
1279+
if dpnp.isscalar(element):
1280+
usm_element = dpnp.as_usm_ndarray(
1281+
element,
1282+
usm_type=test_elements.usm_type,
1283+
sycl_queue=test_elements.sycl_queue,
1284+
)
1285+
usm_test = dpnp.get_usm_ndarray(test_elements)
1286+
elif dpnp.isscalar(test_elements):
1287+
usm_test = dpnp.as_usm_ndarray(
1288+
test_elements,
1289+
usm_type=element.usm_type,
1290+
sycl_queue=element.sycl_queue,
1291+
)
1292+
usm_element = dpnp.get_usm_ndarray(element)
1293+
else:
1294+
if dpu.get_execution_queue((element.sycl_queue, test_elements.sycl_queue)) is None:
1295+
raise dpu.ExecutionPlacementError(
1296+
"Input arrays have incompatible allocation queues"
1297+
)
1298+
usm_element = dpnp.get_usm_ndarray(element)
1299+
usm_test = dpnp.get_usm_ndarray(test_elements)
12801300
return dpnp.get_result_array(
12811301
dpt.isin(
12821302
usm_element,

dpnp/tests/test_logic.py

Lines changed: 71 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -797,98 +797,99 @@ def test_array_equal_nan(a):
797797
assert_equal(result, expected)
798798

799799

800-
@pytest.mark.parametrize(
801-
"a",
802-
[
803-
numpy.array([1, 2, 3, 4]),
804-
numpy.array([[1, 2], [3, 4]]),
805-
],
806-
)
807-
@pytest.mark.parametrize(
808-
"b",
809-
[
810-
numpy.array([2, 4, 6]),
811-
numpy.array([[1, 3], [5, 7]]),
812-
],
813-
)
814-
def test_isin_basic(a, b):
815-
dp_a = dpnp.array(a)
816-
dp_b = dpnp.array(b)
800+
class TestIsin:
801+
@pytest.mark.parametrize(
802+
"a",
803+
[
804+
numpy.array([1, 2, 3, 4]),
805+
numpy.array([[1, 2], [3, 4]]),
806+
],
807+
)
808+
@pytest.mark.parametrize(
809+
"b",
810+
[
811+
numpy.array([2, 4, 6]),
812+
numpy.array([[1, 3], [5, 7]]),
813+
],
814+
)
815+
def test_isin_basic(a, b):
816+
dp_a = dpnp.array(a)
817+
dp_b = dpnp.array(b)
817818

818-
expected = numpy.isin(a, b)
819-
result = dpnp.isin(dp_a, dp_b)
820-
assert_equal(result, expected)
819+
expected = numpy.isin(a, b)
820+
result = dpnp.isin(dp_a, dp_b)
821+
assert_equal(result, expected)
821822

822823

823-
@pytest.mark.parametrize("dtype", get_all_dtypes())
824-
def test_isin_dtype(dtype):
825-
a = numpy.array([1, 2, 3, 4], dtype=dtype)
826-
b = numpy.array([2, 4], dtype=dtype)
824+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
825+
def test_isin_dtype(dtype):
826+
a = numpy.array([1, 2, 3, 4], dtype=dtype)
827+
b = numpy.array([2, 4], dtype=dtype)
827828

828-
dp_a = dpnp.array(a, dtype=dtype)
829-
dp_b = dpnp.array(b, dtype=dtype)
829+
dp_a = dpnp.array(a, dtype=dtype)
830+
dp_b = dpnp.array(b, dtype=dtype)
830831

831-
expected = numpy.isin(a, b)
832-
result = dpnp.isin(dp_a, dp_b)
833-
assert_equal(result, expected)
832+
expected = numpy.isin(a, b)
833+
result = dpnp.isin(dp_a, dp_b)
834+
assert_equal(result, expected)
834835

835836

836-
@pytest.mark.parametrize("sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))])
837-
def test_isin_broadcast(sh_a, sh_b):
838-
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
839-
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
837+
@pytest.mark.parametrize("sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))])
838+
def test_isin_broadcast(sh_a, sh_b):
839+
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
840+
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
840841

841-
dp_a = dpnp.array(a)
842-
dp_b = dpnp.array(b)
842+
dp_a = dpnp.array(a)
843+
dp_b = dpnp.array(b)
843844

844-
expected = numpy.isin(a, b)
845-
result = dpnp.isin(dp_a, dp_b)
846-
assert_equal(result, expected)
845+
expected = numpy.isin(a, b)
846+
result = dpnp.isin(dp_a, dp_b)
847+
assert_equal(result, expected)
847848

848849

849-
def test_isin_scalar_elements():
850-
a = numpy.array([1, 2, 3])
851-
b = 2
850+
def test_isin_scalar_elements():
851+
a = numpy.array([1, 2, 3])
852+
b = 2
852853

853-
dp_a = dpnp.array(a)
854-
dp_b = dpnp.array(b)
854+
dp_a = dpnp.array(a)
855+
dp_b = dpnp.array(b)
855856

856-
expected = numpy.isin(a, b)
857-
result = dpnp.isin(dp_a, dp_b)
858-
assert_equal(result, expected)
857+
expected = numpy.isin(a, b)
858+
result = dpnp.isin(dp_a, dp_b)
859+
assert_equal(result, expected)
859860

860861

861-
def test_isin_scalar_test_elements():
862-
a = 2
863-
b = numpy.array([1, 2, 3])
862+
def test_isin_scalar_test_elements():
863+
a = 2
864+
b = numpy.array([1, 2, 3])
864865

865-
dp_a = dpnp.array(a)
866-
dp_b = dpnp.array(b)
866+
dp_a = dpnp.array(a)
867+
dp_b = dpnp.array(b)
867868

868-
expected = numpy.isin(a, b)
869-
result = dpnp.isin(dp_a, dp_b)
870-
assert_equal(result, expected)
869+
expected = numpy.isin(a, b)
870+
result = dpnp.isin(dp_a, dp_b)
871+
assert_equal(result, expected)
871872

872873

873-
def test_isin_empty():
874-
a = numpy.array([], dtype=int)
875-
b = numpy.array([1, 2, 3])
874+
def test_isin_empty():
875+
a = numpy.array([], dtype=int)
876+
b = numpy.array([1, 2, 3])
876877

877-
dp_a = dpnp.array(a)
878-
dp_b = dpnp.array(b)
878+
dp_a = dpnp.array(a)
879+
dp_b = dpnp.array(b)
879880

880-
expected = numpy.isin(a, b)
881-
result = dpnp.isin(dp_a, dp_b)
882-
assert_equal(result, expected)
881+
expected = numpy.isin(a, b)
882+
result = dpnp.isin(dp_a, dp_b)
883+
assert_equal(result, expected)
883884

884885

885-
def test_isin_errors():
886-
a = dpnp.arange(5)
887-
b = dpnp.arange(3)
886+
def test_isin_errors():
887+
a = dpnp.arange(5)
888+
b = dpnp.arange(3)
888889

889-
# unsupported type for elements or test_elements
890-
with pytest.raises(TypeError):
891-
dpnp.isin(dict(), b)
890+
# unsupported type for elements or test_elements
891+
with pytest.raises(TypeError):
892+
dpnp.isin(dict(), b)
892893

893-
with pytest.raises(TypeError):
894-
dpnp.isin(a, dict())
894+
with pytest.raises(TypeError):
895+
dpnp.isin(a, dict())

0 commit comments

Comments
 (0)