elm/compiler/Type/State.hs

212 lines
7.2 KiB
Haskell
Raw Normal View History

module Type.State where
import Type.Type
import qualified Data.Map as Map
import qualified Data.List as List
import qualified Type.Environment as Env
import qualified Data.UnionFind.IO as UF
import Control.Monad.State
import Control.Applicative ((<$>),(<*>), Applicative)
import qualified Data.Traversable as Traversable
import Text.PrettyPrint as P
import SourceSyntax.PrettyPrint
2013-07-30 23:22:20 +00:00
import SourceSyntax.Location
import qualified SourceSyntax.Type as Src
import qualified Type.Alias as Alias
-- Pool
-- Holds a bunch of variables
-- The rank of each variable is less than or equal to the pool's "maxRank"
-- The young pool exists to make it possible to identify these vars in constant time.
data Pool = Pool {
maxRank :: Int,
inhabitants :: [Variable]
} deriving Show
emptyPool = Pool { maxRank = outermostRank, inhabitants = [] }
type Env = Map.Map String Variable
-- Keeps track of the environment, type variable pool, and a list of errors
data SolverState = SS {
sEnv :: Env,
sSavedEnv :: Env,
sPool :: Pool,
sMark :: Int,
sErrors :: [Alias.Rules -> IO P.Doc]
}
initialState = SS {
sEnv = Map.empty,
sSavedEnv = Map.empty,
sPool = emptyPool,
sMark = noMark + 1, -- The mark must never be equal to noMark!
sErrors = []
}
modifyEnv f = modify $ \state -> state { sEnv = f (sEnv state) }
modifyPool f = modify $ \state -> state { sPool = f (sPool state) }
addError span hint t1 t2 =
2013-08-21 18:57:07 +00:00
modify $ \state -> state { sErrors = makeError : sErrors state }
where
makeError rules = do
let prettiest = pretty . Alias.realias rules
t1' <- prettiest <$> toSrcType t1
t2' <- prettiest <$> toSrcType t2
2013-08-21 18:57:07 +00:00
return . P.vcat $
[ display $ case span of { NoSpan msg -> msg ; Span _ _ msg -> msg }
, case hint of
Nothing -> P.text " Could not match the following types:"
Just msg -> P.text $ eightyCharLines 2 $
msg ++ ", so I could not match the following types:"
, P.text " " <> t1'
, P.text " " <> t2'
2013-08-21 18:57:07 +00:00
]
2013-07-30 23:22:20 +00:00
2013-08-03 19:28:59 +00:00
location = case span of
NoSpan msg -> ""
Span p1 p2 msg ->
if line p1 == line p2 then " on line " ++ show (line p1)
else " between lines " ++ show (line p1) ++ " and " ++ show (line p2)
display msg =
case lines msg of
[] -> P.text $ "Type error" ++ location ++ ":"
lines' ->
P.vcat [ P.text $ "Type error" ++ location ++ ", in or near this expression:"
, P.text $ " " ++ List.intercalate "\n " lines' ]
switchToPool pool = modifyPool (\_ -> pool)
getPool :: StateT SolverState IO Pool
getPool = sPool <$> get
getEnv :: StateT SolverState IO Env
getEnv = sEnv <$> get
saveLocalEnv :: StateT SolverState IO ()
saveLocalEnv = do
env <- sEnv <$> get
modify $ \state -> state { sSavedEnv = env }
uniqueMark :: StateT SolverState IO Int
uniqueMark = do
state <- get
let mark = sMark state
put $ state { sMark = mark + 1 }
return mark
nextRankPool :: StateT SolverState IO Pool
nextRankPool = do
pool <- getPool
return $ Pool { maxRank = maxRank pool + 1, inhabitants = [] }
register :: Variable -> StateT SolverState IO Variable
register variable = do
modifyPool $ \pool -> pool { inhabitants = variable : inhabitants pool }
return variable
introduce :: Variable -> StateT SolverState IO Variable
introduce variable = do
pool <- getPool
liftIO $ UF.modifyDescriptor variable (\desc -> desc { rank = maxRank pool })
register variable
flatten :: Type -> StateT SolverState IO Variable
flatten term =
case term of
VarN v -> return v
TermN t -> do
flatStructure <- traverseTerm flatten t
pool <- getPool
var <- liftIO . UF.fresh $ Descriptor {
structure = Just flatStructure,
rank = maxRank pool,
flex = Flexible,
name = Nothing,
copy = Nothing,
mark = noMark
}
register var
makeInstance :: Variable -> StateT SolverState IO Variable
makeInstance var = do
alreadyCopied <- uniqueMark
freshVar <- makeCopy alreadyCopied var
restore alreadyCopied var
return freshVar
makeCopy :: Int -> Variable -> StateT SolverState IO Variable
makeCopy alreadyCopied variable = do
desc <- liftIO $ UF.descriptor variable
2013-07-31 05:57:13 +00:00
case () of
() | mark desc == alreadyCopied ->
case copy desc of
Just v -> return v
2013-08-12 21:30:32 +00:00
Nothing -> error $ "Error copying type variable. This should be impossible." ++
" Please report an error to the github repo!"
2013-07-31 05:57:13 +00:00
| rank desc /= noRank || flex desc == Constant ->
return variable
2013-07-31 05:57:13 +00:00
| otherwise -> do
pool <- getPool
newVar <- liftIO $ UF.fresh $ Descriptor {
structure = Nothing,
rank = maxRank pool,
mark = noMark,
flex = case flex desc of
Is s -> Is s
_ -> Flexible,
copy = Nothing,
name = case flex desc of
Rigid -> Nothing
_ -> name desc
}
register newVar
-- Link the original variable to the new variable. This lets us
-- avoid making multiple copies of the variable we are instantiating.
--
-- Need to do this before recursively copying the structure of
-- the variable to avoid looping on cyclic terms.
liftIO $ UF.modifyDescriptor variable $ \desc ->
desc { mark = alreadyCopied, copy = Just newVar }
-- Now we recursively copy the structure of the variable.
-- We have already marked the variable as copied, so we
-- will not repeat this work or crawl this variable again.
case structure desc of
Nothing -> return newVar
Just term -> do
newTerm <- traverseTerm (makeCopy alreadyCopied) term
liftIO $ UF.modifyDescriptor newVar $ \desc ->
desc { structure = Just newTerm }
return newVar
restore :: Int -> Variable -> StateT SolverState IO Variable
restore alreadyCopied variable = do
desc <- liftIO $ UF.descriptor variable
if mark desc /= alreadyCopied
then return variable
else do
restoredStructure <-
Traversable.traverse (traverseTerm (restore alreadyCopied)) (structure desc)
liftIO $ UF.modifyDescriptor variable $ \desc ->
desc { mark = noMark, rank = noRank, structure = restoredStructure }
return variable
traverseTerm :: (Monad f, Applicative f) => (a -> f b) -> Term1 a -> f (Term1 b)
traverseTerm f term =
case term of
App1 a b -> App1 <$> f a <*> f b
Fun1 a b -> Fun1 <$> f a <*> f b
Var1 x -> Var1 <$> f x
EmptyRecord1 -> return EmptyRecord1
Record1 fields ext ->
Record1 <$> Traversable.traverse (mapM f) fields <*> f ext