{-# LANGUAGE FlexibleContexts #-}

{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
{-# OPTIONS_GHC -fno-warn-unused-top-binds #-}

-- |
-- Module      :  Numeric.GSL.Internal
-- Copyright   :  (c) Alberto Ruiz 2009
-- License     :  GPL
-- Maintainer  :  Alberto Ruiz
-- Stability   :  provisional
--
--
-- Auxiliary functions.
--


module Numeric.GSL.Internal(
    iv,
    mkVecfun,
    mkVecVecfun,
    mkDoubleVecVecfun,
    mkDoublefun,
    aux_vTov,
    mkVecMatfun,
    mkDoubleVecMatfun,
    aux_vTom,
    createV,
    createMIO,
    module Numeric.LinearAlgebra.Devel,
    check,(#),(#!),vec, ww2,
    Res,TV,TM,TCV,TCM
) where

import Numeric.LinearAlgebra.HMatrix
import Numeric.LinearAlgebra.Devel hiding (check)

import Foreign.Marshal.Array(copyArray)
import Foreign.Ptr(Ptr, FunPtr)
import Foreign.C.Types
import Foreign.C.String(peekCString)
import System.IO.Unsafe(unsafePerformIO)
import Data.Vector.Storable as V (unsafeWith,length)
import Control.Monad(when)

iv :: (Vector Double -> Double) -> (CInt -> Ptr Double -> Double)
iv :: (Vector Double -> Double) -> CInt -> Ptr Double -> Double
iv f :: Vector Double -> Double
f n :: CInt
n p :: Ptr Double
p = Vector Double -> Double
f (Int -> (CInt -> Ptr Double -> IO CInt) -> String -> Vector Double
forall a.
Storable a =>
Int -> (CInt -> Ptr a -> IO CInt) -> String -> Vector a
createV (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
n) CInt -> Ptr Double -> IO CInt
forall a b. (Integral a, Num b) => a -> Ptr Double -> IO b
copy "iv") where
    copy :: a -> Ptr Double -> IO b
copy n' :: a
n' q :: Ptr Double
q = do
        Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
q Ptr Double
p (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n')
        b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return 0

-- | conversion of Haskell functions into function pointers that can be used in the C side
foreign import ccall safe "wrapper"
    mkVecfun :: (CInt -> Ptr Double -> Double)
             -> IO( FunPtr (CInt -> Ptr Double -> Double))

foreign import ccall safe "wrapper"
    mkVecVecfun :: TVV -> IO (FunPtr TVV)

foreign import ccall safe "wrapper"
    mkDoubleVecVecfun :: (Double -> TVV) -> IO (FunPtr (Double -> TVV))

foreign import ccall safe "wrapper"
    mkDoublefun :: (Double -> Double) -> IO (FunPtr (Double -> Double))

aux_vTov :: (Vector Double -> Vector Double) -> TVV
aux_vTov :: (Vector Double -> Vector Double) -> TVV
aux_vTov f :: Vector Double -> Vector Double
f n :: CInt
n p :: Ptr Double
p nr :: CInt
nr r :: Ptr Double
r = IO CInt
g where
    v :: Vector Double
v = Vector Double -> Vector Double
f Vector Double
x
    x :: Vector Double
x = Int -> (CInt -> Ptr Double -> IO CInt) -> String -> Vector Double
forall a.
Storable a =>
Int -> (CInt -> Ptr a -> IO CInt) -> String -> Vector a
createV (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
n) CInt -> Ptr Double -> IO CInt
forall a b. (Integral a, Num b) => a -> Ptr Double -> IO b
copy "aux_vTov"
    copy :: a -> Ptr Double -> IO b
copy n' :: a
n' q :: Ptr Double
q = do
        Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
q Ptr Double
p (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n')
        b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return 0
    g :: IO CInt
g = do Vector Double -> (Ptr Double -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector Double
v ((Ptr Double -> IO ()) -> IO ()) -> (Ptr Double -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \p' :: Ptr Double
p' -> Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
r Ptr Double
p' (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
nr)
           CInt -> IO CInt
forall (m :: * -> *) a. Monad m => a -> m a
return 0

foreign import ccall safe "wrapper"
    mkVecMatfun :: TVM -> IO (FunPtr TVM)

foreign import ccall safe "wrapper"
    mkDoubleVecMatfun :: (Double -> TVM) -> IO (FunPtr (Double -> TVM))

aux_vTom :: (Vector Double -> Matrix Double) -> TVM
aux_vTom :: (Vector Double -> Matrix Double) -> TVM
aux_vTom f :: Vector Double -> Matrix Double
f n :: CInt
n p :: Ptr Double
p rr :: CInt
rr cr :: CInt
cr r :: Ptr Double
r = IO CInt
g where
    v :: Vector Double
v = Matrix Double -> Vector Double
forall t. Element t => Matrix t -> Vector t
flatten (Matrix Double -> Vector Double) -> Matrix Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ Vector Double -> Matrix Double
f Vector Double
x
    x :: Vector Double
x = Int -> (CInt -> Ptr Double -> IO CInt) -> String -> Vector Double
forall a.
Storable a =>
Int -> (CInt -> Ptr a -> IO CInt) -> String -> Vector a
createV (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
n) CInt -> Ptr Double -> IO CInt
forall a b. (Integral a, Num b) => a -> Ptr Double -> IO b
copy "aux_vTov"
    copy :: a -> Ptr Double -> IO b
copy n' :: a
n' q :: Ptr Double
q = do
        Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
q Ptr Double
p (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n')
        b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return 0
    g :: IO CInt
g = do Vector Double -> (Ptr Double -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector Double
v ((Ptr Double -> IO ()) -> IO ()) -> (Ptr Double -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \p' :: Ptr Double
p' -> Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
r Ptr Double
p' (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> CInt -> Int
forall a b. (a -> b) -> a -> b
$ CInt
rrCInt -> CInt -> CInt
forall a. Num a => a -> a -> a
*CInt
cr)
           CInt -> IO CInt
forall (m :: * -> *) a. Monad m => a -> m a
return 0

createV :: Int -> (CInt -> Ptr a -> IO CInt) -> String -> Vector a
createV n :: Int
n fun :: CInt -> Ptr a -> IO CInt
fun msg :: String
msg = IO (Vector a) -> Vector a
forall a. IO a -> a
unsafePerformIO (IO (Vector a) -> Vector a) -> IO (Vector a) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
    Vector a
r <- Int -> IO (Vector a)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
n
    (Vector a
r Vector a
-> (IO CInt -> IO CInt) -> TransRaw (Vector a) (IO CInt) -> IO CInt
forall c b r.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
# IO CInt -> IO CInt
forall a. a -> a
id) TransRaw (Vector a) (IO CInt)
CInt -> Ptr a -> IO CInt
fun IO CInt -> String -> IO ()
#| String
msg
    Vector a -> IO (Vector a)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector a
r

createMIO :: Int
-> Int
-> (CInt -> CInt -> Ptr a -> IO CInt)
-> String
-> IO (Matrix a)
createMIO r :: Int
r c :: Int
c fun :: CInt -> CInt -> Ptr a -> IO CInt
fun msg :: String
msg = do
    Matrix a
res <- MatrixOrder -> Int -> Int -> IO (Matrix a)
forall a. Storable a => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix MatrixOrder
RowMajor Int
r Int
c
    (Matrix a
res Matrix a
-> (IO CInt -> IO CInt) -> TransRaw (Matrix a) (IO CInt) -> IO CInt
forall c b r.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
# IO CInt -> IO CInt
forall a. a -> a
id) TransRaw (Matrix a) (IO CInt)
CInt -> CInt -> Ptr a -> IO CInt
fun IO CInt -> String -> IO ()
#| String
msg
    Matrix a -> IO (Matrix a)
forall (m :: * -> *) a. Monad m => a -> m a
return Matrix a
res

--------------------------------------------------------------------------------

-- | check the error code
check :: String -> IO CInt -> IO ()
check :: String -> IO CInt -> IO ()
check msg :: String
msg f :: IO CInt
f = do
    CInt
err <- IO CInt
f
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CInt
errCInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/=0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Ptr CChar
ps <- CInt -> IO (Ptr CChar)
gsl_strerror CInt
err
        String
s <- Ptr CChar -> IO String
peekCString Ptr CChar
ps
        String -> IO ()
forall a. HasCallStack => String -> a
error (String
msgString -> String -> String
forall a. [a] -> [a] -> [a]
++": "String -> String -> String
forall a. [a] -> [a] -> [a]
++String
s)
    () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | description of GSL error codes
foreign import ccall unsafe "gsl_strerror" gsl_strerror :: CInt -> IO (Ptr CChar)

type PF = Ptr Float
type PD = Ptr Double
type PQ = Ptr (Complex Float)
type PC = Ptr (Complex Double)

type Res = IO CInt
type TV x  = CInt -> PD -> x
type TM x  = CInt -> CInt -> PD -> x
type TCV x = CInt -> PC -> x
type TCM x = CInt -> CInt -> PC -> x

type TVV = TV (TV Res)
type TVM = TV (TM Res)

ww2 :: (t -> (t -> t) -> t)
-> t -> (t -> (t -> t) -> t) -> t -> (t -> t -> t) -> t
ww2 w1 :: t -> (t -> t) -> t
w1 o1 :: t
o1 w2 :: t -> (t -> t) -> t
w2 o2 :: t
o2 f :: t -> t -> t
f = t -> (t -> t) -> t
w1 t
o1 ((t -> t) -> t) -> (t -> t) -> t
forall a b. (a -> b) -> a -> b
$ \a1 :: t
a1 -> t -> (t -> t) -> t
w2 t
o2 ((t -> t) -> t) -> (t -> t) -> t
forall a b. (a -> b) -> a -> b
$ \a2 :: t
a2 -> t -> t -> t
f t
a1 t
a2

vec :: Vector a -> (((CInt -> Ptr a -> t) -> t) -> IO b) -> IO b
vec x :: Vector a
x f :: ((CInt -> Ptr a -> t) -> t) -> IO b
f = Vector a -> (Ptr a -> IO b) -> IO b
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector a
x ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \p :: Ptr a
p -> do
    let v :: (CInt -> Ptr a -> t) -> t
v g :: CInt -> Ptr a -> t
g = CInt -> Ptr a -> t
g (Int -> CInt
fi (Int -> CInt) -> Int -> CInt
forall a b. (a -> b) -> a -> b
$ Vector a -> Int
forall a. Storable a => Vector a -> Int
V.length Vector a
x) Ptr a
p
    ((CInt -> Ptr a -> t) -> t) -> IO b
f (CInt -> Ptr a -> t) -> t
forall t. (CInt -> Ptr a -> t) -> t
v
{-# INLINE vec #-}

infixl 1 #
a :: c
a # :: c -> (b -> IO r) -> TransRaw c b -> IO r
# b :: b -> IO r
b = c -> (b -> IO r) -> TransRaw c b -> IO r
forall c b r.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
applyRaw c
a b -> IO r
b
{-# INLINE (#) #-}

--infixr 1 #
--a # b = apply a b
--{-# INLINE (#) #-}

a :: c
a #! :: c -> (b -> IO r) -> TransRaw (TransRaw c b -> IO r) (IO r) -> IO r
#! b :: b -> IO r
b = c
a c -> (b -> IO r) -> TransRaw c b -> IO r
forall c b r.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
# b -> IO r
b (TransRaw c b -> IO r)
-> (IO r -> IO r) -> TransRaw (TransRaw c b -> IO r) (IO r) -> IO r
forall c b r.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
# IO r -> IO r
forall a. a -> a
id
{-# INLINE (#!) #-}