{-# LANGUAGE GADTs, MultiParamTypeClasses, FlexibleInstances, FlexibleContexts #-}
module Data.Random.Distribution.Multinomial where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Binomial
multinomial :: Distribution (Multinomial p) [a] => [p] -> a -> RVar [a]
multinomial :: forall p a.
Distribution (Multinomial p) [a] =>
[p] -> a -> RVar [a]
multinomial [p]
ps a
n = Multinomial p [a] -> RVar [a]
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar ([p] -> a -> Multinomial p [a]
forall p a. [p] -> a -> Multinomial p [a]
Multinomial [p]
ps a
n)
multinomialT :: Distribution (Multinomial p) [a] => [p] -> a -> RVarT m [a]
multinomialT :: forall p a (m :: * -> *).
Distribution (Multinomial p) [a] =>
[p] -> a -> RVarT m [a]
multinomialT [p]
ps a
n = Multinomial p [a] -> RVarT m [a]
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT ([p] -> a -> Multinomial p [a]
forall p a. [p] -> a -> Multinomial p [a]
Multinomial [p]
ps a
n)
data Multinomial p a where
Multinomial :: [p] -> a -> Multinomial p [a]
instance (Num a, Eq a, Fractional p, Distribution (Binomial p) a) => Distribution (Multinomial p) [a] where
rvarT :: forall (n :: * -> *). Multinomial p [a] -> RVarT n [a]
rvarT (Multinomial [p]
ps0 a
t) = a -> [p] -> [p] -> ([a] -> [a]) -> RVarT n [a]
forall {a} {b} {c} {m :: * -> *}.
(Eq a, Distribution (Binomial b) a, Fractional b, Num a) =>
a -> [b] -> [b] -> ([a] -> c) -> RVarT m c
go a
t [p]
ps0 ([p] -> [p]
forall {a}. Num a => [a] -> [a]
tailSums [p]
ps0) [a] -> [a]
forall a. a -> a
id
where
go :: a -> [b] -> [b] -> ([a] -> c) -> RVarT m c
go a
_ [] [b]
_ [a] -> c
f = c -> RVarT m c
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> c
f [])
go a
n [b
_] [b]
_ [a] -> c
f = c -> RVarT m c
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> c
f [a
n])
go a
0 (b
_:[b]
ps) (b
_ :[b]
psums) [a] -> c
f = a -> [b] -> [b] -> ([a] -> c) -> RVarT m c
go a
0 [b]
ps [b]
psums ([a] -> c
f ([a] -> c) -> ([a] -> [a]) -> [a] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
0a -> [a] -> [a]
forall a. a -> [a] -> [a]
:))
go a
n (b
p:[b]
ps) (b
psum:[b]
psums) [a] -> c
f = do
a
x <- a -> b -> RVarT m a
forall b a (m :: * -> *).
Distribution (Binomial b) a =>
a -> b -> RVarT m a
binomialT a
n (b
p b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
psum)
a -> [b] -> [b] -> ([a] -> c) -> RVarT m c
go (a
na -> a -> a
forall a. Num a => a -> a -> a
-a
x) [b]
ps [b]
psums ([a] -> c
f ([a] -> c) -> ([a] -> [a]) -> [a] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))
go a
_ [b]
_ [b]
_ [a] -> c
_ = [Char] -> RVarT m c
forall a. HasCallStack => [Char] -> a
error [Char]
"rvar/Multinomial: programming error! this case should be impossible!"
tailSums :: [a] -> [a]
tailSums [] = [a
0]
tailSums (a
x:[a]
xs) = case [a] -> [a]
tailSums [a]
xs of
(a
s:[a]
rest) -> (a
xa -> a -> a
forall a. Num a => a -> a -> a
+a
s)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:a
sa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
rest
[a]
_ -> [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"rvar/Multinomial/tailSums: programming error! this case should be impossible!"