Skip to content

Commit

Permalink
add tests for device parametrization (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Apr 13, 2021
1 parent ba18726 commit 0bae345
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 1 deletion.
43 changes: 43 additions & 0 deletions tests/assets/test_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest

import torch
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
onlyCPU,
)
from torch.testing._internal.common_utils import TestCase


class TestFoo(TestCase):
@onlyCPU
@dtypes(torch.float16, torch.int32)
# fails for float16, passes for int32
def test_bar(self, device, dtype):
assert dtype == torch.int32

# passes for float16, skips for int32
@onlyCPU
@dtypes(torch.float16, torch.int32)
def test_baz(self, device, dtype):
if dtype == torch.int32:
raise unittest.SkipTest

assert True


instantiate_device_type_tests(TestFoo, globals())


class TestSpam(TestCase):
@onlyCPU
def test_ham(self, device):
assert True

@onlyCPU
@dtypes(torch.float16)
def test_eggs(self, device, dtype):
assert False


instantiate_device_type_tests(TestSpam, globals())
94 changes: 93 additions & 1 deletion tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,99 @@
passed=1,
),
)
def test_collection(testdir, file, cmds, outcomes):
def test_device(testdir, file, cmds, outcomes):
testdir.copy_example(file)
result = testdir.runpytest(*cmds)
result.assert_outcomes(**outcomes)


@make_params(
"test_dtype.py",
Config(
"*test-*device-*dtype",
new_cmds=(),
legacy_cmds=(),
passed=3,
skipped=7,
failed=2,
),
Config(
"1testcase1-*test-*dtype",
new_cmds="::TestFoo",
legacy_cmds=("-k", "TestFoo"),
passed=2,
skipped=5,
failed=1,
),
Config(
"1testcase2-*test-*device",
new_cmds="::TestSpam",
legacy_cmds=("-k", "TestSpam"),
passed=1,
skipped=2,
failed=1,
),
Config(
"*testcase-*test-1dtype1",
new_cmds=("-k", "float16"),
legacy_cmds=("-k", "float16"),
passed=1,
skipped=3,
failed=2,
),
Config(
"*testcase-*test-1dtype2",
new_cmds=("-k", "int32"),
legacy_cmds=("-k", "int32"),
passed=1,
skipped=3,
),
Config(
"1testcase-*test-1dtype1",
new_cmds=("::TestFoo", "-k", "float16"),
legacy_cmds=("-k", "TestFoo and float16"),
passed=1,
failed=1,
skipped=2,
),
Config(
"1testcase-*test-1dtype2",
new_cmds=("::TestFoo", "-k", "int32"),
legacy_cmds=("-k", "TestFoo and int32"),
passed=1,
skipped=3,
),
Config(
"1testcase-1test1-*dtype",
new_cmds="::TestFoo::test_bar",
legacy_cmds=("-k", "TestFoo and test_bar"),
passed=1,
skipped=2,
failed=1,
),
Config(
"1testcase-1test2-*dtype",
new_cmds="::TestFoo::test_baz",
legacy_cmds=("-k", "TestFoo and test_baz"),
passed=1,
skipped=3,
),
Config(
"1testcase-1test-1dtype1",
new_cmds=("::TestFoo::test_bar", "-k", "float16"),
legacy_cmds=("-k", "TestFoo and test_bar and float16"),
skipped=1,
failed=1,
),
Config(
"1testcase-1test-1dtype2",
new_cmds=("::TestFoo::test_bar", "-k", "int32"),
legacy_cmds=("-k", "TestFoo and test_bar and int32"),
passed=1,
skipped=1,
),
)
def test_dtype(testdir, file, cmds, outcomes):
testdir.copy_example(file)
result = testdir.runpytest(*cmds)
result.assert_outcomes(**outcomes)

0 comments on commit 0bae345

Please sign in to comment.