diff --git a/compiler/Type/Constrain/Expression.hs b/compiler/Type/Constrain/Expression.hs index 1c94980..601a9cc 100644 --- a/compiler/Type/Constrain/Expression.hs +++ b/compiler/Type/Constrain/Expression.hs @@ -21,8 +21,11 @@ import qualified Type.Constrain.Pattern as Pattern {-- Testing section --} import SourceSyntax.PrettyPrint +import Text.PrettyPrint as P import Parse.Expression import Parse.Helpers (iParse) +import Type.Solve (solve) +import qualified Type.State as TS test str = case iParse expr "" str of @@ -33,8 +36,9 @@ test str = constraint <- constrain env expression (VarN var) prettyNames constraint print (pretty constraint) - print (pretty var) - return () + print (P.text "Solving for:" <+> pretty var) + (env,_,_,_) <- execStateT (solve constraint) TS.initialState + mapM_ (\(n,t) -> print $ P.text n <+> P.text ":" <+> pretty t) $ Map.toList env {-- todo: remove testing code --} constrain :: Env.Environment -> LExpr a b -> Type -> IO TypeConstraint @@ -152,7 +156,6 @@ constrain env (L _ _ expr) tipe = v <- flexibleVar -- needs an ex c <- constrain env body (VarN v) return ((name, v), c) - Markdown _ -> return $ tipe === Env.get env Env.builtin "Element" diff --git a/compiler/Type/Solve.hs b/compiler/Type/Solve.hs index 9c9be2a..d187d74 100644 --- a/compiler/Type/Solve.hs +++ b/compiler/Type/Solve.hs @@ -4,7 +4,6 @@ module Type.Solve where import Control.Monad import Control.Monad.State import qualified Data.UnionFind.IO as UF -import qualified Data.Array.IO as Array import qualified Data.Map as Map import qualified Data.Traversable as Traversable import qualified Data.Maybe as Maybe @@ -13,73 +12,78 @@ import Type.Unify import qualified Type.Environment as Env import qualified Type.State as TS -register = undefined -generalize :: TS.Pool -> TS.Pool -> IO [Variable] -generalize oldPool youngPool = do - let young = 0 - visited = 1 - youngRank = TS.maxRank youngPool +-- | Every variable has rank less than or equal to the maxRank of the pool. +-- This sorts variables into the young and old pools accordingly. +generalize :: TS.Pool -> StateT TS.SolverState IO () +generalize youngPool = do + let youngRank = TS.maxRank youngPool + insert dict var = do + desc <- liftIO $ UF.descriptor var + return $ Map.insertWith (++) (rank desc) [var] dict - array' <- Array.newArray (0, youngRank) [] - let array = array' :: Array.IOArray Int [Variable] + -- Sort the youngPool variables by rank. + rankDict <- foldM insert Map.empty (TS.inhabitants youngPool) - -- Insert all of the youngPool variables into the array. - -- They are placed into a list at the index corresponding - -- to their rank. - forM (TS.inhabitants youngPool) $ \var -> do - desc <- UF.descriptor var - vars <- Array.readArray array (rank desc) - Array.writeArray array (rank desc) (var : vars) - - -- get the ranks right for each entry - forM [0 .. youngRank] $ \i -> do - vars <- Array.readArray array i - mapM (traverse young visited i) vars + -- get the ranks right for each entry. + -- start at low ranks so that we only have to pass + -- over the information once. + youngMark <- TS.uniqueMark + visitedMark <- TS.uniqueMark + Traversable.traverse (mapM (adjustRank youngMark visitedMark youngRank)) rankDict - -- do not need to work with variables that have become redundant - vars <- Array.readArray array youngRank - forM vars $ \var -> do - isRedundant <- UF.redundant var - if isRedundant then do - desc <- UF.descriptor var - if rank desc < youngRank - then register oldPool var - else let flex' = if flex desc == Flexible then Rigid else flex desc - in UF.setDescriptor var (desc { rank = noRank, flex = flex' }) - else return () + -- Move variables out of the young pool if they do not have a young rank. + -- We should not generalize things we cannot use. + let youngVars = (Map.!) rankDict youngRank - return vars + registerIfNotRedundant var = do + isRedundant <- liftIO $ UF.redundant var + if isRedundant then return var else TS.register var -traverse :: Int -> Int -> Int -> Variable -> IO Int -traverse young visited k variable = - let f = traverse young visited k in - do desc <- UF.descriptor variable - case mark desc == young of + registerIfHigherRank var = do + isRedundant <- liftIO $ UF.redundant var + if isRedundant then return () else do + desc <- liftIO $ UF.descriptor var + if rank desc < youngRank + then TS.register var >> return () + else let flex' = if flex desc == Flexible then Rigid else flex desc + in liftIO $ UF.setDescriptor var (desc { rank = noRank, flex = flex' }) + + Traversable.traverse (mapM registerIfNotRedundant) rankDict + Traversable.traverse (mapM registerIfHigherRank) rankDict + + return () + + +-- adjust the ranks of variables such that ranks never increase as you +-- move deeper into a variable. This mean the rank actually represents the +-- deepest variable in the whole type, and we can ignore things at a lower +-- rank than the current constraints. +adjustRank :: Int -> Int -> Int -> Variable -> StateT TS.SolverState IO Int +adjustRank youngMark visitedMark groupRank variable = + let adjust = adjustRank youngMark visitedMark groupRank in + do desc <- liftIO $ UF.descriptor variable + case mark desc == youngMark of True -> do rank' <- case structure desc of - Nothing -> return k + Nothing -> return groupRank Just term -> case term of - App1 a b -> max `liftM` f a `ap` f b - Fun1 a b -> max `liftM` f a `ap` f b - Var1 x -> f x + App1 a b -> max `liftM` adjust a `ap` adjust b + Fun1 a b -> max `liftM` adjust a `ap` adjust b + Var1 x -> adjust x EmptyRecord1 -> return outermostRank Record1 fields extension -> do - ranks <- mapM f (concat (Map.elems fields)) - max (maximum ranks) `liftM` f extension - UF.setDescriptor variable (desc { mark = visited, rank = rank' }) + ranks <- mapM adjust (concat (Map.elems fields)) + max (maximum ranks) `liftM` adjust extension + liftIO $ UF.setDescriptor variable (desc { mark = visitedMark, rank = rank' }) return rank' False -> do - if mark desc /= visited then do - let rank' = min k (rank desc) - UF.setDescriptor variable (desc { mark = visited, rank = rank' }) + if mark desc == visitedMark then return (rank desc) else do + let rank' = min groupRank (rank desc) + liftIO $ UF.setDescriptor variable (desc { mark = visitedMark, rank = rank' }) return rank' - else return (rank desc) -addTo = undefined -newPool = undefined -introduce = undefined solve :: TypeConstraint -> StateT TS.SolverState IO () solve constraint = @@ -94,41 +98,67 @@ solve constraint = CAnd cs -> mapM_ solve cs CLet [Scheme [] fqs constraint' _] CTrue -> do - mapM_ introduce fqs + mapM_ TS.introduce fqs solve constraint' CLet schemes constraint' -> do - mapM solveScheme schemes + headers <- mapM solveScheme schemes + TS.modifyEnv $ \env -> Map.unions (headers ++ [env]) solve constraint' CInstance name term -> do - let instance' = undefined - inst = undefined --instance' pool (Env.get env value name) + env <- TS.getEnv + freshCopy <- TS.makeInstance ((Map.!) env name) t <- TS.flatten term - unify inst t + unify freshCopy t -solveScheme :: TypeScheme -> StateT TS.SolverState IO () +solveScheme :: TypeScheme -> StateT TS.SolverState IO (Map.Map String Variable) solveScheme scheme = case scheme of Scheme [] [] constraint header -> do solve constraint Traversable.traverse TS.flatten header - return () Scheme rigidQuantifiers flexibleQuantifiers constraint header -> do let quantifiers = rigidQuantifiers ++ flexibleQuantifiers - globalPool <- TS.getPool - localPool <- TS.newPool - TS.modifyPool (\_ -> localPool) + currentPool <- TS.getPool + + -- fill in a new pool when working on this scheme's constraints + emptyPool <- TS.nextRankPool + TS.switchToPool emptyPool mapM TS.introduce quantifiers header' <- Traversable.traverse TS.flatten header solve constraint - -- distinct variables - -- generalize - -- generic variables - TS.modifyPool (\_ -> globalPool) -isGeneric var = - do desc <- UF.descriptor var - undefined + allDistinct rigidQuantifiers + localPool <- TS.getPool + TS.switchToPool currentPool + generalize localPool + mapM isGeneric rigidQuantifiers + return header' + +-- Checks that all of the given variables belong to distinct equivalence classes. +-- Also checks that their structure is Nothing, so they represent a variable, not +-- a more complex term. +allDistinct :: [Variable] -> StateT TS.SolverState IO () +allDistinct vars = do + seen <- TS.uniqueMark + let check var = do + desc <- liftIO $ UF.descriptor var + case structure desc of + Just _ -> TS.addError "Cannot generalize something that is not a type variable." + Nothing -> do + if mark desc == seen + then TS.addError "Duplicate variable during generalization" + else return () + liftIO $ UF.setDescriptor var (desc { mark = seen }) + mapM_ check vars + +-- Check that a variable has rank == noRank, meaning that it can be generalized. +isGeneric :: Variable -> StateT TS.SolverState IO () +isGeneric var = do + desc <- liftIO $ UF.descriptor var + if rank desc == noRank + then return () + else TS.addError "Cannot generalize. Variable must have not have a rank." \ No newline at end of file diff --git a/compiler/Type/State.hs b/compiler/Type/State.hs index 09ce115..068e1c7 100644 --- a/compiler/Type/State.hs +++ b/compiler/Type/State.hs @@ -1,9 +1,16 @@ +{-# OPTIONS_GHC -XMultiWayIf #-} module Type.State where import Type.Type +import qualified Data.Map as Map import qualified Type.Environment as Env import qualified Data.UnionFind.IO as UF import Control.Monad.State +import Control.Applicative ((<$>),(<*>), Applicative) +import qualified Data.Traversable as Traversable + +-- todo: remove later +import SourceSyntax.PrettyPrint -- Pool -- Holds a bunch of variables @@ -13,23 +20,41 @@ import Control.Monad.State data Pool = Pool { maxRank :: Int, inhabitants :: [Variable] -} +} deriving Show + +emptyPool = Pool { maxRank = 0, inhabitants = [] } -- Keeps track of the environment, type variable pool, and a list of errors -type SolverState = (Env.Environment, Pool, [String]) +type SolverState = (Map.Map String Variable, Pool, Int, [String]) -modifyEnv f = modify $ \(env, pool, errors) -> (f env, pool, errors) -modifyPool f = modify $ \(env, pool, errors) -> (env, f pool, errors) -addError err = modify $ \(env, pool, errors) -> (env, pool, err:errors) +-- The mark must never be equal to noMark! +initialState = (Map.empty, emptyPool, noMark + 1, []) + +modifyEnv f = modify $ \(env, pool, mark, errors) -> (f env, pool, mark, errors) +modifyPool f = modify $ \(env, pool, mark, errors) -> (env, f pool, mark, errors) +addError err = modify $ \(env, pool, mark, errors) -> (env, pool, mark, err:errors) + +switchToPool pool = modifyPool (\_ -> pool) getPool :: StateT SolverState IO Pool getPool = do - (_, pool, _) <- get + (_, pool, _, _) <- get return pool -newPool :: StateT SolverState IO Pool -newPool = do - (_, pool, _) <- get +getEnv :: StateT SolverState IO (Map.Map String Variable) +getEnv = do + (env, _, _, _) <- get + return env + +uniqueMark :: StateT SolverState IO Int +uniqueMark = do + (env, pool, mark, errs) <- get + put (env, pool, mark+1, errs) + return mark + +nextRankPool :: StateT SolverState IO Pool +nextRankPool = do + pool <- getPool return $ Pool { maxRank = maxRank pool + 1, inhabitants = [] } register :: Variable -> StateT SolverState IO Variable @@ -39,7 +64,7 @@ register variable = do introduce :: Variable -> StateT SolverState IO Variable introduce variable = do - (_, pool, _) <- get + pool <- getPool liftIO $ UF.modifyDescriptor variable (\desc -> desc { rank = maxRank pool }) register variable @@ -48,13 +73,84 @@ flatten term = case term of VarN v -> return v TermN t -> do - flatStructure <- undefined -- chop t - (_, pool, _) <- get + flatStructure <- traverseTerm flatten t + pool <- getPool var <- liftIO . UF.fresh $ Descriptor { structure = Just flatStructure, rank = maxRank pool, flex = Flexible, name = Nothing, - mark = 0 + copy = Nothing, + mark = noMark } - register var \ No newline at end of file + register var + +makeInstance :: Variable -> StateT SolverState IO Variable +makeInstance var = do + alreadyCopied <- uniqueMark + freshVar <- makeCopy alreadyCopied var + restore alreadyCopied var + return freshVar + +makeCopy :: Int -> Variable -> StateT SolverState IO Variable +makeCopy alreadyCopied variable = do + desc <- liftIO $ UF.descriptor variable + if | mark desc == alreadyCopied -> + case copy desc of + Just v -> return v + Nothing -> error "This should be impossible." + + | mark desc /= noRank || flex desc == Constant -> + return variable + + | otherwise -> do + pool <- getPool + newVar <- liftIO $ UF.fresh $ Descriptor { + structure = Nothing, + rank = maxRank pool, + mark = noMark, + flex = Flexible, + copy = Nothing, + name = case flex desc of + Rigid -> Nothing + _ -> name desc + } + + register newVar + + -- Link the original variable to the new variable + -- Need to do this before recursively copying to + -- avoid looping on cyclic terms. + liftIO $ UF.modifyDescriptor variable $ \desc -> + desc { mark = alreadyCopied, copy = Just newVar } + + case structure desc of + Nothing -> return newVar + Just term -> do + newTerm <- traverseTerm (makeCopy alreadyCopied) term + liftIO $ UF.modifyDescriptor newVar $ \desc -> + desc { structure = Just newTerm } + return newVar + +restore :: Int -> Variable -> StateT SolverState IO Variable +restore alreadyCopied variable = do + desc <- liftIO $ UF.descriptor variable + if mark desc /= alreadyCopied then return variable else do + restoredStructure <- + case structure desc of + Nothing -> return Nothing + Just term -> Just <$> traverseTerm (restore alreadyCopied) term + liftIO $ UF.modifyDescriptor variable $ \desc -> + desc { mark = noMark, rank = noRank, structure = restoredStructure } + return variable + +traverseTerm :: (Monad f, Applicative f) => (a -> f b) -> Term1 a -> f (Term1 b) +traverseTerm f term = + case term of + App1 a b -> App1 <$> f a <*> f b + Fun1 a b -> Fun1 <$> f a <*> f b + Var1 x -> Var1 <$> f x + EmptyRecord1 -> return EmptyRecord1 + Record1 fields ext -> + Record1 <$> Traversable.traverse (mapM f) fields <*> f ext + diff --git a/compiler/Type/Type.hs b/compiler/Type/Type.hs index d590480..ca06885 100644 --- a/compiler/Type/Type.hs +++ b/compiler/Type/Type.hs @@ -50,12 +50,16 @@ data Descriptor = Descriptor { rank :: Int, flex :: Flex, name :: Maybe TypeName, + copy :: Maybe Variable, mark :: Int } deriving Show noRank = -1 outermostRank = 0 :: Int +noMark = 0 +initialMark = 1 + data Flex = Rigid | Flexible | Constant deriving (Show, Eq) @@ -85,7 +89,8 @@ namedVar name = UF.fresh $ Descriptor { rank = noRank, flex = Constant, name = Just name, - mark = 0 + copy = Nothing, + mark = noMark } flexibleVar = UF.fresh $ Descriptor { @@ -93,7 +98,8 @@ flexibleVar = UF.fresh $ Descriptor { rank = noRank, flex = Flexible, name = Nothing, - mark = 0 + copy = Nothing, + mark = noMark } rigidVar = UF.fresh $ Descriptor { @@ -101,7 +107,8 @@ rigidVar = UF.fresh $ Descriptor { rank = noRank, flex = Rigid, name = Nothing, - mark = 0 + copy = Nothing, + mark = noMark } -- ex qs constraint == exists qs. constraint diff --git a/compiler/Type/Unify.hs b/compiler/Type/Unify.hs index e881a78..5712b48 100644 --- a/compiler/Type/Unify.hs +++ b/compiler/Type/Unify.hs @@ -37,17 +37,19 @@ actuallyUnify variable1 variable2 = do merge1 :: StateT TS.SolverState IO () merge1 = liftIO $ do UF.union variable2 variable1 - UF.setDescriptor variable1 (desc1 { flex = flex', name = name', rank = rank', mark = undefined }) + UF.setDescriptor variable1 (desc1 { + flex = flex', name = name', rank = rank', copy = Nothing, mark = noMark }) merge2 :: StateT TS.SolverState IO () merge2 = liftIO $ do UF.union variable1 variable2 - UF.setDescriptor variable2 (desc2 { flex = flex', name = name', rank = rank', mark = undefined }) + UF.setDescriptor variable2 (desc2 { + flex = flex', name = name', rank = rank', copy = Nothing, mark = noMark }) fresh :: Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable fresh structure = do var <- liftIO . UF.fresh $ Descriptor { - structure = structure, rank = rank', flex = flex', name = name', mark = undefined + structure = structure, rank = rank', flex = flex', name = name', copy = Nothing, mark = noMark } TS.register var