

{----------------------------------------------------------------------------

                           Compilation A La Carte

                     Laurence E. Day and Graham Hutton
                     Functional Programming Laboratory
                        School of Computer Science
                       University of Nottingham, UK

                               March 2013

----------------------------------------------------------------------------}

{-# LANGUAGE TypeOperators, 
             GADTs, 
             FlexibleInstances, 
             MultiParamTypeClasses, 
             TypeFamilies,
             OverlappingInstances,
             FlexibleContexts, 
             ScopedTypeVariables,
             UndecidableInstances, 
             DeriveFunctor #-}

import Control.Monad
import Prelude hiding (catch)

data Fix f = In (f (Fix f))
type Language = Fix

data (f :+: g) e = Inl (f e) | Inr (g e)      deriving Functor

data Arith e = Val Int | Add e e              deriving Functor
data Except e = Throw | Catch e e             deriving Functor
data Lambda e = Index Int | Abs e | Apply e e deriving Functor
data State e = Get | Set Int e                deriving Functor

fold :: Functor f => (f a -> a) -> Fix f -> a
fold f (In t) = f (fmap (fold f) t)

data Value m where
  Num  :: Int -> Value m
  Clos :: Monad m => [Value m] -> m (Value m) -> Value m
  
class (Functor f, Monad m) => Eval f m where
  evAlg :: f (m (Value m)) -> m (Value m)
  
instance (Eval f m, Eval g m) => Eval (f :+: g) m where
  evAlg (Inl x) = evAlg x
  evAlg (Inr y) = evAlg y

instance Monad m => Eval Arith m where
  evAlg (Val n)   = return (Num n)  
  evAlg (Add x y) = do Num n <- x
                       Num m <- y
                       return (Num (n + m))
  
instance MonadPlus m => Eval Except m where
  evAlg (Throw)     = mzero
  evAlg (Catch x h) = x `mplus` h
  
instance CBVMonad m => Eval Lambda m where
  evAlg (Index n)   = env >>= \e -> return (e !! n)
  evAlg (Abs t)     = env >>= \e -> return (Clos e t)
  evAlg (Apply f x) = f >>= \(Clos ctx t) -> x >>= \c -> with (c:ctx) t
  
instance StateMonad m => Eval State m where
  evAlg (Get)     = update id        >>= return . Num
  evAlg (Set n e) = update (\_ -> n) >>  e

eval :: Eval f m => Language f -> m (Value m)
eval = fold evAlg

data ARITH e = PUSH Int e | ADD e deriving Functor

data EXCEPT e where
  THROW ::  e -> EXCEPT e
  UNMARK :: e -> EXCEPT e
  MARK :: (Execute f) => Language f -> e -> EXCEPT e
  ---
  SAVE    :: e -> EXCEPT e
  RESTORE :: e -> EXCEPT e
  
data LAMBDA e where
  ACCESS  :: Int -> e -> LAMBDA e
  CLOSURE :: Execute f => Fix f -> e -> LAMBDA e
  APPLY   :: e -> LAMBDA e
  RETURN  :: e -> LAMBDA e
  
data STATE e = GET e | SET Int e deriving Functor
  
data EMPTY e = EMPTY             deriving Functor
  
instance Functor EXCEPT where
  fmap f (THROW e) = THROW (f e)
  fmap f (UNMARK e) = UNMARK (f e)
  fmap f (MARK h e) = MARK h (f e)
  fmap f (SAVE e) = SAVE (f e)
  fmap f (RESTORE e) = RESTORE (f e)
  
instance Functor LAMBDA where
  fmap f (ACCESS n e)  = ACCESS n (f e)
  fmap f (CLOSURE k e) = CLOSURE k (f e)
  fmap f (APPLY e)     = APPLY (f e)
  fmap f (RETURN e)    = RETURN (f e)

class (Functor sub, Functor sup) => sub :<: sup where
  inj :: sub a -> sup a
  
instance Functor f => (:<:) f f where
  inj = id
  
instance (Functor f, Functor g) => (:<:) f (f :+: g) where
  inj = Inl
  
instance (Functor g, (:<:) f h) => (:<:) f (g :+: h) where
  inj = Inr . inj
  
inject :: (g :<: f) => g (Language f) -> Language f
inject = In . inj

{-
Smart constructors
-}

type Duo  a = a -> a
type Trio a = a -> a -> a

--  Arith
val_sc :: (Arith :<: f) => Int -> Language f
val_sc   = \n -> inject (Val n)

add_sc :: (Arith :<: f) => Trio (Language f)
add_sc   = \x y -> inject (Add x y)
-- /Arith

-- Except
throw_sc :: (Except :<: f) => Language f
throw_sc = inject (Throw)

catch_sc :: (Except :<: f) => Trio (Language f)
catch_sc = \x h -> inject (Catch x h)
-- /Except

-- Lambda
index_sc :: (Lambda :<: f) => Int -> Language f
index_sc = \n -> inject (Index n)

abs_sc   :: (Lambda :<: f) => Duo (Language f)
abs_sc   = \t -> inject (Abs t)

apply_sc :: (Lambda :<: f) => Trio (Language f)
apply_sc = \f x -> inject (Apply f x)
-- /Lambda

-- State
get_sc   :: (State :<: f) => Language f
get_sc   = inject Get

set_sc   :: (State :<: f) => Int -> Duo (Language f)
set_sc   = \n c -> inject (Set n c)
-- /State

-- ARITH
push_SC :: (ARITH :<: g) => Int -> Duo (Language g)
push_SC = \n c -> inject (PUSH n c)

add_SC :: (ARITH :<: g) => Duo (Language g)
add_SC = \c -> inject (ADD c)
-- /ARITH

-- EXCEPT
throw_SC :: (EXCEPT :<: g) => Duo (Language g)
throw_SC = \c -> inject (THROW c)

mark_SC :: (EXCEPT :<: g, Execute f) => Fix f -> Duo (Language g)
mark_SC = \t c -> inject (MARK t c)

unmark_SC :: (EXCEPT :<: g) => Duo (Language g)
unmark_SC = \c -> inject (UNMARK c)

save_SC :: (EXCEPT :<: g) => Duo (Language g)
save_SC = \c -> inject (SAVE c)

restore_SC :: (EXCEPT :<: g) => Duo (Language g)
restore_SC = \c -> inject (RESTORE c)
-- /EXCEPT

-- LAMBDA
access_SC :: (LAMBDA :<: g) => Int -> Duo (Language g)
access_SC = \n c -> inject (ACCESS n c)

closure_SC :: (LAMBDA :<: g, Execute f) => Fix f -> Duo (Language g)
closure_SC = \k c -> inject (CLOSURE k c)

apply_SC :: (LAMBDA :<: g) => Duo (Language g)
apply_SC = \c -> inject (APPLY c)

return_SC :: (LAMBDA :<: g) => Duo (Language g)
return_SC = \c -> inject (RETURN c)
-- /LAMBDA

-- STATE
get_SC :: (STATE :<: g) => Duo (Language g)
get_SC = \c -> inject (GET c)

set_SC :: (STATE :<: g) => Int -> Duo (Language g)
set_SC = \n c -> inject (SET n c)
-- /STATE

newtype Identity   a = I  { runI  :: a             }
newtype ErrorT   m a = E  { runE  :: m (Maybe a)   }
newtype StateM s   a = S  { runSM :: s -> (a, s)   }
newtype StateT s m a = ST { runS  :: s -> m (a, s) }

instance Monad Identity where
  return      = I
  (I x) >>= f = f x
  
instance Monad (StateM s) where
  return x = S $ \s -> (x, s)
  (S t) >>= f = S $ \s -> let (x, u) = t s in runSM (f x) u
  
{- various class instantiations go here -}

class Monad m => ErrorMonad m where
  throw :: m a
  catch :: m a -> m a -> m a
                                     
class Monad m => CBVMonad m where
  env  :: m [Value m]
  with :: [Value m] -> (m (Value m)) -> m (Value m)
  
class Monad m => StateMonad m where
  update :: (Int -> Int) -> m Int
  
class MonadT t where
  lift :: Monad m => m a -> t m a

instance MonadT ErrorT where
  lift m = E $ m >>= return . Just
  
instance MonadT (StateT s) where
  lift m = ST $ \s -> (m >>= \x -> return (x, s))  

----------------------
  
instance Monad m => Monad (StateT s m) where
  return x     = ST $ \s -> return (x, s)
  (ST t) >>= f = ST $ \s -> t s >>= \(x, u) -> runS (f x) u
  
instance Monad m => Monad (ErrorT m) where
  return      = E . return . Just
  (E m) >>= f = E (do res <- m; case res of Just n -> runE (f n); _ -> return Nothing)
  
instance ErrorMonad Maybe where
  throw = Nothing
  Nothing `catch` x = x
  x `catch` _       = x

----------------------

data MTList m where
  Err ::      MTList m -> MTList (ErrorT m)
  Sta :: s -> MTList m -> MTList (StateT s m)
  Id  ::      MTList Identity
  
-- Technique Two: Monadic Reification --

class (Functor f, Functor g) => Comp f g where
  coAlg :: (EMPTY :<: g, Execute g) => f (MTList m -> Duo (Fix g)) -> MTList m -> Duo (Fix g)

instance (Comp f h, Comp g h) => Comp (f :+: g) h where
  coAlg (Inl x)     = coAlg x
  coAlg (Inr y)     = coAlg y
           
instance (ARITH :<: g) => Comp Arith g where
  coAlg (Val n)     = \_ c -> push_SC n c
  coAlg (Add x y)   = \m c -> x m $ y m (add_SC c)
  
instance (EXCEPT :<: g) => Comp Except g where
  coAlg (Throw)     = \_ c -> throw_SC c
  coAlg (Catch x h) = \m c -> case m of (Err (Sta s t)) -> mark_SC (h m c) (x m (unmark_SC c))
                                        (Sta s (Err t)) -> mark_SC (h m c) (save_SC (x m (restore_SC $ unmark_SC c)))
  
instance (LAMBDA :<: g) => Comp Lambda g where
  coAlg (Index n)   = \_ c -> access_SC n c
  coAlg (Abs t)     = \m c -> closure_SC (t m (return_SC e)) c 
                      where e :: Fix g = inject EMPTY
  coAlg (Apply f x) = \m c -> x m (f m $ apply_SC c)

instance (STATE :<: g) => Comp State g where
  coAlg (Get)     = \_ -> get_SC
  coAlg (Set n x) = \m c -> set_SC n (x m c)

comp :: (Comp f g, EMPTY :<: g, Execute g) => Fix f -> MTList m -> Fix g -> Fix g
comp = fold coAlg

-- Technique One: Matching Monads As Parameters --

class (Functor f, Functor g, Monad m) => Comp' f g m where
  coAlg' :: (EMPTY :<: g, Execute g) => f (m () -> Duo (Fix g)) -> m () -> Duo (Fix g)
  
instance (Comp' f h m, Comp' g h m) => Comp' (f :+: g) h m where
  coAlg' (Inl x)     = coAlg' x
  coAlg' (Inr y)     = coAlg' y
           
instance (ARITH :<: g, Monad m) => Comp' Arith g m where
  coAlg' (Val n)     = \_ -> push_SC n
  coAlg' (Add x y)   = \m c -> x m $ y m (add_SC c)
  
instance (EXCEPT :<: g, Monad m) => Comp' Except g (ErrorT (StateT s m)) where
  coAlg' (Throw)     = \_ -> throw_SC
  coAlg' (Catch x h) = \m c -> mark_SC (h m c) (x m (unmark_SC c))
  
instance (EXCEPT :<: g, Monad m) => Comp' Except g (StateT s (ErrorT m)) where
  coAlg' (Throw)     = \_ -> throw_SC
  coAlg' (Catch x h) = \m c -> mark_SC (h m c) (save_SC (x m (restore_SC $ unmark_SC c)))
  
instance (LAMBDA :<: g, Monad m) => Comp' Lambda g m where
  coAlg' (Index n)   = \_ c -> access_SC n c
  coAlg' (Abs t)     = \m c -> closure_SC (t m (return_SC e)) c where e :: Fix g = inject EMPTY
  coAlg' (Apply f x) = \m c -> x m (f m $ apply_SC c)

instance (STATE :<: g, Monad m) => Comp' State g m where
  coAlg' (Get)   = \_ -> get_SC
  coAlg' (Set n e) = \m c -> set_SC n (e m c)
  
comp' :: (EMPTY :<: g, Execute g, Comp' f g m) => Fix f -> m () -> Fix g -> Fix g
comp' = fold coAlg'
  
---

data VALUE e where
  NUM :: Int -> VALUE e
  REC :: Int -> VALUE e
  HAN :: Execute f => Fix f -> VALUE e
  CTN :: ZAM -> [VALUE e] -> VALUE e
  CLO :: Execute f => Fix f -> [VALUE e] -> VALUE e
  
isHAN         :: VALUE e -> Bool
isHAN (HAN _) = True
isHAN _       = False

isREC :: VALUE e -> Bool
isREC (REC _) = True
isREC _       = False

type StateStack = [VALUE ()]
type ZAMEnv     = [VALUE ()]
type ZAMStack   = [VALUE ()]
type Snapshot   = (ZAMEnv, StateStack, ZAMStack)
type ZAM        = StateM Snapshot ()

class Functor f => Execute f where
  exAlg :: f ZAM -> ZAM
  
instance (Execute f, Execute g) => Execute (f :+: g) where
  exAlg (Inl x) = exAlg x
  exAlg (Inr y) = exAlg y
  
instance Execute ARITH where
  exAlg (PUSH n c)    = S $ \(e, s, stk) -> runSM c (e, s, (NUM n):stk)
  exAlg (ADD c)       = S $ \(e, s, stk) -> case stk of ((NUM n):(NUM m):stk') -> 
                                                          runSM c (e, s, ((NUM (n + m)):stk'))

instance Execute EXCEPT where
  exAlg (THROW c)     = S $ \(e, s, stk) -> case dropWhile (not . isHAN) stk of  
                                           ((HAN h):stk') -> runSM (fold exAlg h) (e, s, stk')
  exAlg (MARK h c)    = S $ \(e, s, stk) -> runSM c (e, s, (HAN h):stk)
  exAlg (UNMARK c)    = S $ \(e, s, stk) -> case dropWhile (not . isHAN) stk of 
                                           ((HAN _):stk') -> runSM c (e, s, stk')
  exAlg (SAVE c)      = S $ \(e, s, stk) -> case s of z@((NUM n):s') -> runSM c (e, z, (REC n):stk)
  exAlg (RESTORE c)   = S $ \(e, s, stk) -> case dropWhile (not . isREC) stk of 
                                            ((REC n):stk') -> runSM c (e, ((NUM n):s), stk')

instance Execute LAMBDA where
  exAlg (ACCESS i c)  = S $ \(e, s, stk) -> runSM c (e, s, (e !! i):stk)
  exAlg (CLOSURE k c) = S $ \(e, s, stk) -> runSM c (e, s, (CLO k e):stk)
  exAlg (RETURN c)    = S $ \(e, s, stk) -> case stk of (v:(CLO k e'):stk') -> runSM (fold exAlg k) (v:e', s, ((CTN c e):stk'))
  exAlg (APPLY c)     = S $ \(e, s, stk) -> case stk of (v:(CTN c' e'):stk') -> runSM c' (e', s, v:stk')
                                                     
instance Execute STATE where
  exAlg (GET c)   = S $ \(e, s, stk) -> case s of z@((NUM n):_) -> runSM c (e, z, (NUM n):stk)
  exAlg (SET n c) = S $ \(e, s, stk) -> runSM c (e, ((NUM n):s), stk)
  
instance Execute EMPTY where
  exAlg (EMPTY) =  S $ \s -> ((), s)
  
execute :: Execute f => Fix f -> ZAM
execute f = execute' f ([], [], [])

execute' :: Execute f => Fix f -> Snapshot -> ZAM
execute' f st = S $ \st -> runSM (fold exAlg f) st

-- Fin.