Get the solver working on basic programs. It outputs pretty types for the
variables in the program. Need to test further and start doing some benchmarking.
This commit is contained in:
parent
ca62ee64a9
commit
497d478d26
5 changed files with 230 additions and 92 deletions
|
@ -21,8 +21,11 @@ import qualified Type.Constrain.Pattern as Pattern
|
|||
|
||||
{-- Testing section --}
|
||||
import SourceSyntax.PrettyPrint
|
||||
import Text.PrettyPrint as P
|
||||
import Parse.Expression
|
||||
import Parse.Helpers (iParse)
|
||||
import Type.Solve (solve)
|
||||
import qualified Type.State as TS
|
||||
|
||||
test str =
|
||||
case iParse expr "" str of
|
||||
|
@ -33,8 +36,9 @@ test str =
|
|||
constraint <- constrain env expression (VarN var)
|
||||
prettyNames constraint
|
||||
print (pretty constraint)
|
||||
print (pretty var)
|
||||
return ()
|
||||
print (P.text "Solving for:" <+> pretty var)
|
||||
(env,_,_,_) <- execStateT (solve constraint) TS.initialState
|
||||
mapM_ (\(n,t) -> print $ P.text n <+> P.text ":" <+> pretty t) $ Map.toList env
|
||||
{-- todo: remove testing code --}
|
||||
|
||||
constrain :: Env.Environment -> LExpr a b -> Type -> IO TypeConstraint
|
||||
|
@ -153,7 +157,6 @@ constrain env (L _ _ expr) tipe =
|
|||
c <- constrain env body (VarN v)
|
||||
return ((name, v), c)
|
||||
|
||||
|
||||
Markdown _ ->
|
||||
return $ tipe === Env.get env Env.builtin "Element"
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ module Type.Solve where
|
|||
import Control.Monad
|
||||
import Control.Monad.State
|
||||
import qualified Data.UnionFind.IO as UF
|
||||
import qualified Data.Array.IO as Array
|
||||
import qualified Data.Map as Map
|
||||
import qualified Data.Traversable as Traversable
|
||||
import qualified Data.Maybe as Maybe
|
||||
|
@ -13,73 +12,78 @@ import Type.Unify
|
|||
import qualified Type.Environment as Env
|
||||
import qualified Type.State as TS
|
||||
|
||||
register = undefined
|
||||
|
||||
generalize :: TS.Pool -> TS.Pool -> IO [Variable]
|
||||
generalize oldPool youngPool = do
|
||||
let young = 0
|
||||
visited = 1
|
||||
youngRank = TS.maxRank youngPool
|
||||
-- | Every variable has rank less than or equal to the maxRank of the pool.
|
||||
-- This sorts variables into the young and old pools accordingly.
|
||||
generalize :: TS.Pool -> StateT TS.SolverState IO ()
|
||||
generalize youngPool = do
|
||||
let youngRank = TS.maxRank youngPool
|
||||
insert dict var = do
|
||||
desc <- liftIO $ UF.descriptor var
|
||||
return $ Map.insertWith (++) (rank desc) [var] dict
|
||||
|
||||
array' <- Array.newArray (0, youngRank) []
|
||||
let array = array' :: Array.IOArray Int [Variable]
|
||||
-- Sort the youngPool variables by rank.
|
||||
rankDict <- foldM insert Map.empty (TS.inhabitants youngPool)
|
||||
|
||||
-- Insert all of the youngPool variables into the array.
|
||||
-- They are placed into a list at the index corresponding
|
||||
-- to their rank.
|
||||
forM (TS.inhabitants youngPool) $ \var -> do
|
||||
desc <- UF.descriptor var
|
||||
vars <- Array.readArray array (rank desc)
|
||||
Array.writeArray array (rank desc) (var : vars)
|
||||
-- get the ranks right for each entry.
|
||||
-- start at low ranks so that we only have to pass
|
||||
-- over the information once.
|
||||
youngMark <- TS.uniqueMark
|
||||
visitedMark <- TS.uniqueMark
|
||||
Traversable.traverse (mapM (adjustRank youngMark visitedMark youngRank)) rankDict
|
||||
|
||||
-- get the ranks right for each entry
|
||||
forM [0 .. youngRank] $ \i -> do
|
||||
vars <- Array.readArray array i
|
||||
mapM (traverse young visited i) vars
|
||||
-- Move variables out of the young pool if they do not have a young rank.
|
||||
-- We should not generalize things we cannot use.
|
||||
let youngVars = (Map.!) rankDict youngRank
|
||||
|
||||
-- do not need to work with variables that have become redundant
|
||||
vars <- Array.readArray array youngRank
|
||||
forM vars $ \var -> do
|
||||
isRedundant <- UF.redundant var
|
||||
if isRedundant then do
|
||||
desc <- UF.descriptor var
|
||||
registerIfNotRedundant var = do
|
||||
isRedundant <- liftIO $ UF.redundant var
|
||||
if isRedundant then return var else TS.register var
|
||||
|
||||
registerIfHigherRank var = do
|
||||
isRedundant <- liftIO $ UF.redundant var
|
||||
if isRedundant then return () else do
|
||||
desc <- liftIO $ UF.descriptor var
|
||||
if rank desc < youngRank
|
||||
then register oldPool var
|
||||
then TS.register var >> return ()
|
||||
else let flex' = if flex desc == Flexible then Rigid else flex desc
|
||||
in UF.setDescriptor var (desc { rank = noRank, flex = flex' })
|
||||
else return ()
|
||||
in liftIO $ UF.setDescriptor var (desc { rank = noRank, flex = flex' })
|
||||
|
||||
return vars
|
||||
Traversable.traverse (mapM registerIfNotRedundant) rankDict
|
||||
Traversable.traverse (mapM registerIfHigherRank) rankDict
|
||||
|
||||
traverse :: Int -> Int -> Int -> Variable -> IO Int
|
||||
traverse young visited k variable =
|
||||
let f = traverse young visited k in
|
||||
do desc <- UF.descriptor variable
|
||||
case mark desc == young of
|
||||
return ()
|
||||
|
||||
|
||||
-- adjust the ranks of variables such that ranks never increase as you
|
||||
-- move deeper into a variable. This mean the rank actually represents the
|
||||
-- deepest variable in the whole type, and we can ignore things at a lower
|
||||
-- rank than the current constraints.
|
||||
adjustRank :: Int -> Int -> Int -> Variable -> StateT TS.SolverState IO Int
|
||||
adjustRank youngMark visitedMark groupRank variable =
|
||||
let adjust = adjustRank youngMark visitedMark groupRank in
|
||||
do desc <- liftIO $ UF.descriptor variable
|
||||
case mark desc == youngMark of
|
||||
True -> do
|
||||
rank' <- case structure desc of
|
||||
Nothing -> return k
|
||||
Nothing -> return groupRank
|
||||
Just term -> case term of
|
||||
App1 a b -> max `liftM` f a `ap` f b
|
||||
Fun1 a b -> max `liftM` f a `ap` f b
|
||||
Var1 x -> f x
|
||||
App1 a b -> max `liftM` adjust a `ap` adjust b
|
||||
Fun1 a b -> max `liftM` adjust a `ap` adjust b
|
||||
Var1 x -> adjust x
|
||||
EmptyRecord1 -> return outermostRank
|
||||
Record1 fields extension -> do
|
||||
ranks <- mapM f (concat (Map.elems fields))
|
||||
max (maximum ranks) `liftM` f extension
|
||||
UF.setDescriptor variable (desc { mark = visited, rank = rank' })
|
||||
ranks <- mapM adjust (concat (Map.elems fields))
|
||||
max (maximum ranks) `liftM` adjust extension
|
||||
liftIO $ UF.setDescriptor variable (desc { mark = visitedMark, rank = rank' })
|
||||
return rank'
|
||||
|
||||
False -> do
|
||||
if mark desc /= visited then do
|
||||
let rank' = min k (rank desc)
|
||||
UF.setDescriptor variable (desc { mark = visited, rank = rank' })
|
||||
if mark desc == visitedMark then return (rank desc) else do
|
||||
let rank' = min groupRank (rank desc)
|
||||
liftIO $ UF.setDescriptor variable (desc { mark = visitedMark, rank = rank' })
|
||||
return rank'
|
||||
else return (rank desc)
|
||||
|
||||
addTo = undefined
|
||||
newPool = undefined
|
||||
introduce = undefined
|
||||
|
||||
solve :: TypeConstraint -> StateT TS.SolverState IO ()
|
||||
solve constraint =
|
||||
|
@ -94,41 +98,67 @@ solve constraint =
|
|||
CAnd cs -> mapM_ solve cs
|
||||
|
||||
CLet [Scheme [] fqs constraint' _] CTrue -> do
|
||||
mapM_ introduce fqs
|
||||
mapM_ TS.introduce fqs
|
||||
solve constraint'
|
||||
|
||||
CLet schemes constraint' -> do
|
||||
mapM solveScheme schemes
|
||||
headers <- mapM solveScheme schemes
|
||||
TS.modifyEnv $ \env -> Map.unions (headers ++ [env])
|
||||
solve constraint'
|
||||
|
||||
CInstance name term -> do
|
||||
let instance' = undefined
|
||||
inst = undefined --instance' pool (Env.get env value name)
|
||||
env <- TS.getEnv
|
||||
freshCopy <- TS.makeInstance ((Map.!) env name)
|
||||
t <- TS.flatten term
|
||||
unify inst t
|
||||
unify freshCopy t
|
||||
|
||||
solveScheme :: TypeScheme -> StateT TS.SolverState IO ()
|
||||
solveScheme :: TypeScheme -> StateT TS.SolverState IO (Map.Map String Variable)
|
||||
solveScheme scheme =
|
||||
case scheme of
|
||||
Scheme [] [] constraint header -> do
|
||||
solve constraint
|
||||
Traversable.traverse TS.flatten header
|
||||
return ()
|
||||
|
||||
Scheme rigidQuantifiers flexibleQuantifiers constraint header -> do
|
||||
let quantifiers = rigidQuantifiers ++ flexibleQuantifiers
|
||||
globalPool <- TS.getPool
|
||||
localPool <- TS.newPool
|
||||
TS.modifyPool (\_ -> localPool)
|
||||
currentPool <- TS.getPool
|
||||
|
||||
-- fill in a new pool when working on this scheme's constraints
|
||||
emptyPool <- TS.nextRankPool
|
||||
TS.switchToPool emptyPool
|
||||
mapM TS.introduce quantifiers
|
||||
header' <- Traversable.traverse TS.flatten header
|
||||
solve constraint
|
||||
-- distinct variables
|
||||
-- generalize
|
||||
-- generic variables
|
||||
TS.modifyPool (\_ -> globalPool)
|
||||
|
||||
isGeneric var =
|
||||
do desc <- UF.descriptor var
|
||||
undefined
|
||||
allDistinct rigidQuantifiers
|
||||
localPool <- TS.getPool
|
||||
TS.switchToPool currentPool
|
||||
generalize localPool
|
||||
mapM isGeneric rigidQuantifiers
|
||||
return header'
|
||||
|
||||
|
||||
-- Checks that all of the given variables belong to distinct equivalence classes.
|
||||
-- Also checks that their structure is Nothing, so they represent a variable, not
|
||||
-- a more complex term.
|
||||
allDistinct :: [Variable] -> StateT TS.SolverState IO ()
|
||||
allDistinct vars = do
|
||||
seen <- TS.uniqueMark
|
||||
let check var = do
|
||||
desc <- liftIO $ UF.descriptor var
|
||||
case structure desc of
|
||||
Just _ -> TS.addError "Cannot generalize something that is not a type variable."
|
||||
Nothing -> do
|
||||
if mark desc == seen
|
||||
then TS.addError "Duplicate variable during generalization"
|
||||
else return ()
|
||||
liftIO $ UF.setDescriptor var (desc { mark = seen })
|
||||
mapM_ check vars
|
||||
|
||||
-- Check that a variable has rank == noRank, meaning that it can be generalized.
|
||||
isGeneric :: Variable -> StateT TS.SolverState IO ()
|
||||
isGeneric var = do
|
||||
desc <- liftIO $ UF.descriptor var
|
||||
if rank desc == noRank
|
||||
then return ()
|
||||
else TS.addError "Cannot generalize. Variable must have not have a rank."
|
|
@ -1,9 +1,16 @@
|
|||
{-# OPTIONS_GHC -XMultiWayIf #-}
|
||||
module Type.State where
|
||||
|
||||
import Type.Type
|
||||
import qualified Data.Map as Map
|
||||
import qualified Type.Environment as Env
|
||||
import qualified Data.UnionFind.IO as UF
|
||||
import Control.Monad.State
|
||||
import Control.Applicative ((<$>),(<*>), Applicative)
|
||||
import qualified Data.Traversable as Traversable
|
||||
|
||||
-- todo: remove later
|
||||
import SourceSyntax.PrettyPrint
|
||||
|
||||
-- Pool
|
||||
-- Holds a bunch of variables
|
||||
|
@ -13,23 +20,41 @@ import Control.Monad.State
|
|||
data Pool = Pool {
|
||||
maxRank :: Int,
|
||||
inhabitants :: [Variable]
|
||||
}
|
||||
} deriving Show
|
||||
|
||||
emptyPool = Pool { maxRank = 0, inhabitants = [] }
|
||||
|
||||
-- Keeps track of the environment, type variable pool, and a list of errors
|
||||
type SolverState = (Env.Environment, Pool, [String])
|
||||
type SolverState = (Map.Map String Variable, Pool, Int, [String])
|
||||
|
||||
modifyEnv f = modify $ \(env, pool, errors) -> (f env, pool, errors)
|
||||
modifyPool f = modify $ \(env, pool, errors) -> (env, f pool, errors)
|
||||
addError err = modify $ \(env, pool, errors) -> (env, pool, err:errors)
|
||||
-- The mark must never be equal to noMark!
|
||||
initialState = (Map.empty, emptyPool, noMark + 1, [])
|
||||
|
||||
modifyEnv f = modify $ \(env, pool, mark, errors) -> (f env, pool, mark, errors)
|
||||
modifyPool f = modify $ \(env, pool, mark, errors) -> (env, f pool, mark, errors)
|
||||
addError err = modify $ \(env, pool, mark, errors) -> (env, pool, mark, err:errors)
|
||||
|
||||
switchToPool pool = modifyPool (\_ -> pool)
|
||||
|
||||
getPool :: StateT SolverState IO Pool
|
||||
getPool = do
|
||||
(_, pool, _) <- get
|
||||
(_, pool, _, _) <- get
|
||||
return pool
|
||||
|
||||
newPool :: StateT SolverState IO Pool
|
||||
newPool = do
|
||||
(_, pool, _) <- get
|
||||
getEnv :: StateT SolverState IO (Map.Map String Variable)
|
||||
getEnv = do
|
||||
(env, _, _, _) <- get
|
||||
return env
|
||||
|
||||
uniqueMark :: StateT SolverState IO Int
|
||||
uniqueMark = do
|
||||
(env, pool, mark, errs) <- get
|
||||
put (env, pool, mark+1, errs)
|
||||
return mark
|
||||
|
||||
nextRankPool :: StateT SolverState IO Pool
|
||||
nextRankPool = do
|
||||
pool <- getPool
|
||||
return $ Pool { maxRank = maxRank pool + 1, inhabitants = [] }
|
||||
|
||||
register :: Variable -> StateT SolverState IO Variable
|
||||
|
@ -39,7 +64,7 @@ register variable = do
|
|||
|
||||
introduce :: Variable -> StateT SolverState IO Variable
|
||||
introduce variable = do
|
||||
(_, pool, _) <- get
|
||||
pool <- getPool
|
||||
liftIO $ UF.modifyDescriptor variable (\desc -> desc { rank = maxRank pool })
|
||||
register variable
|
||||
|
||||
|
@ -48,13 +73,84 @@ flatten term =
|
|||
case term of
|
||||
VarN v -> return v
|
||||
TermN t -> do
|
||||
flatStructure <- undefined -- chop t
|
||||
(_, pool, _) <- get
|
||||
flatStructure <- traverseTerm flatten t
|
||||
pool <- getPool
|
||||
var <- liftIO . UF.fresh $ Descriptor {
|
||||
structure = Just flatStructure,
|
||||
rank = maxRank pool,
|
||||
flex = Flexible,
|
||||
name = Nothing,
|
||||
mark = 0
|
||||
copy = Nothing,
|
||||
mark = noMark
|
||||
}
|
||||
register var
|
||||
|
||||
makeInstance :: Variable -> StateT SolverState IO Variable
|
||||
makeInstance var = do
|
||||
alreadyCopied <- uniqueMark
|
||||
freshVar <- makeCopy alreadyCopied var
|
||||
restore alreadyCopied var
|
||||
return freshVar
|
||||
|
||||
makeCopy :: Int -> Variable -> StateT SolverState IO Variable
|
||||
makeCopy alreadyCopied variable = do
|
||||
desc <- liftIO $ UF.descriptor variable
|
||||
if | mark desc == alreadyCopied ->
|
||||
case copy desc of
|
||||
Just v -> return v
|
||||
Nothing -> error "This should be impossible."
|
||||
|
||||
| mark desc /= noRank || flex desc == Constant ->
|
||||
return variable
|
||||
|
||||
| otherwise -> do
|
||||
pool <- getPool
|
||||
newVar <- liftIO $ UF.fresh $ Descriptor {
|
||||
structure = Nothing,
|
||||
rank = maxRank pool,
|
||||
mark = noMark,
|
||||
flex = Flexible,
|
||||
copy = Nothing,
|
||||
name = case flex desc of
|
||||
Rigid -> Nothing
|
||||
_ -> name desc
|
||||
}
|
||||
|
||||
register newVar
|
||||
|
||||
-- Link the original variable to the new variable
|
||||
-- Need to do this before recursively copying to
|
||||
-- avoid looping on cyclic terms.
|
||||
liftIO $ UF.modifyDescriptor variable $ \desc ->
|
||||
desc { mark = alreadyCopied, copy = Just newVar }
|
||||
|
||||
case structure desc of
|
||||
Nothing -> return newVar
|
||||
Just term -> do
|
||||
newTerm <- traverseTerm (makeCopy alreadyCopied) term
|
||||
liftIO $ UF.modifyDescriptor newVar $ \desc ->
|
||||
desc { structure = Just newTerm }
|
||||
return newVar
|
||||
|
||||
restore :: Int -> Variable -> StateT SolverState IO Variable
|
||||
restore alreadyCopied variable = do
|
||||
desc <- liftIO $ UF.descriptor variable
|
||||
if mark desc /= alreadyCopied then return variable else do
|
||||
restoredStructure <-
|
||||
case structure desc of
|
||||
Nothing -> return Nothing
|
||||
Just term -> Just <$> traverseTerm (restore alreadyCopied) term
|
||||
liftIO $ UF.modifyDescriptor variable $ \desc ->
|
||||
desc { mark = noMark, rank = noRank, structure = restoredStructure }
|
||||
return variable
|
||||
|
||||
traverseTerm :: (Monad f, Applicative f) => (a -> f b) -> Term1 a -> f (Term1 b)
|
||||
traverseTerm f term =
|
||||
case term of
|
||||
App1 a b -> App1 <$> f a <*> f b
|
||||
Fun1 a b -> Fun1 <$> f a <*> f b
|
||||
Var1 x -> Var1 <$> f x
|
||||
EmptyRecord1 -> return EmptyRecord1
|
||||
Record1 fields ext ->
|
||||
Record1 <$> Traversable.traverse (mapM f) fields <*> f ext
|
||||
|
||||
|
|
|
@ -50,12 +50,16 @@ data Descriptor = Descriptor {
|
|||
rank :: Int,
|
||||
flex :: Flex,
|
||||
name :: Maybe TypeName,
|
||||
copy :: Maybe Variable,
|
||||
mark :: Int
|
||||
} deriving Show
|
||||
|
||||
noRank = -1
|
||||
outermostRank = 0 :: Int
|
||||
|
||||
noMark = 0
|
||||
initialMark = 1
|
||||
|
||||
data Flex = Rigid | Flexible | Constant
|
||||
deriving (Show, Eq)
|
||||
|
||||
|
@ -85,7 +89,8 @@ namedVar name = UF.fresh $ Descriptor {
|
|||
rank = noRank,
|
||||
flex = Constant,
|
||||
name = Just name,
|
||||
mark = 0
|
||||
copy = Nothing,
|
||||
mark = noMark
|
||||
}
|
||||
|
||||
flexibleVar = UF.fresh $ Descriptor {
|
||||
|
@ -93,7 +98,8 @@ flexibleVar = UF.fresh $ Descriptor {
|
|||
rank = noRank,
|
||||
flex = Flexible,
|
||||
name = Nothing,
|
||||
mark = 0
|
||||
copy = Nothing,
|
||||
mark = noMark
|
||||
}
|
||||
|
||||
rigidVar = UF.fresh $ Descriptor {
|
||||
|
@ -101,7 +107,8 @@ rigidVar = UF.fresh $ Descriptor {
|
|||
rank = noRank,
|
||||
flex = Rigid,
|
||||
name = Nothing,
|
||||
mark = 0
|
||||
copy = Nothing,
|
||||
mark = noMark
|
||||
}
|
||||
|
||||
-- ex qs constraint == exists qs. constraint
|
||||
|
|
|
@ -37,17 +37,19 @@ actuallyUnify variable1 variable2 = do
|
|||
merge1 :: StateT TS.SolverState IO ()
|
||||
merge1 = liftIO $ do
|
||||
UF.union variable2 variable1
|
||||
UF.setDescriptor variable1 (desc1 { flex = flex', name = name', rank = rank', mark = undefined })
|
||||
UF.setDescriptor variable1 (desc1 {
|
||||
flex = flex', name = name', rank = rank', copy = Nothing, mark = noMark })
|
||||
|
||||
merge2 :: StateT TS.SolverState IO ()
|
||||
merge2 = liftIO $ do
|
||||
UF.union variable1 variable2
|
||||
UF.setDescriptor variable2 (desc2 { flex = flex', name = name', rank = rank', mark = undefined })
|
||||
UF.setDescriptor variable2 (desc2 {
|
||||
flex = flex', name = name', rank = rank', copy = Nothing, mark = noMark })
|
||||
|
||||
fresh :: Maybe (Term1 Variable) -> StateT TS.SolverState IO Variable
|
||||
fresh structure = do
|
||||
var <- liftIO . UF.fresh $ Descriptor {
|
||||
structure = structure, rank = rank', flex = flex', name = name', mark = undefined
|
||||
structure = structure, rank = rank', flex = flex', name = name', copy = Nothing, mark = noMark
|
||||
}
|
||||
TS.register var
|
||||
|
||||
|
|
Loading…
Reference in a new issue