Make vectorTH use singleton when possible.
[matthijs/master-project/support/tfvec.git] / Data / Param / TFVec.hs
1 ------------------------------------------------------------------------------
2 -- |
3 -- Module       : Data.Param.TFVec
4 -- Copyright    : (c) 2009 Christiaan Baaij
5 -- Licence      : BSD-style (see the file LICENCE)
6 --
7 -- Maintainer   : christiaan.baaij@gmail.com
8 -- Stability    : experimental
9 -- Portability  : non-portable
10 --
11 -- 'TFVec': Fixed sized vectors. Vectors with numerically parameterized size,
12 --          using type-level numerals from 'tfp' library
13 --
14 ------------------------------------------------------------------------------
15
16 module Data.Param.TFVec
17   ( TFVec
18   , empty
19   , (+>)
20   , singleton
21   , vectorCPS
22   , vectorTH
23   , unsafeVector
24   , readTFVec
25   , length
26   , lengthT
27   , fromVector
28   , null
29   , (!)
30   , replace
31   , head
32   , last
33   , init
34   , tail
35   , take
36   , drop
37   , select
38   , (<+)
39   , (++)
40   , map
41   , zipWith
42   , foldl
43   , foldr
44   , zip
45   , unzip
46   , shiftl
47   , shiftr
48   , rotl
49   , rotr
50   , concat
51   , reverse
52   , iterate
53   , iteraten
54   , generate
55   , generaten
56   , copy
57   , copyn
58   ) where
59     
60 import Types
61 import Types.Data.Num
62 import Types.Data.Num.Decimal.Literals.TH
63 import Data.RangedWord
64
65 import Data.Generics (Data)
66 import Data.Typeable
67 import qualified Prelude as P
68 import Prelude hiding (
69   null, length, head, tail, last, init, take, drop, (++), map, foldl, foldr,
70   zipWith, zip, unzip, concat, reverse, iterate )
71 import qualified Data.Foldable as DF (Foldable, foldr)
72 import qualified Data.Traversable as DT (Traversable(traverse))
73 import Language.Haskell.TH hiding (Pred)
74 import Language.Haskell.TH.Syntax (Lift(..))
75
76 import Language.Haskell.TH.TypeLib
77
78 newtype (NaturalT s) => TFVec s a = TFVec {unTFVec :: [a]}
79   deriving (Eq, Typeable)
80
81 deriving instance (NaturalT s, Typeable s, Data s, Typeable a, Data a) => Data (TFVec s a)
82
83 -- ==========================
84 -- = Constructing functions =
85 -- ==========================
86                                                   
87 empty :: TFVec D0 a
88 empty = TFVec []
89
90 (+>) :: a -> TFVec s a -> TFVec (Succ s) a
91 x +> (TFVec xs) = TFVec (x:xs)
92
93 infix 5 +>
94
95 singleton :: a -> TFVec D1 a
96 singleton x = x +> empty
97
98 vectorCPS :: [a] -> (forall s . NaturalT s => TFVec s a -> w) -> w
99 vectorCPS xs = unsafeVectorCPS (toInteger (P.length xs)) xs
100
101 -- FIXME: Not the most elegant solution... but it works for now in clash
102 vectorTH :: (Lift a, Typeable a) => [a] -> ExpQ
103 -- vectorTH xs = sigE [| (TFVec xs) |] (decTFVecT (toInteger (P.length xs)) xs)
104 vectorTH [] = [| empty |]
105 vectorTH [x] = [| singleton x |]
106 vectorTH (x:xs) = [| x +> $(vectorTH xs) |]
107
108 unsafeVector :: NaturalT s => s -> [a] -> TFVec s a
109 unsafeVector l xs
110   | fromIntegerT l /= P.length xs =
111     error (show 'unsafeVector P.++ ": dynamic/static lenght mismatch")
112   | otherwise = TFVec xs
113
114 readTFVec :: (Read a, NaturalT s) => String -> TFVec s a
115 readTFVec = read
116
117 readTFVecCPS :: Read a => String -> (forall s . NaturalT s => TFVec s a -> w) -> w
118 readTFVecCPS str = unsafeVectorCPS (toInteger l) xs
119  where fName = show 'readTFVecCPS
120        (xs,l) = case [(xs,l) | (xs,l,rest) <- readTFVecList str,  
121                            ("","") <- lexTFVec rest] of
122                        [(xs,l)] -> (xs,l)
123                        []   -> error (fName P.++ ": no parse")
124                        _    -> error (fName P.++ ": ambiguous parse")
125         
126 -- =======================
127 -- = Observing functions =
128 -- =======================
129 length :: forall s a . NaturalT s => TFVec s a -> Int
130 length _ = fromIntegerT (undefined :: s)
131
132 lengthT :: NaturalT s => TFVec s a -> s
133 lengthT = undefined
134
135 fromVector :: NaturalT s => TFVec s a -> [a]
136 fromVector (TFVec xs) = xs
137
138 null :: TFVec D0 a -> Bool
139 null _ = True
140
141 (!) ::  ( PositiveT s
142         , NaturalT u
143         , (s :>: u) ~ True) => TFVec s a -> RangedWord u -> a
144 (TFVec xs) ! i = xs !! (fromInteger (toInteger i))
145
146 -- ==========================
147 -- = Transforming functions =
148 -- ==========================
149 replace :: (PositiveT s, NaturalT u, (s :>: u) ~ True) =>
150   TFVec s a -> RangedWord u -> a -> TFVec s a
151 replace (TFVec xs) i y = TFVec $ replace' xs (toInteger i) y
152   where replace' []     _ _ = []
153         replace' (_:xs) 0 y = (y:xs)
154         replace' (x:xs) n y = x : (replace' xs (n-1) y)
155   
156 head :: PositiveT s => TFVec s a -> a
157 head = P.head . unTFVec
158
159 tail :: PositiveT s => TFVec s a -> TFVec (Pred s) a
160 tail = liftV P.tail
161
162 last :: PositiveT s => TFVec s a -> a
163 last = P.last . unTFVec
164
165 init :: PositiveT s => TFVec s a -> TFVec (Pred s) a
166 init = liftV P.init
167
168 take :: NaturalT i => i -> TFVec s a -> TFVec (Min s i) a
169 take i = liftV $ P.take (fromIntegerT i)
170
171 drop :: NaturalT i => i -> TFVec s a -> TFVec (s :-: (Min s i)) a
172 drop i = liftV $ P.drop (fromIntegerT i)
173
174 select :: (NaturalT f, NaturalT s, NaturalT n, (f :<: i) ~ True, 
175           (((s :*: n) :+: f) :<=: i) ~ True) => 
176           f -> s -> n -> TFVec i a -> TFVec n a
177 select f s n = liftV (select' f' s' n')
178   where (f', s', n') = (fromIntegerT f, fromIntegerT s, fromIntegerT n)
179         select' f s n = ((selectFirst0 s n).(P.drop f))
180         selectFirst0 :: Int -> Int -> [a] -> [a]
181         selectFirst0 s n l@(x:_)
182           | n > 0 = x : selectFirst0 s (n-1) (P.drop s l)
183           | otherwise = []
184         selectFirst0 _ 0 [] = []
185
186 (<+) :: TFVec s a -> a -> TFVec (Succ s) a
187 (<+) (TFVec xs) x = TFVec (xs P.++ [x])
188
189 (++) :: TFVec s a -> TFVec s2 a -> TFVec (s :+: s2) a
190 (++) = liftV2 (P.++)
191
192 infixl 5 <+
193 infixr 5 ++
194
195 map :: (a -> b) -> TFVec s a -> TFVec s b
196 map f = liftV (P.map f)
197
198 zipWith :: (a -> b -> c) -> TFVec s a -> TFVec s b -> TFVec s c
199 zipWith f = liftV2 (P.zipWith f)
200
201 foldl :: (a -> b -> a) -> a -> TFVec s b -> a
202 foldl f e = (P.foldl f e) . unTFVec
203
204 foldr :: (b -> a -> a) -> a -> TFVec s b -> a
205 foldr f e = (P.foldr f e) . unTFVec
206
207 zip :: TFVec s a -> TFVec s b -> TFVec s (a, b)
208 zip = liftV2 P.zip
209
210 unzip :: TFVec s (a, b) -> (TFVec s a, TFVec s b)
211 unzip (TFVec xs) = let (a,b) = P.unzip xs in (TFVec a, TFVec b)
212
213 shiftl :: (PositiveT s, NaturalT n, n ~ Pred s, s ~ Succ n) => 
214           TFVec s a -> a -> TFVec s a
215 shiftl xs x = x +> init xs
216
217 shiftr :: (PositiveT s, NaturalT n, n ~ Pred s, s ~ Succ n) => 
218           TFVec s a -> a -> TFVec s a
219 shiftr xs x = tail xs <+ x
220   
221 rotl :: forall s a . NaturalT s => TFVec s a -> TFVec s a
222 rotl = liftV rotl'
223   where vlen = fromIntegerT (undefined :: s)
224         rotl' [] = []
225         rotl' xs = let (i,[l]) = splitAt (vlen - 1) xs
226                    in l : i 
227
228 rotr :: NaturalT s => TFVec s a -> TFVec s a
229 rotr = liftV rotr'
230   where
231     rotr' [] = []
232     rotr' (x:xs) = xs P.++ [x] 
233
234 concat :: TFVec s1 (TFVec s2 a) -> TFVec (s1 :*: s2) a
235 concat = liftV (P.foldr ((P.++).unTFVec) [])
236
237 reverse :: TFVec s a -> TFVec s a
238 reverse = liftV P.reverse
239
240 iterate :: NaturalT s => (a -> a) -> a -> TFVec s a
241 iterate = iteraten (undefined :: s)
242
243 iteraten :: NaturalT s => s -> (a -> a) -> a -> TFVec s a
244 iteraten s f x = let s' = fromIntegerT s in TFVec (P.take s' $ P.iterate f x)
245
246 generate :: NaturalT s => (a -> a) -> a -> TFVec s a
247 generate = generaten (undefined :: s)
248
249 generaten :: NaturalT s => s -> (a -> a) -> a -> TFVec s a
250 generaten s f x = let s' = fromIntegerT s in TFVec (P.take s' $ P.tail $ P.iterate f x)
251
252 copy :: NaturalT s => a -> TFVec s a
253 copy x = copyn (undefined :: s) x
254
255 copyn :: NaturalT s => s -> a -> TFVec s a
256 copyn s x = iteraten s id x
257
258 -- =============
259 -- = Instances =
260 -- =============
261 instance Show a => Show (TFVec s a) where
262   showsPrec _ = showV.unTFVec
263     where showV []      = showString "<>"
264           showV (x:xs)  = showChar '<' . shows x . showl xs
265                             where showl []      = showChar '>'
266                                   showl (x:xs)  = showChar ',' . shows x .
267                                                   showl xs
268
269 instance (Read a, NaturalT nT) => Read (TFVec nT a) where
270   readsPrec _ str
271     | all fitsLength possibilities = P.map toReadS possibilities
272     | otherwise = error (fName P.++ ": string/dynamic length mismatch")
273     where 
274       fName = "Data.Param.TFVec.read"
275       expectedL = fromIntegerT (undefined :: nT)
276       possibilities = readTFVecList str
277       fitsLength (_, l, _) = l == expectedL
278       toReadS (xs, _, rest) = (TFVec xs, rest)
279       
280 instance NaturalT s => DF.Foldable (TFVec s) where
281  foldr = foldr
282  
283 instance NaturalT s => Functor (TFVec s) where
284  fmap = map
285
286 instance NaturalT s => DT.Traversable (TFVec s) where 
287   traverse f = (fmap TFVec).(DT.traverse f).unTFVec
288
289 -- instance (Lift a, NaturalT nT) => Lift (TFVec nT a) where
290 --   lift (TFVec xs) = [|  unsafeTFVecCoerse
291 --                         $(decLiteralV (fromIntegerT (undefined :: nT)))
292 --                         (TFVec xs) |]
293
294 instance (Lift a, Typeable a, NaturalT nT) => Lift (TFVec nT a) where
295   lift (TFVec xs) = sigE [| (TFVec xs) |] (decTFVecT (fromIntegerT (undefined :: nT)) xs)
296
297 decTFVecT :: Typeable x => Integer -> x -> Q Type
298 decTFVecT n a = appT (appT (conT (''TFVec)) (decLiteralT n)) elemT
299   where
300     (con,reps) = splitTyConApp (typeOf a)
301     elemT = typeRep2Type (P.head reps)
302
303
304 -- ======================
305 -- = Internal Functions =
306 -- ======================
307 liftV :: ([a] -> [b]) -> TFVec nT a -> TFVec nT' b
308 liftV f = TFVec . f . unTFVec
309
310 liftV2 :: ([a] -> [b] -> [c]) -> TFVec s a -> TFVec s2 b -> TFVec s3 c
311 liftV2 f a b = TFVec (f (unTFVec a) (unTFVec b))
312
313 splitAtM :: Int -> [a] -> Maybe ([a],[a])
314 splitAtM n xs = splitAtM' n [] xs
315   where splitAtM' 0 xs ys = Just (xs, ys)
316         splitAtM' n xs (y:ys) | n > 0 = do
317           (ls, rs) <- splitAtM' (n-1) xs ys
318           return (y:ls,rs)
319         splitAtM' _ _ _ = Nothing
320
321 unsafeTFVecCoerse :: nT' -> TFVec nT a -> TFVec nT' a
322 unsafeTFVecCoerse _ (TFVec v) = (TFVec v)
323
324 unsafeVectorCPS :: forall a w . Integer -> [a] ->
325                         (forall s . NaturalT s => TFVec s a -> w) -> w
326 unsafeVectorCPS l xs f = reifyNaturalD l 
327                         (\(_ :: lt) -> f ((TFVec xs) :: (TFVec lt a)))
328
329 readTFVecList :: Read a => String -> [([a], Int, String)]
330 readTFVecList = readParen' False (\r -> [pr | ("<",s) <- lexTFVec r,
331                                               pr <- readl s])
332   where
333     readl   s = [([],0,t) | (">",t) <- lexTFVec s] P.++
334                             [(x:xs,1+n,u) | (x,t)       <- reads s,
335                                             (xs, n, u)  <- readl' t]
336     readl'  s = [([],0,t) | (">",t) <- lexTFVec s] P.++
337                             [(x:xs,1+n,v) | (",",t)   <- lex s,
338                                             (x,u)     <- reads t,
339                                             (xs,n,v)  <- readl' u]
340     readParen' b g  = if b then mandatory else optional
341       where optional r  = g r P.++ mandatory r
342             mandatory r = [(x,n,u) | ("(",s)  <- lexTFVec r,
343                                       (x,n,t) <- optional s,
344                                       (")",u) <- lexTFVec t]
345
346 -- Custom lexer for FSVecs, we cannot use lex directly because it considers
347 -- sequences of < and > as unique lexemes, and that breaks nested FSVecs, e.g.
348 -- <<1,2><3,4>>
349 lexTFVec :: ReadS String
350 lexTFVec ('>':rest) = [(">",rest)]
351 lexTFVec ('<':rest) = [("<",rest)]
352 lexTFVec str = lex str
353