diff --git a/src/Test/QuickCheck/Arbitrary.hs b/src/Test/QuickCheck/Arbitrary.hs index f3595d66..27d6fb69 100644 --- a/src/Test/QuickCheck/Arbitrary.hs +++ b/src/Test/QuickCheck/Arbitrary.hs @@ -159,6 +159,7 @@ import qualified Data.Map as Map import qualified Data.IntSet as IntSet import qualified Data.IntMap as IntMap import qualified Data.Sequence as Sequence +import qualified Data.Tree as Tree import Data.Bits import qualified Data.Monoid as Monoid @@ -821,6 +822,35 @@ instance Arbitrary1 Sequence.Seq where instance Arbitrary a => Arbitrary (Sequence.Seq a) where arbitrary = arbitrary1 shrink = shrink1 +instance Arbitrary1 Tree.Tree where + liftArbitrary arb = sized $ \n -> do + k <- chooseInt (0, n) + go k + where + go n = do -- n is the size of the trees. + value <- arb + pars <- arbPartition (n - 1) -- can go negative! + forest <- mapM go pars + return $ Tree.Node value forest + + arbPartition :: Int -> Gen [Int] + arbPartition k = case compare k 1 of + LT -> pure [] + EQ -> pure [1] + GT -> do + first <- chooseInt (1, k) + rest <- arbPartition $ k - first + shuffle (first : rest) + + liftShrink shr = go + where + go (Tree.Node val forest) = forest ++ + [ Tree.Node e fs + | (e, fs) <- liftShrink2 shr (liftShrink go) (val, forest) + ] +instance Arbitrary a => Arbitrary (Tree.Tree a) where + arbitrary = arbitrary1 + shrink = shrink1 -- Arbitrary instance for Ziplist instance Arbitrary1 ZipList where @@ -1360,6 +1390,8 @@ instance CoArbitrary a => CoArbitrary (IntMap.IntMap a) where coarbitrary = coarbitrary . IntMap.toList instance CoArbitrary a => CoArbitrary (Sequence.Seq a) where coarbitrary = coarbitrary . toList +instance CoArbitrary a => CoArbitrary (Tree.Tree a) where + coarbitrary (Tree.Node val forest) = coarbitrary val . coarbitrary forest -- CoArbitrary instance for Ziplist instance CoArbitrary a => CoArbitrary (ZipList a) where diff --git a/src/Test/QuickCheck/Function.hs b/src/Test/QuickCheck/Function.hs index 11dc5d61..1139e84a 100644 --- a/src/Test/QuickCheck/Function.hs +++ b/src/Test/QuickCheck/Function.hs @@ -76,6 +76,7 @@ import qualified Data.IntSet as IntSet import qualified Data.Map as Map import qualified Data.Set as Set import qualified Data.Sequence as Sequence +import qualified Data.Tree as Tree import Data.Int import Data.Complex import Data.Foldable(toList) @@ -339,6 +340,9 @@ instance Function a => Function (IntMap.IntMap a) where instance Function a => Function (Sequence.Seq a) where function = functionMap toList Sequence.fromList +instance Function a => Function (Tree.Tree a) where + function = functionMap (\(Tree.Node x xs) -> (x,xs)) (uncurry Tree.Node) + instance Function Int8 where function = functionBoundedEnum