From 3113b0197bdfb60be5c66d8560c38cca5becd525 Mon Sep 17 00:00:00 2001 From: Dennis Jabs Date: Wed, 15 Nov 2023 11:28:31 +0100 Subject: [PATCH] D. Jabs: - Adjusted unittests in categorical --- tests/hp/test_categorical.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/hp/test_categorical.py b/tests/hp/test_categorical.py index f9292c2..8620c1c 100644 --- a/tests/hp/test_categorical.py +++ b/tests/hp/test_categorical.py @@ -45,11 +45,13 @@ def setUp(self): self.hp3 = Categorical(name=self.name, choices=self.choices, default=self.default_X1, distribution=None) # Tests with default=None and distribution=None self.hp4 = Categorical(name=self.name, choices=self.choices, default=None, distribution=None) + # Test with default=None, shape=None, distribution=None + self.hp5 = Categorical(name=self.name, choices=self.choices, default=None, distribution=None, shape=None) # Test with all options - self.hp5 = Categorical(name=self.name, choices=self.choices5, default=self.default5, + self.hp6 = Categorical(name=self.name, choices=self.choices5, default=self.default5, distribution=self.distribution2, shape=self.shape5) # Test with no shape and default given - self.hp6 = Categorical(name=self.name, choices=self.choices5, default=None, distribution=self.distribution2, + self.hp7 = Categorical(name=self.name, choices=self.choices5, default=None, distribution=self.distribution2, shape=None) def test_name(self): @@ -62,6 +64,7 @@ def test_name(self): self.assertEqual(self.name, self.hp4._name) self.assertEqual(self.name, self.hp5._name) self.assertEqual(self.name, self.hp6._name) + self.assertEqual(self.name, self.hp7._name) def test_shape(self): """ @@ -71,8 +74,9 @@ def test_shape(self): self.assertEqual(self.shape, self.hp2._shape) self.assertEqual(self.shape, self.hp3._shape) self.assertEqual(self.shape, self.hp4._shape) - self.assertEqual(self.shape5, self.hp5._shape) + self.assertEqual(self.shape, self.hp5._shape) self.assertEqual(self.shape5, self.hp6._shape) + self.assertEqual(self.shape5, self.hp7._shape) def test_choices(self): """ @@ -82,8 +86,9 @@ def test_choices(self): self.assertTrue(np.array_equal(self.choices, self.hp2._choices)) self.assertTrue(np.array_equal(self.choices, self.hp3._choices)) self.assertTrue(np.array_equal(self.choices, self.hp4._choices)) - self.assertTrue(np.array_equal(self.choices5, self.hp5._choices)) + self.assertTrue(np.array_equal(self.choices, self.hp5._choices)) self.assertTrue(np.array_equal(self.choices5, self.hp6._choices)) + self.assertTrue(np.array_equal(self.choices5, self.hp7._choices)) def test_default(self): """ @@ -93,8 +98,9 @@ def test_default(self): self.assertEqual(self.default_X5, self.hp2._default) self.assertEqual(self.default_X1, self.hp3._default) self.assertEqual(self.default_X1, self.hp4._default) - self.assertTrue(np.all(self.default5 == self.hp5._default)) + self.assertEqual(self.default_X1, self.hp5._default) self.assertTrue(np.all(self.default5 == self.hp6._default)) + self.assertTrue(np.all(self.default5 == self.hp7._default)) def test_distribution(self): """ @@ -106,6 +112,7 @@ def test_distribution(self): self.assertIsInstance(self.hp4._distribution, Choice) self.assertIsInstance(self.hp5._distribution, Choice) self.assertIsInstance(self.hp6._distribution, Choice) + self.assertIsInstance(self.hp7._distribution, Choice) def test_get_name(self): """ @@ -117,6 +124,7 @@ def test_get_name(self): self.assertEqual(self.name, self.hp4.get_name()) self.assertEqual(self.name, self.hp5.get_name()) self.assertEqual(self.name, self.hp6.get_name()) + self.assertEqual(self.name, self.hp7.get_name()) def test_get_default(self): """ @@ -126,8 +134,9 @@ def test_get_default(self): self.assertEqual(self.default_X5, self.hp2.get_default()) self.assertEqual(self.default_X1, self.hp3.get_default()) self.assertEqual(self.default_X1, self.hp4.get_default()) - self.assertTrue(np.all(self.default5 == self.hp5.get_default())) + self.assertEqual(self.default_X1, self.hp5.get_default()) self.assertTrue(np.all(self.default5 == self.hp6.get_default())) + self.assertTrue(np.all(self.default5 == self.hp7.get_default())) def test_get_shape(self): """ @@ -137,8 +146,9 @@ def test_get_shape(self): self.assertEqual(self.shape, self.hp2.get_shape()) self.assertEqual(self.shape, self.hp3.get_shape()) self.assertEqual(self.shape, self.hp4.get_shape()) - self.assertEqual(self.shape5, self.hp5.get_shape()) + self.assertEqual(self.shape, self.hp5.get_shape()) self.assertEqual(self.shape5, self.hp6.get_shape()) + self.assertEqual(self.shape5, self.hp7.get_shape()) def test_get_choices(self): """ @@ -148,8 +158,9 @@ def test_get_choices(self): self.assertTrue(np.array_equal(self.choices, self.hp2.get_choices())) self.assertTrue(np.array_equal(self.choices, self.hp3.get_choices())) self.assertTrue(np.array_equal(self.choices, self.hp4.get_choices())) - self.assertTrue(np.array_equal(self.choices5, self.hp5.get_choices())) + self.assertTrue(np.array_equal(self.choices, self.hp5.get_choices())) self.assertTrue(np.array_equal(self.choices5, self.hp6.get_choices())) + self.assertTrue(np.array_equal(self.choices5, self.hp7.get_choices())) def test_get_distribution(self): """ @@ -161,6 +172,7 @@ def test_get_distribution(self): self.assertIsInstance(self.hp4.get_distribution(), Choice) self.assertIsInstance(self.hp5.get_distribution(), Choice) self.assertIsInstance(self.hp6.get_distribution(), Choice) + self.assertIsInstance(self.hp7.get_distribution(), Choice) def test_change_distribution(self): """ @@ -193,10 +205,14 @@ def test_sample(self): self.assertEqual(self.size, len(sample5)) self.assertTrue(s in self.choices for s in sample5) - sample6 = self.hp5.sample(random=self.random, size=self.size) + sample6 = self.hp6.sample(random=self.random, size=self.size) self.assertEqual(self.size, len(sample6)) self.assertTrue(s in self.choices for s in sample6) + sample7 = self.hp6.sample(random=self.random, size=self.size) + self.assertEqual(self.size, len(sample7)) + self.assertTrue(s in self.choices for s in sample7) + def test_valid_configuration(self): """ Tests the method valid_configuration(). @@ -214,6 +230,7 @@ def test_eq(self): self.assertNotEqual(self.hp, self.hp4) self.assertNotEqual(self.hp, self.hp5) self.assertNotEqual(self.hp, self.hp6) + self.assertNotEqual(self.hp, self.hp7) def test_hash(self): """ @@ -225,6 +242,7 @@ def test_hash(self): self.assertNotEqual(hash(self.hp), hash(self.hp4)) self.assertNotEqual(hash(self.hp), hash(self.hp5)) self.assertNotEqual(hash(self.hp), hash(self.hp6)) + self.assertNotEqual(hash(self.hp), hash(self.hp7)) def test_set_get_state(self): """