Move TFVec and TFP integers (Signed, Unsiged and Index) into clash
[matthijs/master-project/cλash.git] / cλash / Data / Param / Vector.hs
1 {-# LANGUAGE StandaloneDeriving, ExistentialQuantification, ScopedTypeVariables, TemplateHaskell, TypeOperators, TypeFamilies #-}
2 module Data.Param.Vector
3   ( Vector
4   , empty
5   , (+>)
6   , singleton
7   , vectorTH
8   , unsafeVector
9   , readVector
10   , length
11   , lengthT
12   , fromVector
13   , null
14   , (!)
15   , replace
16   , head
17   , last
18   , init
19   , tail
20   , take
21   , drop
22   , select
23   , (<+)
24   , (++)
25   , map
26   , zipWith
27   , foldl
28   , foldr
29   , zip
30   , unzip
31   , shiftl
32   , shiftr
33   , rotl
34   , rotr
35   , concat
36   , reverse
37   , iterate
38   , iteraten
39   , generate
40   , generaten
41   , copy
42   , copyn
43   , split
44   ) where
45     
46 import Types
47 import Types.Data.Num
48 import Types.Data.Num.Decimal.Literals.TH
49 import Data.Param.Index
50
51 import Data.Typeable
52 import qualified Prelude as P
53 import Prelude hiding (
54   null, length, head, tail, last, init, take, drop, (++), map, foldl, foldr,
55   zipWith, zip, unzip, concat, reverse, iterate )
56 import qualified Data.Foldable as DF (Foldable, foldr)
57 import qualified Data.Traversable as DT (Traversable(traverse))
58 import Language.Haskell.TH hiding (Pred)
59 import Language.Haskell.TH.Syntax (Lift(..))
60
61 newtype (NaturalT s) => Vector s a = Vector {unVec :: [a]}
62   deriving Eq
63
64 -- deriving instance (NaturalT s, Typeable s, Data s, Typeable a, Data a) => Data (TFVec s a)
65
66 -- ==========================
67 -- = Constructing functions =
68 -- ==========================
69                                                   
70 empty :: Vector D0 a
71 empty = Vector []
72
73 (+>) :: a -> Vector s a -> Vector (Succ s) a
74 x +> (Vector xs) = Vector (x:xs)
75
76 infix 5 +>
77
78 singleton :: a -> Vector D1 a
79 singleton x = x +> empty
80
81 -- FIXME: Not the most elegant solution... but it works for now in clash
82 vectorTH :: (Lift a) => [a] -> ExpQ
83 -- vectorTH xs = sigE [| (TFVec xs) |] (decTFVecT (toInteger (P.length xs)) xs)
84 vectorTH [] = [| empty |]
85 vectorTH [x] = [| singleton x |]
86 vectorTH (x:xs) = [| x +> $(vectorTH xs) |]
87
88 unsafeVector :: NaturalT s => s -> [a] -> Vector s a
89 unsafeVector l xs
90   | fromIntegerT l /= P.length xs =
91     error (show 'unsafeVector P.++ ": dynamic/static lenght mismatch")
92   | otherwise = Vector xs
93
94 readVector :: (Read a, NaturalT s) => String -> Vector s a
95 readVector = read
96         
97 -- =======================
98 -- = Observing functions =
99 -- =======================
100 length :: forall s a . NaturalT s => Vector s a -> Int
101 length _ = fromIntegerT (undefined :: s)
102
103 lengthT :: NaturalT s => Vector s a -> s
104 lengthT = undefined
105
106 fromVector :: NaturalT s => Vector s a -> [a]
107 fromVector (Vector xs) = xs
108
109 null :: Vector D0 a -> Bool
110 null _ = True
111
112 (!) ::  ( PositiveT s
113         , NaturalT u
114         , (s :>: u) ~ True) => Vector s a -> Index u -> a
115 (Vector xs) ! i = xs !! (fromInteger (toInteger i))
116
117 -- ==========================
118 -- = Transforming functions =
119 -- ==========================
120 replace :: (PositiveT s, NaturalT u, (s :>: u) ~ True) =>
121   Vector s a -> Index u -> a -> Vector s a
122 replace (Vector xs) i y = Vector $ replace' xs (toInteger i) y
123   where replace' []     _ _ = []
124         replace' (_:xs) 0 y = (y:xs)
125         replace' (x:xs) n y = x : (replace' xs (n-1) y)
126   
127 head :: PositiveT s => Vector s a -> a
128 head = P.head . unVec
129
130 tail :: PositiveT s => Vector s a -> Vector (Pred s) a
131 tail = liftV P.tail
132
133 last :: PositiveT s => Vector s a -> a
134 last = P.last . unVec
135
136 init :: PositiveT s => Vector s a -> Vector (Pred s) a
137 init = liftV P.init
138
139 take :: NaturalT i => i -> Vector s a -> Vector (Min s i) a
140 take i = liftV $ P.take (fromIntegerT i)
141
142 drop :: NaturalT i => i -> Vector s a -> Vector (s :-: (Min s i)) a
143 drop i = liftV $ P.drop (fromIntegerT i)
144
145 select :: (NaturalT f, NaturalT s, NaturalT n, (f :<: i) ~ True, 
146           (((s :*: n) :+: f) :<=: i) ~ True) => 
147           f -> s -> n -> Vector i a -> Vector n a
148 select f s n = liftV (select' f' s' n')
149   where (f', s', n') = (fromIntegerT f, fromIntegerT s, fromIntegerT n)
150         select' f s n = ((selectFirst0 s n).(P.drop f))
151         selectFirst0 :: Int -> Int -> [a] -> [a]
152         selectFirst0 s n l@(x:_)
153           | n > 0 = x : selectFirst0 s (n-1) (P.drop s l)
154           | otherwise = []
155         selectFirst0 _ 0 [] = []
156
157 (<+) :: Vector s a -> a -> Vector (Succ s) a
158 (<+) (Vector xs) x = Vector (xs P.++ [x])
159
160 (++) :: Vector s a -> Vector s2 a -> Vector (s :+: s2) a
161 (++) = liftV2 (P.++)
162
163 infixl 5 <+
164 infixr 5 ++
165
166 map :: (a -> b) -> Vector s a -> Vector s b
167 map f = liftV (P.map f)
168
169 zipWith :: (a -> b -> c) -> Vector s a -> Vector s b -> Vector s c
170 zipWith f = liftV2 (P.zipWith f)
171
172 foldl :: (a -> b -> a) -> a -> Vector s b -> a
173 foldl f e = (P.foldl f e) . unVec
174
175 foldr :: (b -> a -> a) -> a -> Vector s b -> a
176 foldr f e = (P.foldr f e) . unVec
177
178 zip :: Vector s a -> Vector s b -> Vector s (a, b)
179 zip = liftV2 P.zip
180
181 unzip :: Vector s (a, b) -> (Vector s a, Vector s b)
182 unzip (Vector xs) = let (a,b) = P.unzip xs in (Vector a, Vector b)
183
184 shiftl :: (PositiveT s, NaturalT n, n ~ Pred s, s ~ Succ n) => 
185           Vector s a -> a -> Vector s a
186 shiftl xs x = x +> init xs
187
188 shiftr :: (PositiveT s, NaturalT n, n ~ Pred s, s ~ Succ n) => 
189           Vector s a -> a -> Vector s a
190 shiftr xs x = tail xs <+ x
191   
192 rotl :: forall s a . NaturalT s => Vector s a -> Vector s a
193 rotl = liftV rotl'
194   where vlen = fromIntegerT (undefined :: s)
195         rotl' [] = []
196         rotl' xs = let (i,[l]) = splitAt (vlen - 1) xs
197                    in l : i 
198
199 rotr :: NaturalT s => Vector s a -> Vector s a
200 rotr = liftV rotr'
201   where
202     rotr' [] = []
203     rotr' (x:xs) = xs P.++ [x] 
204
205 concat :: Vector s1 (Vector s2 a) -> Vector (s1 :*: s2) a
206 concat = liftV (P.foldr ((P.++).unVec) [])
207
208 reverse :: Vector s a -> Vector s a
209 reverse = liftV P.reverse
210
211 iterate :: NaturalT s => (a -> a) -> a -> Vector s a
212 iterate = iteraten (undefined :: s)
213
214 iteraten :: NaturalT s => s -> (a -> a) -> a -> Vector s a
215 iteraten s f x = let s' = fromIntegerT s in Vector (P.take s' $ P.iterate f x)
216
217 generate :: NaturalT s => (a -> a) -> a -> Vector s a
218 generate = generaten (undefined :: s)
219
220 generaten :: NaturalT s => s -> (a -> a) -> a -> Vector s a
221 generaten s f x = let s' = fromIntegerT s in Vector (P.take s' $ P.tail $ P.iterate f x)
222
223 copy :: NaturalT s => a -> Vector s a
224 copy x = copyn (undefined :: s) x
225
226 copyn :: NaturalT s => s -> a -> Vector s a
227 copyn s x = iteraten s id x
228
229 split :: ( NaturalT s
230          -- , IsEven s ~ True
231          ) => Vector s a -> (Vector (Div2 s) a, Vector (Div2 s) a)
232 split (Vector xs) = (Vector (P.take vlen xs), Vector (P.drop vlen xs))
233   where
234     vlen = round ((fromIntegral (P.length xs)) / 2)
235
236 -- =============
237 -- = Instances =
238 -- =============
239 instance Show a => Show (Vector s a) where
240   showsPrec _ = showV.unVec
241     where showV []      = showString "<>"
242           showV (x:xs)  = showChar '<' . shows x . showl xs
243                             where showl []      = showChar '>'
244                                   showl (x:xs)  = showChar ',' . shows x .
245                                                   showl xs
246
247 instance (Read a, NaturalT nT) => Read (Vector nT a) where
248   readsPrec _ str
249     | all fitsLength possibilities = P.map toReadS possibilities
250     | otherwise = error (fName P.++ ": string/dynamic length mismatch")
251     where 
252       fName = "Data.Param.TFVec.read"
253       expectedL = fromIntegerT (undefined :: nT)
254       possibilities = readVectorList str
255       fitsLength (_, l, _) = l == expectedL
256       toReadS (xs, _, rest) = (Vector xs, rest)
257       
258 instance NaturalT s => DF.Foldable (Vector s) where
259  foldr = foldr
260  
261 instance NaturalT s => Functor (Vector s) where
262  fmap = map
263
264 instance NaturalT s => DT.Traversable (Vector s) where 
265   traverse f = (fmap Vector).(DT.traverse f).unVec
266
267 instance (Lift a, NaturalT nT) => Lift (Vector nT a) where
268   lift (Vector xs) = [|  unsafeVectorCoerse
269                          $(decLiteralV (fromIntegerT (undefined :: nT)))
270                           (Vector xs) |]
271
272 -- ======================
273 -- = Internal Functions =
274 -- ======================
275 liftV :: ([a] -> [b]) -> Vector nT a -> Vector nT' b
276 liftV f = Vector . f . unVec
277
278 liftV2 :: ([a] -> [b] -> [c]) -> Vector s a -> Vector s2 b -> Vector s3 c
279 liftV2 f a b = Vector (f (unVec a) (unVec b))
280
281 splitAtM :: Int -> [a] -> Maybe ([a],[a])
282 splitAtM n xs = splitAtM' n [] xs
283   where splitAtM' 0 xs ys = Just (xs, ys)
284         splitAtM' n xs (y:ys) | n > 0 = do
285           (ls, rs) <- splitAtM' (n-1) xs ys
286           return (y:ls,rs)
287         splitAtM' _ _ _ = Nothing
288
289 unsafeVectorCoerse :: nT' -> Vector nT a -> Vector nT' a
290 unsafeVectorCoerse _ (Vector v) = (Vector v)
291
292 readVectorList :: Read a => String -> [([a], Int, String)]
293 readVectorList = readParen' False (\r -> [pr | ("<",s) <- lexVector r,
294                                               pr <- readl s])
295   where
296     readl   s = [([],0,t) | (">",t) <- lexVector s] P.++
297                             [(x:xs,1+n,u) | (x,t)       <- reads s,
298                                             (xs, n, u)  <- readl' t]
299     readl'  s = [([],0,t) | (">",t) <- lexVector s] P.++
300                             [(x:xs,1+n,v) | (",",t)   <- lex s,
301                                             (x,u)     <- reads t,
302                                             (xs,n,v)  <- readl' u]
303     readParen' b g  = if b then mandatory else optional
304       where optional r  = g r P.++ mandatory r
305             mandatory r = [(x,n,u) | ("(",s)  <- lexVector r,
306                                       (x,n,t) <- optional s,
307                                       (")",u) <- lexVector t]
308
309 -- Custom lexer for FSVecs, we cannot use lex directly because it considers
310 -- sequences of < and > as unique lexemes, and that breaks nested FSVecs, e.g.
311 -- <<1,2><3,4>>
312 lexVector :: ReadS String
313 lexVector ('>':rest) = [(">",rest)]
314 lexVector ('<':rest) = [("<",rest)]
315 lexVector str = lex str
316