Main.hs (10750B)
1 module Main where 2 import Control.Monad 3 import Control.Monad.Except 4 import Control.Monad.Reader 5 import Data.ByteString (ByteString) 6 import Data.Foldable (foldlM) 7 import qualified Data.ByteString as ByteString 8 import Sparsec 9 import System.Exit 10 11 import Algebra 12 import Qul (Qul) 13 import Ledger (Ledger) 14 import qualified Ledger as L 15 import Common 16 import qualified Environment as Env 17 import Context (Ctx) 18 import qualified Context as Ctx 19 import qualified Surface as S 20 import qualified Core as C 21 import qualified Value as V 22 import qualified Parser as P 23 24 -------------------------------------------------------------------------------- 25 -- Util 26 27 orThrowError :: MonadError e m => Maybe a -> e -> m a 28 a `orThrowError` e = 29 case a of 30 Just x -> return x 31 Nothing -> throwError e 32 33 -------------------------------------------------------------------------------- 34 -- Elaboration monad 35 36 type TypeCtx q = Ctx S.TypeName C.TypeIdx (V.Type q) Kind 37 type TermCtx q = Ctx S.TermName C.TermIdx (V.Term q) (V.Type q) 38 data BothCtx q = BothCtx { 39 typeCtx :: TypeCtx q, 40 termCtx :: TermCtx q 41 } 42 43 bothCtxEmpty :: BothCtx q 44 bothCtxEmpty = BothCtx Ctx.empty Ctx.empty 45 46 data Err 47 = ErrTypeInferCheckMismatch 48 | ErrTypeVarUnbound 49 | ErrTypeArrowAbsNotInferable 50 | ErrTypeArrowAppNotArrow 51 | ErrTermInferCheckMismatch 52 | ErrTermVarUnbound 53 | ErrTermLetQuantityMismatch 54 | ErrTermArrowAbsNotInferable 55 | ErrTermArrowAbsQuantityMismatch 56 | ErrTermArrowAbsNotArrow 57 | ErrTermArrowAppQuantityMismatch 58 | ErrTermArrowAppNotArrow 59 | ErrTermForallAbsNotInferable 60 | ErrTermForallAbsNotForall 61 | ErrTermForallAppQuantityMismatch 62 | ErrTermForallAppNotForall 63 deriving Show 64 65 type Elab q a = ReaderT (BothCtx q) (Either Err) a 66 67 runElab :: Elab q a -> BothCtx q -> Either Err a 68 runElab = runReaderT 69 70 -------------------------------------------------------------------------------- 71 -- Type-level elaboration 72 73 evalType :: QuantityAlgebra q => C.Type q -> Elab q (V.Type q) 74 evalType a = V.evalType <$> asks (.typeCtx.env) <*> pure a 75 76 quoteType :: QuantityAlgebra q => V.Type q -> Elab q (C.Type q) 77 quoteType a = V.quoteType <$> asks (.typeCtx.env) <*> pure a 78 79 convType :: QuantityAlgebra q => V.Type q -> V.Type q -> Elab q Bool 80 convType a b = V.convType <$> asks (.typeCtx.env) <*> pure a <*> pure b 81 82 bindAndAssignType :: QuantityAlgebra q 83 => S.TypeName 84 -> Kind 85 -> (V.TypeLvl -> Elab q a) 86 -> Elab q a 87 bindAndAssignType x k f = do 88 ctx <- ask 89 let (lvl, typeCtx') = Ctx.bindAndAssign x k ctx.typeCtx 90 local (const $ ctx { typeCtx = typeCtx' }) (f lvl) 91 92 lookupType :: QuantityAlgebra q => S.TypeName -> Elab q (V.Type q, Kind) 93 lookupType x = (Ctx.lookup x 0 <$> asks (.typeCtx)) >>= (`orThrowError` ErrTypeVarUnbound) 94 95 inferType :: QuantityAlgebra q => S.Type q -> Elab q (C.Type q, Kind) 96 inferType = \case 97 S.TypeVar x -> do 98 (va, k) <- lookupType x 99 a <- quoteType va 100 pure (a, k) 101 102 S.TypeAnnot a k -> (, k) <$> checkType a k 103 104 S.TypeTArrow dom cod -> do 105 dom' <- checkType dom KindQT 106 cod' <- checkType cod KindQT 107 pure (C.TypeTArrow dom' cod', KindT) 108 109 S.TypeTForall x k a -> do 110 a' <- bindAndAssignType x k \_ -> checkType a KindQT 111 pure (C.TypeTForall k a', KindT) 112 113 S.TypeQ q -> pure (C.TypeQ q, KindQ) 114 115 S.TypeQAdd q r -> do 116 q' <- checkType q KindQ 117 r' <- checkType r KindQ 118 pure (C.TypeQAdd q' r', KindQ) 119 120 S.TypeQMul q r -> do 121 q' <- checkType q KindQ 122 r' <- checkType r KindQ 123 pure (C.TypeQMul q' r', KindQ) 124 125 S.TypeQTPair q a -> do 126 q' <- checkType q KindQ 127 a' <- checkType a KindT 128 pure (C.TypeQTPair q' a', KindQT) 129 130 S.TypeQTFst qa -> do 131 qa' <- checkType qa KindQT 132 pure (C.TypeQTFst qa', KindQ) 133 134 S.TypeQTSnd qa -> do 135 qa' <- checkType qa KindQT 136 pure (C.TypeQTSnd qa', KindT) 137 138 S.TypeArrowAbs _ _ -> throwError ErrTypeArrowAbsNotInferable 139 140 S.TypeArrowApp fun arg -> do 141 (fun', k) <- inferType fun 142 case k of 143 KindArrow dom cod -> do 144 arg' <- checkType arg dom 145 pure (C.TypeArrowApp fun' arg', cod) 146 _ -> throwError ErrTypeArrowAppNotArrow 147 148 checkType :: QuantityAlgebra q => S.Type q -> Kind -> Elab q (C.Type q) 149 checkType a k = case (a, k) of 150 (S.TypeArrowAbs x body, KindArrow dom cod) -> 151 bindAndAssignType x dom \_ -> checkType body cod 152 (a, k) -> do 153 (a', k') <- inferType a 154 if k' == KindT && k == KindQT then 155 -- Coerce from kind * to kind $* 156 pure $ C.TypeQTPair (C.TypeQ defaultQuantity) a' 157 else do 158 unless (k' == k) $ throwError ErrTypeInferCheckMismatch 159 pure a' 160 161 -------------------------------------------------------------------------------- 162 -- Term-level elaboration 163 164 -- TODO: Think of a better way of tracking ledgers 165 166 bindAndAssignTerm :: QuantityAlgebra q 167 => S.TermName 168 -> V.Type q 169 -> (V.TermLvl -> Elab q a) 170 -> Elab q a 171 bindAndAssignTerm x a f = do 172 ctx <- ask 173 let (lvl, termCtx') = Ctx.bindAndAssign x a ctx.termCtx 174 local (const $ ctx { termCtx = termCtx' }) (f lvl) 175 176 lookupTerm :: QuantityAlgebra q => S.TermName -> Elab q (V.TermLvl, V.Type q) 177 lookupTerm x = do 178 m <- Ctx.lookup x 0 <$> asks (.termCtx) 179 (V.TermGeneric lvl, a) <- m `orThrowError` ErrTermVarUnbound 180 pure (lvl, a) 181 182 termLvlToIdx :: QuantityAlgebra q => V.TermLvl -> Elab q C.TermIdx 183 termLvlToIdx lvl = Ctx.lvlToIdx lvl <$> asks (.termCtx) 184 185 inferTerm :: QuantityAlgebra q 186 => S.Term q 187 -> V.Type q 188 -> Elab q (C.Term q, V.Type q, Ledger q) 189 inferTerm t q = case t of 190 S.TermVar x -> do 191 (lvl, a) <- lookupTerm x 192 idx <- termLvlToIdx lvl 193 pure (C.TermVar idx, a, L.singleton lvl q) 194 195 S.TermAnnot t a -> do 196 a' <- checkType a KindT 197 a'' <- evalType a' 198 (t', tLedger) <- checkTerm t q a'' 199 pure (t', a'', tLedger) 200 201 S.TermLet x arg rb body -> do 202 rb' <- checkType rb KindQT 203 rb'' <- evalType rb' 204 let r = V.qtFst rb'' 205 let b = V.qtSnd rb'' 206 (arg', argLedger) <- checkTerm arg r b 207 bindAndAssignTerm x b \lvl -> do 208 (body', bodyType, bodyLedger) <- inferTerm body q 209 let (argDemand, letLedger) = L.split lvl bodyLedger 210 unless (r `V.qGeq` argDemand) $ throwError ErrTermLetQuantityMismatch 211 pure (C.TermArrowApp (C.TermArrowAbs body') arg', bodyType, letLedger <+> argLedger) 212 213 S.TermArrowAbs _ _ -> throwError ErrTermArrowAbsNotInferable 214 215 S.TermArrowApp fun arg -> do 216 (fun', funType, funLedger) <- inferTerm fun one 217 case funType of 218 V.TypeTArrow dom cod -> do 219 (arg', argLedger) <- checkTerm arg (V.qtFst dom) (V.qtSnd dom) 220 unless (V.qtFst cod `V.qGeq` q) $ throwError ErrTermArrowAppQuantityMismatch 221 pure (C.TermArrowApp fun' arg', V.qtSnd cod, funLedger <+> argLedger) 222 _ -> throwError ErrTermArrowAppNotArrow 223 224 S.TermForallAbs _ _ -> throwError ErrTermForallAbsNotInferable 225 226 S.TermForallApp fun arg -> do 227 (fun', funType, funLedger) <- inferTerm fun one 228 case funType of 229 V.TypeTForall k clo -> do 230 arg' <- checkType arg k 231 arg'' <- evalType arg' 232 let qa = V.applyClosure clo arg'' 233 unless (V.qtFst qa `V.qGeq` q) $ throwError ErrTermForallAppQuantityMismatch 234 pure (C.TermForallApp fun' arg', V.qtSnd qa, funLedger) 235 _ -> throwError ErrTermForallAppNotForall 236 237 checkTerm :: QuantityAlgebra q 238 => S.Term q 239 -> V.Type q 240 -> V.Type q 241 -> Elab q (C.Term q, Ledger q) 242 checkTerm t q a = case (t, a) of 243 (S.TermArrowAbs x body, V.TypeTArrow dom cod) -> do 244 bindAndAssignTerm x (V.qtSnd dom) \lvl -> do 245 (body', bodyLedger) <- checkTerm body (V.qtFst cod) (V.qtSnd cod) 246 let (domDemand, absLedger) = L.split lvl bodyLedger 247 unless (V.qtFst dom `V.qGeq` domDemand) $ throwError ErrTermArrowAbsQuantityMismatch 248 pure (C.TermArrowAbs body', q <⋅> absLedger) 249 250 (S.TermArrowAbs _ _, _) -> throwError ErrTermArrowAbsNotArrow 251 252 (S.TermForallAbs x body, V.TypeTForall k clo) -> 253 bindAndAssignType x k \lvl -> do 254 let qa = V.applyClosure clo (V.TypeGeneric lvl) 255 (body', bodyLedger) <- checkTerm body (V.qtFst qa) (V.qtSnd qa) 256 pure (C.TermForallAbs body', q <⋅> bodyLedger) 257 258 (S.TermForallAbs _ _, _) -> throwError ErrTermForallAbsNotForall 259 260 (t, a) -> do 261 (t', a', tLedger) <- inferTerm t q 262 equal <- convType a' a 263 unless equal $ throwError ErrTermInferCheckMismatch 264 pure (t', tLedger) 265 266 -------------------------------------------------------------------------------- 267 268 type Q = Qul 269 pQ :: Parse P.Err Q 270 pQ = P.pQul 271 272 data Decl = Decl ByteString ByteString ByteString 273 274 data Prog = Prog { 275 decls :: [Decl], 276 body :: ByteString 277 } 278 279 parseKind :: ByteString -> IO Kind 280 parseKind k = 281 case P.runP (P.pTop P.pKind) pQ k of 282 Ok k _ _ -> pure k 283 result -> print result *> exitFailure 284 285 parseType :: ByteString -> IO (S.Type Q) 286 parseType a = 287 case P.runP (P.pTop P.pType) pQ a of 288 Ok a _ _ -> pure a 289 result -> print result *> exitFailure 290 291 parseTerm :: ByteString -> IO (S.Term Q) 292 parseTerm t = 293 case P.runP (P.pTop P.pTerm) pQ t of 294 Ok t _ _ -> pure t 295 result -> print result *> exitFailure 296 297 declare :: TypeCtx Q -> Decl -> IO (TypeCtx Q) 298 declare typeCtx (Decl x k a) = do 299 let x' = S.TypeName x 300 k' <- parseKind k 301 a' <- parseType a 302 a'' <- case checkType a' k' `runElab` BothCtx typeCtx Ctx.empty of 303 Left e -> print e *> exitFailure 304 Right a -> pure a 305 let a''' = V.evalType typeCtx.env a'' 306 pure $ Ctx.assign x' a''' k' typeCtx 307 308 input :: Prog 309 input = 310 Prog [ 311 Decl "Nat" "%" "@M:% -> M -> (M -> M) -> M" 312 ] $ ByteString.intercalate "\n" [ 313 "let zero : Nat = \\[M] z s => z;", 314 "let succ : Nat -> Nat = \\n => \\[M] z s => s (n [M] z s);", 315 "succ (succ zero)" 316 ] 317 318 main :: IO () 319 main = do 320 typeCtx <- foldlM declare Ctx.empty input.decls 321 t <- parseTerm input.body 322 case inferTerm t one `runElab` BothCtx typeCtx Ctx.empty of 323 Left e -> print e 324 Right (t, a, _) -> do 325 putStrLn "term:" 326 print t 327 putStrLn "----------------------------" 328 putStrLn "type:" 329 print $ V.quoteType Env.empty a