elm/compiler/Types/Solver.hs

216 lines
8.3 KiB
Haskell
Raw Normal View History

module Types.Solver (solver) where
import SourceSyntax.Everything
import Control.Arrow (second)
import Control.Monad (liftM)
import Data.Either (lefts,rights)
import Data.List (foldl')
import Data.Maybe (isJust)
import qualified Data.Set as Set
import qualified Data.Map as Map
import Unique
import Types.Types
import Types.Substitutions
2013-04-05 16:55:30 +00:00
import Types.Alias (dealias)
import System.IO.Unsafe
isSolved ss (L _ _ (t1 :=: t2)) = t1 == t2
isSolved ss (L _ _ (x :<<: _)) = isJust (lookup x ss)
isSolved ss c = False
2013-04-05 16:55:30 +00:00
type Aliases = Map.Map String ([X],Type)
crush :: Aliases -> Scheme -> Unique (Either String Scheme)
crush aliases forall@(Forall xs cs t) =
2013-04-05 16:55:30 +00:00
do subs <- solver aliases Map.empty cs
return $ do ss' <- subs
let ss = Map.toList ss'
cs' = filter (not . isSolved ss) (subst ss cs)
f x = x {-(unsafePerformIO $ do
print forall >> putStrLn "-------"
print x >> putStrLn "~~~~~~~") `seq` x-}
return . f $ Forall xs cs' (subst ss t)
schemeSubHelp txt span x s t1 rltn t2 = do
(t1',cs1) <- sub t1
(t2',cs2) <- sub t2
return (L txt span (rltn t1' t2') : cs1 ++ cs2)
where sub t | not (occurs x t) = return (t, [])
| otherwise = do (st, cs) <- concretize s
return (subst [(x,st)] t, cs)
schemeSub x s' c =
case s' of
Right s'' -> Right `liftM` schemeSub' x s'' c
Left err -> return $ Left err
schemeSub' x s c@(L txt span constraint) =
case constraint of
(t1 :=: t2) -> schemeSubHelp txt span x s t1 (:=:) t2
(t1 :<: t2) -> schemeSubHelp txt span x s t1 (:<:) t2
(y :<<: Forall cxs ccs ctipe)
| not (occurs x c) -> return [c]
| otherwise ->
do Forall xs cs tipe <- rescheme s
let ss = [(x,tipe)]
constraints = subst ss (cs ++ ccs)
c' = y :<<: Forall (cxs ++ xs) constraints (subst ss ctipe)
return [ L txt span c' ]
recordConstraints eq fs t fs' t' =
2013-01-03 08:31:03 +00:00
liftM concat . sequence $
[ constrain fs fs'
, liftM concat . mapM (\(k,ts) -> zipper [] k ts []) . Map.toList $
Map.difference fs fs'
, liftM concat . mapM (\(k,ts) -> zipper [] k [] ts) . Map.toList $
Map.difference fs' fs
]
where constrain :: Map.Map String [Type] -> Map.Map String [Type]
-> Unique [Located Constraint]
constrain as bs = liftM concat . sequence . Map.elems $
Map.intersectionWithKey (zipper []) as bs
zipper :: [Located Constraint] -> String -> [Type] -> [Type]
-> Unique [Located Constraint]
zipper cs k xs ys =
case (xs,ys) of
(a:as, b:bs) -> zipper (eq a b : cs) k as bs
([],[]) -> return cs
(as,[]) -> do x <- guid
let tipe = RecordT (Map.singleton k as) (VarT x)
2013-01-03 08:31:03 +00:00
return (cs ++ [eq t' tipe])
([],bs) -> do x <- guid
let tipe = RecordT (Map.singleton k bs) (VarT x)
2013-01-03 08:31:03 +00:00
return (cs ++ [eq t tipe])
2013-04-05 16:55:30 +00:00
solver :: Aliases
-> Map.Map X Type
-> [Located Constraint]
-> Unique (Either String (Map.Map X Type))
2013-04-05 16:55:30 +00:00
solver _ subs [] = return $ Right subs
solver aliases subs (L txt span c : cs) =
let loc = L txt span
eq t1 t2 = loc (t1 :=: t2)
2013-04-05 16:55:30 +00:00
solv = solver aliases subs
uniError' = uniError (\t1 t2 -> solv (eq t1 t2 : cs)) aliases txt span
2013-02-10 11:29:59 +00:00
in case c of
-- Destruct Type-constructors
t1@(ADT n1 ts1) :=: t2@(ADT n2 ts2) ->
2013-04-05 16:55:30 +00:00
if n1 == n2 then solv (zipWith eq ts1 ts2 ++ cs)
else uniError' t1 t2
LambdaT t1 t2 :=: LambdaT t1' t2' ->
2013-04-05 16:55:30 +00:00
solv ([ eq t1 t1', eq t2 t2' ] ++ cs)
RecordT fs t :=: RecordT fs' t' ->
do cs' <- recordConstraints eq fs t fs' t'
2013-04-05 16:55:30 +00:00
solv (cs' ++ cs)
-- Type-equality
VarT x :=: VarT y
2013-04-05 16:55:30 +00:00
| x == y -> solv cs
| otherwise ->
case (Map.lookup x subs, Map.lookup y subs) of
(Just (Super xts), Just (Super yts)) ->
let ts = Set.intersection xts yts
setXY t = Map.insert x t . Map.insert y t
in case Set.toList ts of
[] -> unionError txt span xts yts
[t] -> let cs1 = subst [(x,t),(y,t)] cs in
2013-04-05 16:55:30 +00:00
cs1 `seq` solver aliases (setXY t subs) cs1
_ -> solver aliases (setXY (Super ts) subs) cs
(Just (Super xts), _) ->
let cs2 = subst [(y,VarT x)] cs in
2013-04-05 16:55:30 +00:00
solver aliases (Map.insert y (VarT x) subs) cs2
(_, _) ->
let cs3 = subst [(x,VarT y)] cs in
2013-04-05 16:55:30 +00:00
solver aliases (Map.insert x (VarT y) subs) cs3
VarT x :=: t -> do
if x `occurs` t then occursError txt span (VarT x) t else
(case Map.lookup x subs of
2013-04-05 16:55:30 +00:00
Nothing ->
let cs4 = subst [(x,t)] cs
subs' = Map.map (subst [(x,t)]) $ Map.insert x t subs
in solver aliases subs' cs4
Just (Super ts) ->
let ts' = Set.intersection ts (Set.singleton t) in
case Set.toList ts' of
[] -> solv (loc (t :<: Super ts) : cs)
[t'] -> let cs5 = subst [(x,t)] cs in
2013-04-05 16:55:30 +00:00
solver aliases (Map.insert x t' subs) cs5
_ -> solver aliases (Map.insert x (Super ts') subs) cs
Just t' -> solv (loc (t' :=: t) : cs)
)
t :=: VarT x -> solv ((loc (VarT x :=: t)) : cs)
2013-04-05 16:55:30 +00:00
t1 :=: t2 | t1 == t2 -> solv cs
| otherwise -> uniError' t1 t2
-- subtypes
VarT x :<: Super ts ->
case Map.lookup x subs of
2013-04-05 16:55:30 +00:00
Nothing -> solver aliases (Map.insert x (Super ts) subs) cs
Just (Super ts') ->
case Set.toList $ Set.intersection ts ts' of
[] -> unionError txt span ts ts'
2013-04-05 16:55:30 +00:00
[t] -> solver aliases (Map.insert x t subs) (subst [(x,t)] cs)
ts'' -> solver aliases subs' cs
where subs' = Map.insert x (Super $ Set.fromList ts'') subs
ADT "List" [t] :<: Super ts
2013-04-05 16:55:30 +00:00
| any f (Set.toList ts) -> solv cs
| otherwise -> subtypeError txt span (ADT "List" [t]) (Super ts)
where f (ADT "List" [VarT _]) = True
f (ADT "List" [t']) = dealias aliases t == t'
f _ = False
t :<: Super ts
2013-04-05 16:55:30 +00:00
| Set.member t ts -> solv cs
| Set.member (dealias aliases t) ts -> solv cs
| otherwise -> subtypeError txt span t (Super ts)
x :<<: s
| any (occurs x) cs ->
do s' <- crush aliases s
css <- mapM (schemeSub x s') cs
case lefts css of
err : _ -> return $ Left err
2013-04-05 16:55:30 +00:00
[] -> solv (concat (rights css))
| otherwise ->
do (t,cs7) <- concretize s
solv (cs ++ loc (VarT x :=: t) : cs7)
showMsg msg = case msg of
Just str -> "\nIn context: " ++ str
Nothing -> ""
occursError msg span t1 t2 =
return . Left $ concat
[ "Type error (" ++ show span ++ "):\n"
, "Occurs check: cannot construct the infinite type:\n"
, show t1, " = ", show t2, showMsg msg ]
uniError solveWith aliases msg span t1 t2 =
let t1' = dealias aliases t1
t2' = dealias aliases t2
in if t1 /= t1' || t2 /= t2'
then solveWith t1' t2'
else return . Left $ concat
[ "Type error (" ++ show span ++ "):\n"
, show t1, " is not equal to ", show t2, showMsg msg ]
unionError msg span ts ts' =
return . Left $ concat
[ "Type error (" ++ show span ++ "):\n"
, "There are no types in both "
, show (Super ts), " and ", show (Super ts'), showMsg msg ]
subtypeError msg span t s =
return . Left $ concat
[ "Type error (" ++ show span ++ "):\n"
, show t, " is not a ", show s, showMsg msg ]