Get the solver working on basic programs. It outputs pretty types for the

variables in the program. Need to test further and start doing some benchmarking.
This commit is contained in:
Evan Czaplicki 2013-07-09 21:52:05 +02:00
parent ca62ee64a9
commit 497d478d26
5 changed files with 230 additions and 92 deletions

View file

@ -21,8 +21,11 @@ import qualified Type.Constrain.Pattern as Pattern
{-- Testing section --}
import SourceSyntax.PrettyPrint
import Text.PrettyPrint as P
import Parse.Expression
import Parse.Helpers (iParse)
import Type.Solve (solve)
import qualified Type.State as TS
test str =
case iParse expr "" str of
@ -33,8 +36,9 @@ test str =
constraint <- constrain env expression (VarN var)
prettyNames constraint
print (pretty constraint)
print (pretty var)
return ()
print (P.text "Solving for:" <+> pretty var)
(env,_,_,_) <- execStateT (solve constraint) TS.initialState
mapM_ (\(n,t) -> print $ P.text n <+> P.text ":" <+> pretty t) $ Map.toList env
{-- todo: remove testing code --}
constrain :: Env.Environment -> LExpr a b -> Type -> IO TypeConstraint
@ -152,7 +156,6 @@ constrain env (L _ _ expr) tipe =
v <- flexibleVar -- needs an ex
c <- constrain env body (VarN v)
return ((name, v), c)
Markdown _ ->
return $ tipe === Env.get env Env.builtin "Element"

View file

@ -4,7 +4,6 @@ module Type.Solve where
import Control.Monad
import Control.Monad.State
import qualified Data.UnionFind.IO as UF
import qualified Data.Array.IO as Array
import qualified Data.Map as Map
import qualified Data.Traversable as Traversable
import qualified Data.Maybe as Maybe
@ -13,73 +12,78 @@ import Type.Unify
import qualified Type.Environment as Env
import qualified Type.State as TS
register = undefined
generalize :: TS.Pool -> TS.Pool -> IO [Variable]
generalize oldPool youngPool = do
let young = 0
visited = 1
youngRank = TS.maxRank youngPool
-- | Every variable has rank less than or equal to the maxRank of the pool.
-- This sorts variables into the young and old pools accordingly.
generalize :: TS.Pool -> StateT TS.SolverState IO ()
generalize youngPool = do
let youngRank = TS.maxRank youngPool
insert dict var = do
desc <- liftIO $ UF.descriptor var
return $ Map.insertWith (++) (rank desc) [var] dict
array' <- Array.newArray (0, youngRank) []
let array = array' :: Array.IOArray Int [Variable]
-- Sort the youngPool variables by rank.
rankDict <- foldM insert Map.empty (TS.inhabitants youngPool)
-- Insert all of the youngPool variables into the array.
-- They are placed into a list at the index corresponding
-- to their rank.
forM (TS.inhabitants youngPool) $ \var -> do
desc <- UF.descriptor var
vars <- Array.readArray array (rank desc)
Array.writeArray array (rank desc) (var : vars)
-- get the ranks right for each entry
forM [0 .. youngRank] $ \i -> do
vars <- Array.readArray array i
mapM (traverse young visited i) vars
-- get the ranks right for each entry.
-- start at low ranks so that we only have to pass
-- over the information once.
youngMark <- TS.uniqueMark
visitedMark <- TS.uniqueMark
Traversable.traverse (mapM (adjustRank youngMark visitedMark youngRank)) rankDict
-- do not need to work with variables that have become redundant
vars <- Array.readArray array youngRank
forM vars $ \var -> do
isRedundant <- UF.redundant var
if isRedundant then do
desc <- UF.descriptor var
if rank desc < youngRank
then register oldPool var
else let flex' = if flex desc == Flexible then Rigid else flex desc
in UF.setDescriptor var (desc { rank = noRank, flex = flex' })
else return ()
-- Move variables out of the young pool if they do not have a young rank.
-- We should not generalize things we cannot use.
let youngVars = (Map.!) rankDict youngRank
return vars
registerIfNotRedundant var = do
isRedundant <- liftIO $ UF.redundant var
if isRedundant then return var else TS.register var
traverse :: Int -> Int -> Int -> Variable -> IO Int
traverse young visited k variable =
let f = traverse young visited k in
do desc <- UF.descriptor variable
case mark desc == young of
registerIfHigherRank var = do
isRedundant <- liftIO $ UF.redundant var
if isRedundant then return () else do
desc <- liftIO $ UF.descriptor var
if rank desc < youngRank
then TS.register var >> return ()
else let flex' = if flex desc == Flexible then Rigid else flex desc
in liftIO $ UF.setDescriptor var (desc { rank = noRank, flex = flex' })
Traversable.traverse (mapM registerIfNotRedundant) rankDict
Traversable.traverse (mapM registerIfHigherRank) rankDict
return ()
-- adjust the ranks of variables such that ranks never increase as you
-- move deeper into a variable. This mean the rank actually represents the
-- deepest variable in the whole type, and we can ignore things at a lower
-- rank than the current constraints.
adjustRank :: Int -> Int -> Int -> Variable -> StateT TS.SolverState IO Int
adjustRank youngMark visitedMark groupRank variable =
let adjust = adjustRank youngMark visitedMark groupRank in
do desc <- liftIO $ UF.descriptor variable
case mark desc == youngMark of
True -> do
rank' <- case structure desc of
Nothing -> return k
Nothing -> return groupRank
Just term -> case term of
App1 a b -> max `liftM` f a `ap` f b
Fun1 a b -> max `liftM` f a `ap` f b
Var1 x -> f x
App1 a b -> max `liftM` adjust a `ap` adjust b
Fun1 a b -> max `liftM` adjust a `ap` adjust b
Var1 x -> adjust x
EmptyRecord1 -> return outermostRank
Record1 fields extension -> do
ranks <- mapM f (concat (Map.elems fields))
max (maximum ranks) `liftM` f extension
UF.setDescriptor variable (desc { mark = visited, rank = rank' })
ranks <- mapM adjust (concat (Map.elems fields))
max (maximum ranks) `liftM` adjust extension
liftIO $ UF.setDescriptor variable (desc { mark = visitedMark, rank = rank' })
return rank'
False -> do
if mark desc /= visited then do
let rank' = min k (rank desc)
UF.setDescriptor variable (desc { mark = visited, rank = rank' })
if mark desc == visitedMark then return (rank desc) else do
let rank' = min groupRank (rank desc)
liftIO $ UF.setDescriptor variable (desc { mark = visitedMark, rank = rank' })
return rank'
else return (rank desc)
addTo = undefined
newPool = undefined
introduce = undefined
solve :: TypeConstraint -> StateT TS.SolverState IO ()
solve constraint =
@ -94,41 +98,67 @@ solve constraint =
CAnd cs -> mapM_ solve cs
CLet [Scheme [] fqs constraint' _] CTrue -> do
mapM_ introduce fqs
mapM_ TS.introduce fqs
solve constraint'
CLet schemes constraint' -> do
mapM solveScheme schemes
headers <- mapM solveScheme schemes
TS.modifyEnv $ \env -> Map.unions (headers ++ [env])
solve constraint'
CInstance name term -> do
let instance' = undefined
inst = undefined --instance' pool (Env.get env value name)
env <- TS.getEnv
freshCopy <- TS.makeInstance ((Map.!) env name)
t <- TS.flatten term
unify inst t
unify freshCopy t
solveScheme :: TypeScheme -> StateT TS.SolverState IO ()
solveScheme :: TypeScheme -> StateT TS.SolverState IO (Map.Map String Variable)
solveScheme scheme =
case scheme of
Scheme [] [] constraint header -> do
solve constraint
Traversable.traverse TS.flatten header
return ()
Scheme rigidQuantifiers flexibleQuantifiers constraint header -> do
let quantifiers = rigidQuantifiers ++ flexibleQuantifiers
globalPool <- TS.getPool
localPool <- TS.newPool
TS.modifyPool (\_ -> localPool)
currentPool <- TS.getPool
-- fill in a new pool when working on this scheme's constraints
emptyPool <- TS.nextRankPool
TS.switchToPool emptyPool
mapM TS.introduce quantifiers
header' <- Traversable.traverse TS.flatten header
solve constraint
-- distinct variables
-- generalize
-- generic variables
TS.modifyPool (\_ -> globalPool)
isGeneric var =
do desc <- UF.descriptor var
undefined
allDistinct rigidQuantifiers
localPool <- TS.getPool
TS.switchToPool currentPool
generalize localPool
mapM isGeneric rigidQuantifiers
return header'
-- Checks that all of the given variables belong to distinct equivalence classes.
-- Also checks that their structure is Nothing, so they represent a variable, not
-- a more complex term.
allDistinct :: [Variable] -> StateT TS.SolverState IO ()
allDistinct vars = do
seen <- TS.uniqueMark
let check var = do
desc <- liftIO $ UF.descriptor var
case structure desc of
Just _ -> TS.addError "Cannot generalize something that is not a type variable."
Nothing -> do
if mark desc == seen
then TS.addError "Duplicate variable during generalization"
else return ()
liftIO $ UF.setDescriptor var (desc { mark = seen })
mapM_ check vars
-- Check that a variable has rank == noRank, meaning that it can be generalized.
isGeneric :: Variable -> StateT TS.SolverState IO ()
isGeneric var = do
desc <- liftIO $ UF.descriptor var
if rank desc == noRank
then return ()
else TS.addError "Cannot generalize. Variable must have not have a rank."

View file

@ -1,9 +1,16 @@
{-# OPTIONS_GHC -XMultiWayIf #-}
module Type.State where
import Type.Type
import qualified Data.Map as Map
import qualified Type.Environment as Env
import qualified Data.UnionFind.IO as UF
import Control.Monad.State
import Control.Applicative ((<$>),(<*>), Applicative)
import qualified Data.Traversable as Traversable
-- todo: remove later
import SourceSyntax.PrettyPrint
-- Pool
-- Holds a bunch of variables
@ -13,23 +20,41 @@ import Control.Monad.State
data Pool = Pool {
maxRank :: Int,
inhabitants :: [Variable]
}
} deriving Show
emptyPool = Pool { maxRank = 0, inhabitants = [] }
-- Keeps track of the environment, type variable pool, and a list of errors
type SolverState = (Env.Environment, Pool, [String])
type SolverState = (Map.Map String Variable, Pool, Int, [String])
modifyEnv f = modify $ \(env, pool, errors) -> (f env, pool, errors)
modifyPool f = modify $ \(env, pool, errors) -> (env, f pool, errors)
addError err = modify $ \(env, pool, errors) -> (env, pool, err:errors)
-- The mark must never be equal to noMark!
initialState = (Map.empty, emptyPool, noMark + 1, [])
modifyEnv f = modify $ \(env, pool, mark, errors) -> (f env, pool, mark, errors)
modifyPool f = modify $ \(env, pool, mark, errors) -> (env, f pool, mark, errors)
addError err = modify $ \(env, pool, mark, errors) -> (env, pool, mark, err:errors)
switchToPool pool = modifyPool (\_ -> pool)
getPool :: StateT SolverState IO Pool
getPool = do
(_, pool, _) <- get
(_, pool, _, _) <- get
return pool
newPool :: StateT SolverState IO Pool
newPool = do
(_, pool, _) <- get
getEnv :: StateT SolverState IO (Map.Map String Variable)
getEnv = do
(env, _, _, _) <- get
return env
uniqueMark :: StateT SolverState IO Int
uniqueMark = do
(env, pool, mark, errs) <- get
put (env, pool, mark+1, errs)
return mark
nextRankPool :: StateT SolverState IO Pool
nextRankPool = do
pool <- getPool
return $ Pool { maxRank = maxRank pool + 1, inhabitants = [] }
register :: Variable -> StateT SolverState IO Variable
@ -39,7 +64,7 @@ register variable = do
introduce :: Variable -> StateT SolverState IO Variable
introduce variable = do
(_, pool, _) <- get
pool <- getPool
liftIO $ UF.modifyDescriptor variable (\desc -> desc { rank = maxRank pool })
register variable
@ -48,13 +73,84 @@ flatten term =
case term of
VarN v -> return v
TermN t -> do
flatStructure <- undefined -- chop t
(_, pool, _) <- get
flatStructure <- traverseTerm flatten t
pool <- getPool
var <- liftIO . UF.fresh $ Descriptor {
structure = Just flatStructure,
rank = maxRank pool,
flex = Flexible,
name = Nothing,
mark = 0
copy = Nothing,
mark = noMark
}
register var
register var
makeInstance :: Variable -> StateT SolverState IO Variable
makeInstance var = do
alreadyCopied <- uniqueMark
freshVar <- makeCopy alreadyCopied var
restore alreadyCopied var
return freshVar
makeCopy :: Int -> Variable -> StateT SolverState IO Variable
makeCopy alreadyCopied variable = do
desc <- liftIO $ UF.descriptor variable
if | mark desc == alreadyCopied ->
case copy desc of
Just v -> return v
Nothing -> error "This should be impossible."
| mark desc /= noRank || flex desc == Constant ->
return variable
| otherwise -> do
pool <- getPool
newVar <- liftIO $ UF.fresh $ Descriptor {
structure = Nothing,
rank = maxRank pool,
mark = noMark,
flex = Flexible,
copy = Nothing,
name = case flex desc of
Rigid -> Nothing
_ -> name desc
}
register newVar
-- Link the original variable to the new variable
-- Need to do this before recursively copying to
-- avoid looping on cyclic terms.
liftIO $ UF.modifyDescriptor variable $ \desc ->
desc { mark = alreadyCopied, copy = Just newVar }
case structure desc of
Nothing -> return newVar
Just term -> do
newTerm <- traverseTerm (makeCopy alreadyCopied) term
liftIO $ UF.modifyDescriptor newVar $ \desc ->
desc { structure = Just newTerm }
return newVar
restore :: Int -> Variable -> StateT SolverState IO Variable
restore alreadyCopied variable = do
desc <- liftIO $ UF.descriptor variable
if mark desc /= alreadyCopied then return variable else do
restoredStructure <-
case structure desc of
Nothing -> return Nothing
Just term -> Just <$> traverseTerm (restore alreadyCopied) term
liftIO $ UF.modifyDescriptor variable $ \desc ->
desc { mark = noMark, rank = noRank, structure = restoredStructure }
return variable
traverseTerm :: (Monad f, Applicative f) => (a -> f b) -> Term1 a -> f (Term1 b)
traverseTerm f term =
case term of
App1 a b -> App1 <$> f a <*> f b
Fun1 a b -> Fun1 <$> f a <*> f b
Var1 x -> Var1 <$> f x
EmptyRecord1 -> return EmptyRecord1
Record1 fields ext ->
Record1 <$> Traversable.traverse (mapM f) fields <*> f ext

View file

@ -50,12 +50,16 @@ data Descriptor = Descriptor {
rank :: Int,
flex :: Flex,
name :: Maybe TypeName,
copy :: Maybe Variable,
mark :: Int
} deriving Show
noRank = -1
outermostRank = 0 :: Int
noMark = 0
initialMark = 1
data Flex = Rigid | Flexible | Constant
deriving (Show, Eq)
@ -85,7 +89,8 @@ namedVar name = UF.fresh $ Descriptor {
rank = noRank,
flex = Constant,
name = Just name,
mark = 0
copy = Nothing,
mark = noMark
}
flexibleVar = UF.fresh $ Descriptor {
@ -93,7 +98,8 @@ flexibleVar = UF.fresh $ Descriptor {
rank = noRank,
flex = Flexible,
name = Nothing,
mark = 0
copy = Nothing,
mark = noMark
}
rigidVar = UF.fresh $ Descriptor {
@ -101,7 +107,8 @@ rigidVar = UF.fresh $ Descriptor {
rank = noRank,
flex = Rigid,
name = Nothing,
mark = 0
copy = Nothing,
mark = noMark
}
-- ex qs constraint == exists qs. constraint

View file

@ -37,17 +37,19 @@ actuallyUnify variable1 variable2 = do
merge1 :: StateT TS.SolverState IO ()
merge1 = liftIO $ do
UF.union variable2 variable1
UF.setDescriptor variable1 (desc1 { flex = flex', name = name', rank = rank', mark = undefined })
UF.setDescriptor variable1 (desc1 {
flex = flex', name = name', rank = rank', copy = Nothing, mark = noMark })
merge2 :: StateT TS.SolverState IO ()
merge2 = liftIO $ do
UF.union variable1 variable2
UF.setDescriptor variable2 (desc2 { flex = flex', name = name', rank = rank', mark = undefined })
UF.setDescriptor variable2 (desc2 {
flex = flex', name = name', rank = rank', copy = Nothing, mark = noMark })
fresh :: Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable
fresh structure = do
var <- liftIO . UF.fresh $ Descriptor {
structure = structure, rank = rank', flex = flex', name = name', mark = undefined
structure = structure, rank = rank', flex = flex', name = name', copy = Nothing, mark = noMark
}
TS.register var