Skip to content

Commit

Permalink
D. Jabs:
Browse files Browse the repository at this point in the history
- Adjusted unittests in categorical
  • Loading branch information
Dennis Jabs committed Nov 15, 2023
1 parent 9f964f0 commit 3113b01
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions tests/hp/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def setUp(self):
self.hp3 = Categorical(name=self.name, choices=self.choices, default=self.default_X1, distribution=None)
# Tests with default=None and distribution=None
self.hp4 = Categorical(name=self.name, choices=self.choices, default=None, distribution=None)
# Test with default=None, shape=None, distribution=None
self.hp5 = Categorical(name=self.name, choices=self.choices, default=None, distribution=None, shape=None)
# Test with all options
self.hp5 = Categorical(name=self.name, choices=self.choices5, default=self.default5,
self.hp6 = Categorical(name=self.name, choices=self.choices5, default=self.default5,
distribution=self.distribution2, shape=self.shape5)
# Test with no shape and default given
self.hp6 = Categorical(name=self.name, choices=self.choices5, default=None, distribution=self.distribution2,
self.hp7 = Categorical(name=self.name, choices=self.choices5, default=None, distribution=self.distribution2,
shape=None)

def test_name(self):
Expand All @@ -62,6 +64,7 @@ def test_name(self):
self.assertEqual(self.name, self.hp4._name)
self.assertEqual(self.name, self.hp5._name)
self.assertEqual(self.name, self.hp6._name)
self.assertEqual(self.name, self.hp7._name)

def test_shape(self):
"""
Expand All @@ -71,8 +74,9 @@ def test_shape(self):
self.assertEqual(self.shape, self.hp2._shape)
self.assertEqual(self.shape, self.hp3._shape)
self.assertEqual(self.shape, self.hp4._shape)
self.assertEqual(self.shape5, self.hp5._shape)
self.assertEqual(self.shape, self.hp5._shape)
self.assertEqual(self.shape5, self.hp6._shape)
self.assertEqual(self.shape5, self.hp7._shape)

def test_choices(self):
"""
Expand All @@ -82,8 +86,9 @@ def test_choices(self):
self.assertTrue(np.array_equal(self.choices, self.hp2._choices))
self.assertTrue(np.array_equal(self.choices, self.hp3._choices))
self.assertTrue(np.array_equal(self.choices, self.hp4._choices))
self.assertTrue(np.array_equal(self.choices5, self.hp5._choices))
self.assertTrue(np.array_equal(self.choices, self.hp5._choices))
self.assertTrue(np.array_equal(self.choices5, self.hp6._choices))
self.assertTrue(np.array_equal(self.choices5, self.hp7._choices))

def test_default(self):
"""
Expand All @@ -93,8 +98,9 @@ def test_default(self):
self.assertEqual(self.default_X5, self.hp2._default)
self.assertEqual(self.default_X1, self.hp3._default)
self.assertEqual(self.default_X1, self.hp4._default)
self.assertTrue(np.all(self.default5 == self.hp5._default))
self.assertEqual(self.default_X1, self.hp5._default)
self.assertTrue(np.all(self.default5 == self.hp6._default))
self.assertTrue(np.all(self.default5 == self.hp7._default))

def test_distribution(self):
"""
Expand All @@ -106,6 +112,7 @@ def test_distribution(self):
self.assertIsInstance(self.hp4._distribution, Choice)
self.assertIsInstance(self.hp5._distribution, Choice)
self.assertIsInstance(self.hp6._distribution, Choice)
self.assertIsInstance(self.hp7._distribution, Choice)

def test_get_name(self):
"""
Expand All @@ -117,6 +124,7 @@ def test_get_name(self):
self.assertEqual(self.name, self.hp4.get_name())
self.assertEqual(self.name, self.hp5.get_name())
self.assertEqual(self.name, self.hp6.get_name())
self.assertEqual(self.name, self.hp7.get_name())

def test_get_default(self):
"""
Expand All @@ -126,8 +134,9 @@ def test_get_default(self):
self.assertEqual(self.default_X5, self.hp2.get_default())
self.assertEqual(self.default_X1, self.hp3.get_default())
self.assertEqual(self.default_X1, self.hp4.get_default())
self.assertTrue(np.all(self.default5 == self.hp5.get_default()))
self.assertEqual(self.default_X1, self.hp5.get_default())
self.assertTrue(np.all(self.default5 == self.hp6.get_default()))
self.assertTrue(np.all(self.default5 == self.hp7.get_default()))

def test_get_shape(self):
"""
Expand All @@ -137,8 +146,9 @@ def test_get_shape(self):
self.assertEqual(self.shape, self.hp2.get_shape())
self.assertEqual(self.shape, self.hp3.get_shape())
self.assertEqual(self.shape, self.hp4.get_shape())
self.assertEqual(self.shape5, self.hp5.get_shape())
self.assertEqual(self.shape, self.hp5.get_shape())
self.assertEqual(self.shape5, self.hp6.get_shape())
self.assertEqual(self.shape5, self.hp7.get_shape())

def test_get_choices(self):
"""
Expand All @@ -148,8 +158,9 @@ def test_get_choices(self):
self.assertTrue(np.array_equal(self.choices, self.hp2.get_choices()))
self.assertTrue(np.array_equal(self.choices, self.hp3.get_choices()))
self.assertTrue(np.array_equal(self.choices, self.hp4.get_choices()))
self.assertTrue(np.array_equal(self.choices5, self.hp5.get_choices()))
self.assertTrue(np.array_equal(self.choices, self.hp5.get_choices()))
self.assertTrue(np.array_equal(self.choices5, self.hp6.get_choices()))
self.assertTrue(np.array_equal(self.choices5, self.hp7.get_choices()))

def test_get_distribution(self):
"""
Expand All @@ -161,6 +172,7 @@ def test_get_distribution(self):
self.assertIsInstance(self.hp4.get_distribution(), Choice)
self.assertIsInstance(self.hp5.get_distribution(), Choice)
self.assertIsInstance(self.hp6.get_distribution(), Choice)
self.assertIsInstance(self.hp7.get_distribution(), Choice)

def test_change_distribution(self):
"""
Expand Down Expand Up @@ -193,10 +205,14 @@ def test_sample(self):
self.assertEqual(self.size, len(sample5))
self.assertTrue(s in self.choices for s in sample5)

sample6 = self.hp5.sample(random=self.random, size=self.size)
sample6 = self.hp6.sample(random=self.random, size=self.size)
self.assertEqual(self.size, len(sample6))
self.assertTrue(s in self.choices for s in sample6)

sample7 = self.hp6.sample(random=self.random, size=self.size)
self.assertEqual(self.size, len(sample7))
self.assertTrue(s in self.choices for s in sample7)

def test_valid_configuration(self):
"""
Tests the method valid_configuration().
Expand All @@ -214,6 +230,7 @@ def test_eq(self):
self.assertNotEqual(self.hp, self.hp4)
self.assertNotEqual(self.hp, self.hp5)
self.assertNotEqual(self.hp, self.hp6)
self.assertNotEqual(self.hp, self.hp7)

def test_hash(self):
"""
Expand All @@ -225,6 +242,7 @@ def test_hash(self):
self.assertNotEqual(hash(self.hp), hash(self.hp4))
self.assertNotEqual(hash(self.hp), hash(self.hp5))
self.assertNotEqual(hash(self.hp), hash(self.hp6))
self.assertNotEqual(hash(self.hp), hash(self.hp7))

def test_set_get_state(self):
"""
Expand Down

0 comments on commit 3113b01

Please sign in to comment.