168 lines
6.4 KiB
Haskell
168 lines
6.4 KiB
Haskell
module Type.Environment where
|
|
|
|
import Control.Applicative ((<$>), (<*>))
|
|
import Control.Exception (try, SomeException)
|
|
import Control.Monad
|
|
import Control.Monad.Trans.Error (ErrorList(..))
|
|
import Control.Monad.Error (ErrorT, throwError, liftIO)
|
|
import qualified Control.Monad.State as State
|
|
import qualified Data.Traversable as Traverse
|
|
import qualified Data.Map as Map
|
|
import Data.List (isPrefixOf)
|
|
import qualified Data.UnionFind.IO as UF
|
|
import qualified Text.PrettyPrint as PP
|
|
|
|
import qualified SourceSyntax.Type as Src
|
|
import SourceSyntax.Module (ADT)
|
|
import Type.Type
|
|
|
|
type TypeDict = Map.Map String Type
|
|
type VarDict = Map.Map String Variable
|
|
|
|
data Environment = Environment {
|
|
constructor :: Map.Map String (IO (Int, [Variable], [Type], Type)),
|
|
aliases :: Map.Map String ([String], Src.Type),
|
|
types :: TypeDict,
|
|
value :: TypeDict
|
|
}
|
|
|
|
initialEnvironment :: [ADT] -> [(String, [String], Src.Type)] -> IO Environment
|
|
initialEnvironment datatypes aliases = do
|
|
types <- makeTypes datatypes
|
|
let aliases' = Map.fromList $ map (\(a,b,c) -> (a,(b,c))) aliases
|
|
env = Environment {
|
|
constructor = Map.empty,
|
|
value = Map.empty,
|
|
types = types,
|
|
aliases = aliases' }
|
|
|
|
return $ env { constructor = makeConstructors env datatypes }
|
|
|
|
makeTypes :: [ADT] -> IO TypeDict
|
|
makeTypes datatypes =
|
|
Map.fromList <$> mapM makeCtor (builtins ++ map nameAndKind datatypes)
|
|
where
|
|
nameAndKind (name, tvars, _) = (name, length tvars)
|
|
|
|
makeCtor (name, kind) = do
|
|
ctor <- VarN <$> namedVar Constant name
|
|
return (name, ctor)
|
|
|
|
tuple n = ("_Tuple" ++ show n, n)
|
|
|
|
kind n names = map (\name -> (name, n)) names
|
|
|
|
builtins :: [(String,Int)]
|
|
builtins = concat [ map tuple [0..9]
|
|
, kind 1 ["_List"]
|
|
, kind 0 ["Int","Float","Char","String","Bool"]
|
|
]
|
|
|
|
|
|
makeConstructors :: Environment
|
|
-> [ADT]
|
|
-> Map.Map String (IO (Int, [Variable], [Type], Type))
|
|
makeConstructors env datatypes = Map.fromList builtins
|
|
where
|
|
list t = (types env Map.! "_List") <| t
|
|
|
|
inst :: Int -> ([Type] -> ([Type], Type)) -> IO (Int, [Variable], [Type], Type)
|
|
inst numTVars tipe = do
|
|
vars <- forM [1..numTVars] $ \_ -> var Flexible
|
|
let (args, result) = tipe (map VarN vars)
|
|
return (length args, vars, args, result)
|
|
|
|
tupleCtor n =
|
|
let name = "_Tuple" ++ show n
|
|
in (name, inst n $ \vs -> (vs, foldl (<|) (types env Map.! name) vs))
|
|
|
|
builtins :: [ (String, IO (Int, [Variable], [Type], Type)) ]
|
|
builtins = [ ("[]", inst 1 $ \ [t] -> ([], list t))
|
|
, ("::", inst 1 $ \ [t] -> ([t, list t], list t))
|
|
] ++ map tupleCtor [0..9]
|
|
++ concatMap (ctorToType env) datatypes
|
|
|
|
|
|
ctorToType :: Environment -> ADT -> [ (String, IO (Int, [Variable], [Type], Type)) ]
|
|
ctorToType env (name, tvars, ctors) =
|
|
zip (map fst ctors) (map inst ctors)
|
|
where
|
|
inst :: (String, [Src.Type]) -> IO (Int, [Variable], [Type], Type)
|
|
inst ctor = do
|
|
((args, tipe), (dict,_)) <- State.runStateT (go ctor) (Map.empty, Map.empty)
|
|
return (length args, Map.elems dict, args, tipe)
|
|
|
|
|
|
go :: (String, [Src.Type]) -> State.StateT (VarDict, TypeDict) IO ([Type], Type)
|
|
go (ctor, args) = do
|
|
types <- mapM (instantiator env) args
|
|
returnType <- instantiator env (Src.Data name (map Src.Var tvars))
|
|
return (types, returnType)
|
|
|
|
|
|
get :: Environment -> (Environment -> Map.Map String a) -> String -> a
|
|
get env subDict key = Map.findWithDefault err key (subDict env)
|
|
where
|
|
err = error $ "\nCould not find type constructor '" ++ key ++ "' while checking types."
|
|
|
|
|
|
freshDataScheme :: Environment -> String -> IO (Int, [Variable], [Type], Type)
|
|
freshDataScheme env name = get env constructor name
|
|
|
|
instance ErrorList PP.Doc where
|
|
listMsg str = [PP.text str]
|
|
|
|
instantiateType :: Environment -> Src.Type -> VarDict -> ErrorT [PP.Doc] IO ([Variable], Type)
|
|
instantiateType env sourceType dict =
|
|
do result <- liftIO $ try (State.runStateT (instantiator env sourceType) (dict, Map.empty))
|
|
case result :: Either SomeException (Type, (VarDict, TypeDict)) of
|
|
Left someError -> throwError [ PP.text $ show someError ]
|
|
Right (tipe, (dict',_)) -> return (Map.elems dict', tipe)
|
|
|
|
instantiator :: Environment -> Src.Type
|
|
-> State.StateT (VarDict, TypeDict) IO Type
|
|
instantiator env sourceType = go sourceType
|
|
where
|
|
go :: Src.Type -> State.StateT (VarDict, TypeDict) IO Type
|
|
go sourceType =
|
|
case sourceType of
|
|
Src.Lambda t1 t2 -> (==>) <$> go t1 <*> go t2
|
|
|
|
Src.Var x -> do
|
|
(dict, aliases) <- State.get
|
|
case (Map.lookup x dict, Map.lookup x aliases) of
|
|
(_, Just t) -> return t
|
|
(Just v, _) -> return (VarN v)
|
|
_ ->
|
|
do var <- State.liftIO $ namedVar flex x
|
|
State.put (Map.insert x var dict, aliases)
|
|
return (VarN var)
|
|
where
|
|
flex | "number" `isPrefixOf` x = Is Number
|
|
| "comparable" `isPrefixOf` x = Is Comparable
|
|
| "appendable" `isPrefixOf` x = Is Appendable
|
|
| otherwise = Flexible
|
|
|
|
Src.Data name ts -> do
|
|
ts' <- mapM go ts
|
|
case (Map.lookup name (types env), Map.lookup name (aliases env)) of
|
|
(Just t, _) -> return $ foldl (<|) t ts'
|
|
(_, Just (tvars, t)) ->
|
|
let tvarLen = length tvars
|
|
msg = "\nType alias '" ++ name ++ "' expects " ++ show tvarLen ++
|
|
" type argument" ++ (if tvarLen == 1 then "" else "s") ++
|
|
" but was given " ++ show (length ts')
|
|
in if length ts' /= length tvars then error msg else
|
|
do (dict, aliases) <- State.get
|
|
let aliases' = Map.union (Map.fromList $ zip tvars ts') aliases
|
|
State.put (dict, aliases')
|
|
t' <- go t
|
|
State.put (dict, aliases)
|
|
return t'
|
|
_ -> error $ "\nCould not find type constructor '" ++
|
|
name ++ "' while checking types."
|
|
|
|
Src.EmptyRecord -> return (TermN EmptyRecord1)
|
|
|
|
Src.Record fields ext ->
|
|
TermN <$> (Record1 <$> Traverse.traverse (mapM go) (Src.fieldMap fields) <*> go ext)
|