diff --git a/hedgehog/src/Hedgehog/Internal/Gen.hs b/hedgehog/src/Hedgehog/Internal/Gen.hs index f181037a..b85aaeb7 100644 --- a/hedgehog/src/Hedgehog/Internal/Gen.hs +++ b/hedgehog/src/Hedgehog/Internal/Gen.hs @@ -1,5 +1,6 @@ {-# OPTIONS_HADDOCK not-home #-} {-# LANGUAGE ApplicativeDo #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} @@ -12,6 +13,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -202,6 +204,8 @@ import Data.Coerce (coerce) import Data.Foldable (for_, toList) import Data.Functor.Identity (Identity(..)) import Data.Int (Int8, Int16, Int32, Int64) +import qualified Data.IntMap.Strict as IM +import qualified Data.List as List import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as NonEmpty import Data.Map (Map) @@ -1170,28 +1174,38 @@ choice = \case -- -- This generator shrinks towards the first generator in the list. -- --- /The input list must be non-empty./ +-- /The sum of the frequencies must be at least @1@ and at most @'maxBound' :: 'Int'@. +-- No frequency may be negative./ -- frequency :: MonadGen m => [(Int, m a)] -> m a -frequency = \case - [] -> - error "Hedgehog.Gen.frequency: used with empty list" - xs0 -> do - let - pick n = \case - [] -> - error "Hedgehog.Gen.frequency/pick: used with empty list" - (k, x) : xs -> - if n <= k then - x - else - pick (n - k) xs - - total = - sum (fmap fst xs0) - +-- We calculate a running sum of the individual frequencies and build +-- an IntMap mapping the results to the generators. This makes the +-- resulting generator much faster than a naive list-based one when +-- the input list is long, and not much slower when it's short. +frequency xs0 = + do n <- integral $ Range.constant 1 total - pick n xs0 + case IM.lookupGE n sum_map of + Just (_, a) -> a + Nothing -> error "Hedgehog.Gen.frequency: Something went wrong." + where + --[(1, x), (7, y), (10, z)] In + --[(1, x), (8, y), (18, z)] Out + sum_map = IM.fromDistinctAscList $ List.unfoldr go (0, xs0) + where + go (_, []) = Nothing + go (n, (k, x) : xs) + | k < 0 = error "Hedgehog.Gen.frequency: Negative frequency." + -- nk < 0 means the sum overflowed. + | nk < 0 = error "Hedgehog.Gen.frequency: Frequency sum above maxBound :: Int" + | k > 0 = Just ((nk, x), (nk, xs)) + | otherwise = go (n, xs) + where !nk = n + fromIntegral k + total + | Just (mx, _) <- IM.lookupMax sum_map + = mx + | otherwise + = error "Hedgehog.Gen.frequency: frequencies sum to zero" -- | Modifies combinators which choose from a list of generators, like 'choice' -- or 'frequency', so that they can be used in recursive scenarios.