diff --git a/elm/src/Types/Constrain.hs b/elm/src/Types/Constrain.hs index f503045..5c58689 100644 --- a/elm/src/Types/Constrain.hs +++ b/elm/src/Types/Constrain.hs @@ -10,24 +10,19 @@ import Control.Monad (liftM,mapM) import Control.Monad.State (evalState) import Guid -data Constraint = Type :=: Type - | Type :<: Type - | Type :<<: Scheme - deriving (Eq, Ord, Show) - beta = VarT `liftM` guid unionA = Map.unionWith (++) unionsA = Map.unionsWith (++) constrain hints expr = do (as,cs,t) <- inference expr - hs <- hints - let cMap = Map.intersectionWith (\t -> map (:<: t)) (Map.fromList hs) as + let cMap = Map.intersectionWith (\s -> map (\v -> VarT v:<<: s)) (Map.fromList hints) as return $ Set.toList cs ++ (concat . map snd $ Map.toList cMap) +inference :: Expr -> GuidCounter (Map.Map String [X], Set.Set Constraint, Type) inference (Var x) = - do b <- beta - return (Map.singleton x [b], Set.empty, b) + do b <- guid + return (Map.singleton x [b], Set.empty, VarT b) inference (App e1 e2) = do (a1,c1,t1) <- inference e1 (a2,c2,t2) <- inference e2 @@ -39,7 +34,7 @@ inference (Lambda x e) = do (a,c,t) <- inference e b <- beta return ( Map.delete x a - , Set.union c . Set.fromList . map (:=: b) $ + , Set.union c . Set.fromList . map (\x -> VarT x :=: b) $ Map.findWithDefault [] x a , LambdaT b t ) inference (Let defs e) = @@ -47,7 +42,7 @@ inference (Let defs e) = let (xs,es) = unzip defs (as,cs,ts) <- unzip3 `liftM` mapM inference es let assumptions = unionsA (a:as) - let f x t = map (:<: t) $ Map.findWithDefault [] x assumptions + let f x t = map (:=: t) . map VarT $ Map.findWithDefault [] x assumptions let constraints = Set.fromList . concat $ zipWith f xs ts return ( foldr Map.delete assumptions xs , Set.unions $ c:constraints:cs @@ -58,7 +53,7 @@ inference (If e1 e2 e3) = (a2,c2,t2) <- inference e2 (a3,c3,t3) <- inference e3 return ( unionsA [a1,a2,a3] - , Set.unions [c1,c2,c3, Set.fromList [ t1 :=: BoolT, t2 :=: t3 ] ] + , Set.unions [c1,c2,c3, Set.fromList [ t1 :=: bool, t2 :=: t3 ] ] , t2 ) inference (Data name es) = inference $ foldl' App (Var name) es @@ -68,10 +63,12 @@ inference (Range e1 e2) = inference (Var "elmRange" `App` e1 `App` e2) inference other = case other of - Number _ -> primitive IntT - Chr _ -> primitive CharT + IntNum _ -> do t <- beta + return (Map.empty, Set.singleton (t :<: number), t) + FloatNum _ -> primitive float + Chr _ -> primitive char Str _ -> primitive string - Boolean _ -> primitive BoolT + Boolean _ -> primitive bool _ -> beta >>= primitive primitive t = return (Map.empty, Set.empty, t) diff --git a/elm/src/Types/Hints.hs b/elm/src/Types/Hints.hs index bfaefd2..b05e5b1 100644 --- a/elm/src/Types/Hints.hs +++ b/elm/src/Types/Hints.hs @@ -16,126 +16,130 @@ textToText = [ "header", "italic", "bold", "underline" textAttrs = [ "toText" -: string ==> text , "link" -: string ==> text ==> text - , "Text.height" -: IntT ==> text ==> text + , "Text.height" -: int ==> text ==> text ] ++ hasType (text ==> text) textToText -elements = let iee = IntT ==> element ==> element in +elements = let iee = int ==> element ==> element in [ "flow" -: direction ==> listOf element ==> element , "layers" -: listOf element ==> element , "text" -: text ==> element , "opacity" -: iee , "width" -: iee , "height" -: iee - , "size" -: IntT ==> iee + , "size" -: int ==> iee , "box" -: iee , "centeredText" -: text ==> element , "justifiedText" -: text ==> element - , "collage" -: IntT ==> IntT ==> listOf form ==> element + , "collage" -: int ==> int ==> listOf form ==> element ] directions = hasType direction ["up","down","left","right","inward","outward"] -colors = [ "rgb" -: IntT ==> IntT ==> IntT ==> color - , "rgba" -: IntT ==> IntT ==> IntT ==> IntT ==> color +colors = [ "rgb" -: int ==> int ==> int ==> color + , "rgba" -: int ==> int ==> int ==> int ==> color ] ++ hasType color ["red","green","blue","black","white"] lineTypes = [ "line" -: listOf point ==> line - , "customLine" -: listOf IntT ==> color ==> line ==> form + , "customLine" -: listOf int ==> color ==> line ==> form ] ++ hasType (color ==> line ==> form) ["solid","dashed","dotted"] shapes = [ "polygon" -: listOf point ==> point ==> shape , "filled" -: color ==> shape ==> form , "outlined" -: color ==> shape ==> form - , "customOutline" -: listOf IntT ==> color ==> shape ==> form - ] ++ hasType (IntT ==> IntT ==> point ==> shape) ["ngon","rect","oval"] + , "customOutline" -: listOf int ==> color ==> shape ==> form + ] ++ hasType (int ==> int ==> point ==> shape) ["ngon","rect","oval"] -------- Foreign -------- casts = - [ "castJSBoolToBool" -: jsBool ==> BoolT - , "castBoolToJSBool" -: BoolT ==> jsBool - , "castJSNumberToInt" -: jsNumber ==> IntT - , "castIntToJSNumber" -: IntT ==> jsNumber - , "castJSElementToElement" -: IntT ==> IntT ==> jsElement ==> element + [ "castJSBoolToBool" -: jsBool ==> bool + , "castBoolToJSBool" -: bool ==> jsBool + , "castJSNumberToInt" -: jsNumber ==> int + , "castIntToJSNumber" -: int ==> jsNumber + , "castJSElementToElement" -: int ==> int ==> jsElement ==> element , "castElementToJSElement" -: element ==> jsElement , "castJSStringToString" -: jsString ==> string , "castStringToJSString" -: string ==> jsString - -- , "castJSNumberToFloat -: - -- , "castFloatToJSNumber -: + , "castJSNumberToFloat" -: jsNumber ==> float + , "castFloatToJSNumber" -: float ==> jsNumber ] -polyCasts = sequence - [ do a <- var ; "castJSArrayToList" -:: jsArray a ==> listOf a - , do a <- var ; "castListToJSArray" -:: listOf a ==> jsArray a - , do vs <- vars 2 ; "castTupleToJSTuple2" -:: tupleOf vs ==> jsTuple vs - , do vs <- vars 3 ; "castTupleToJSTuple3" -:: tupleOf vs ==> jsTuple vs - , do vs <- vars 4 ; "castTupleToJSTuple4" -:: tupleOf vs ==> jsTuple vs - , do vs <- vars 5 ; "castTupleToJSTuple5" -:: tupleOf vs ==> jsTuple vs - , do vs <- vars 2 ; "castJSTupleToTuple2" -:: jsTuple vs ==> tupleOf vs - , do vs <- vars 3 ; "castJSTupleToTuple3" -:: jsTuple vs ==> tupleOf vs - , do vs <- vars 4 ; "castJSTupleToTuple4" -:: jsTuple vs ==> tupleOf vs - , do vs <- vars 5 ; "castJSTupleToTuple5" -:: jsTuple vs ==> tupleOf vs +castToTuple n = (,) name $ Forall [1..n] [] (jsTuple vs ==> tupleOf vs) + where vs = map VarT [1..n] + name = "castJSTupleToTuple" ++ show n +castToJSTuple n = (,) name $ Forall [1..n] [] (tupleOf vs ==> jsTuple vs) + where vs = map VarT [1..n] + name = "castTupleToJSTuple" ++ show n + +polyCasts = + map castToTuple [2..5] ++ map castToJSTuple [2..5] ++ + [ "castJSArrayToList" -:: jsArray a ==> listOf a + , "castListToJSArray" -:: listOf a ==> jsArray a ] -------- Signals -------- -sig ts = fn ts ==> fn (map signalOf ts) +sig n name = (,) name $ Forall [1..n] [] (fn ts ==> fn (map signalOf ts)) where fn = foldr1 (==>) + ts = map VarT [1..n] -signals = sequence - [ do ts <- vars 1 ; "constant" -:: sig ts - , do ts <- vars 2 ; "lift" -:: sig ts - , do ts <- vars 3 ; "lift2" -:: sig ts - , do ts <- vars 4 ; "lift3" -:: sig ts - , do ts <- vars 5 ; "lift4" -:: sig ts - , do [a,b] <- vars 2 - "foldp" -:: (a ==> b ==> b) ==> b ==> signalOf a ==> signalOf b - , do a <- var ; "randomize" -:: IntT ==> IntT ==> signalOf a ==> signalOf IntT - , do a <- var ; "count" -:: signalOf a ==> signalOf IntT - , do a <- var ; "keepIf" -:: (a==>BoolT) ==> a ==> signalOf a ==> signalOf a - , do a <- var ; "dropIf" -:: (a==>BoolT) ==> a ==> signalOf a ==> signalOf a - , do a <- var ; "keepWhen" -:: signalOf BoolT ==>a==> signalOf a ==> signalOf a - , do a <- var ; "dropWhen" -:: signalOf BoolT ==>a==> signalOf a ==> signalOf a - , do a <- var ; "dropRepeats" -:: signalOf a ==> signalOf a - , do [a,b] <- vars 2 ; "sampleOn" -:: signalOf a ==> signalOf b ==> signalOf b +signals = + [ sig 1 "constant" + , sig 2 "lift" + , sig 3 "lift2" + , sig 4 "lift3" + , sig 5 "lift4" + , "foldp" -:: (a ==> b ==> b) ==> b ==> signalOf a ==> signalOf b + , "randomize" -:: int ==> int ==> signalOf a ==> signalOf int + , "count" -:: signalOf a ==> signalOf int + , "keepIf" -:: (a==>bool) ==> a ==> signalOf a ==> signalOf a + , "dropIf" -:: (a==>bool) ==> a ==> signalOf a ==> signalOf a + , "keepWhen" -:: signalOf bool ==>a==> signalOf a ==> signalOf a + , "dropWhen" -:: signalOf bool ==>a==> signalOf a ==> signalOf a + , "dropRepeats" -:: signalOf a ==> signalOf a + , "sampleOn" -:: signalOf a ==> signalOf b ==> signalOf b ] concreteSignals = - [ "keysDown" -: signalOf (listOf IntT) - , "charPressed" -: signalOf (maybeOf IntT) - , "inRange" -: IntT ==> IntT ==> signalOf IntT - , "every" -: time ==> signalOf time - , "before" -: time ==> signalOf BoolT - , "after" -: time ==> signalOf BoolT + [ "keysDown" -: signalOf (listOf int) + , "charPressed" -: signalOf (maybeOf int) + , "inRange" -: int ==> int ==> signalOf int + , timeScheme "every" (\t -> t ==> signalOf t) + , timeScheme "before" (\t -> t ==> signalOf bool) + , timeScheme "after" (\t -> t ==> signalOf bool) , "dimensions" -: signalOf point , "position" -: signalOf point - , "x" -: signalOf IntT - , "y" -: signalOf IntT - , "isDown" -: signalOf BoolT - , "isClicked" -: signalOf BoolT + , "x" -: signalOf int + , "y" -: signalOf int + , "isDown" -: signalOf bool + , "isClicked" -: signalOf bool , "textField" -: string ==> tupleOf [element, signalOf string] , "password" -: string ==> tupleOf [element, signalOf string] - , "textArea" -: IntT ==> IntT ==> tupleOf [element, signalOf string] - , "checkBox" -: BoolT ==> tupleOf [element, signalOf BoolT] - , "button" -: string ==> tupleOf [element, signalOf BoolT] + , "textArea" -: int ==> int ==> tupleOf [element, signalOf string] + , "checkBox" -: bool ==> tupleOf [element, signalOf bool] + , "button" -: string ==> tupleOf [element, signalOf bool] , "stringDropDown" -: listOf string ==> tupleOf [element, signalOf string] ] -------- Math and Binops -------- -iii = IntT ==> IntT ==> IntT -xxb x = x ==> x ==> BoolT +binop t = t ==> t ==> t +numScheme t name = (name, Forall [0] [VarT 0 :<: number] (t (VarT 0))) +timeScheme name t = (name, Forall [0] [VarT 0 :<: time] (t (VarT 0))) math = - hasType (IntT ==> iii) ["clamp"] ++ - hasType iii ["+", "-", "*", "/","rem","mod","logBase","max","min"] ++ - hasType (IntT ==> IntT) ["sin","cos","tan","asin","acos","atan","sqrt","abs"] + map (numScheme (\t -> t ==> binop t)) ["clamp"] ++ + map (numScheme (\t -> binop t)) ["+","-","*","max","min"] ++ + [ numScheme (\t -> t ==> t) "abs" ] ++ + [ "/" -: binop float ] ++ + hasType (binop int) ["rem","div","mod","logBase"] ++ + hasType (float ==> float) ["sin","cos","tan","asin","acos","atan","sqrt"] -bool = - [ "not" -: BoolT ==> BoolT ] ++ - hasType (xxb BoolT) ["&&","||"] ++ - hasType (xxb IntT) ["<",">","<=",">="] +bools = + [ "not" -: bool ==> bool ] ++ + hasType (binop bool) ["&&","||"] ++ + hasType (int ==> int ==> bool) ["<",">","<=",">="] -------- Polymorphic Functions -------- @@ -144,62 +148,60 @@ var = VarT `liftM` guid vars n = mapM (const var) [1..n] infix 8 -:: -name -:: tipe = return $ name -: tipe +name -:: tipe = (name, Forall [1,2,3] [] tipe) -funcs = sequence - [ do a <- var ; "id" -:: a ==> a - , do a <- var ; "==" -:: a ==> a ==> BoolT - , do a <- var ; "/=" -:: a ==> a ==> BoolT - , do [a,b,c] <- vars 3 ; "flip" -:: (a ==> b ==> c) ==> (b ==> a ==> c) - , do [a,b,c] <- vars 3 ; "." -:: (b ==> c) ==> (a ==> b) ==> (a ==> c) - , do [a,b] <- vars 2 ; "$" -:: (a ==> b) ==> a ==> b - , do a <- var ; ":" -:: a ==> listOf a ==> listOf a - , do a <- var ; "++" -:: a ==> a ==> a - , do a <- var ; "Cons" -:: a ==> listOf a ==> listOf a - , do a <- var ; "Nil" -:: listOf a - , do a <- var ; "Just" -:: a ==> ADT "Maybe" [a] - , do a <- var ; "Nothing" -:: ADT "Maybe" [a] - , "elmRange" -:: IntT ==> IntT ==> listOf IntT +[a,b,c] = map VarT [1,2,3] + +funcs = + [ "id" -:: a ==> a + , "==" -:: a ==> a ==> bool + , "/=" -:: a ==> a ==> bool + , "flip" -:: (a ==> b ==> c) ==> (b ==> a ==> c) + , "." -:: (b ==> c) ==> (a ==> b) ==> (a ==> c) + , "$" -:: (a ==> b) ==> a ==> b + , ":" -:: a ==> listOf a ==> listOf a + , "++" -:: a ==> a ==> a + , "Cons" -:: a ==> listOf a ==> listOf a + , "Nil" -:: listOf a + , "Just" -:: a ==> maybeOf a + , "Nothing" -:: maybeOf a + , "elmRange" -:: int ==> int ==> listOf int ] -ints = map (-: (listOf IntT ==> IntT)) [ "sum","product","maximum","minimum" ] - -lists = liftM (++ints) . sequence $ - [ "and" -:: listOf BoolT ==> BoolT - , "or" -:: listOf BoolT ==> BoolT - , "sort" -:: listOf IntT ==> listOf IntT - , do a <- var ; "head" -:: listOf a ==> a - , do a <- var ; "tail" -:: listOf a ==> listOf a - , do a <- var ; "length" -:: listOf a ==> IntT - , do a <- var ; "filter" -:: (a ==> BoolT) ==> listOf a ==> listOf a - , do a <- var ; "foldr1" -:: (a ==> a ==> a) ==> listOf a ==> a - , do a <- var ; "foldl1" -:: (a ==> a ==> a) ==> listOf a ==> a - , do a <- var ; "scanl1" -:: (a ==> a ==> a) ==> listOf a ==> a - , do a <- var ; "forall" -:: (a ==> BoolT) ==> listOf a ==> BoolT - , do a <- var ; "exists" -:: (a ==> BoolT) ==> listOf a ==> BoolT - , do a <- var ; "concat" -:: listOf (listOf a) ==> listOf a - , do a <- var ; "reverse" -:: listOf a ==> listOf a - , do a <- var ; "take" -:: IntT ==> listOf a ==> listOf a - , do a <- var ; "drop" -:: IntT ==> listOf a ==> listOf a - , do a <- var ; "partition" -:: (a==>BoolT)==>listOf a==>tupleOf [listOf a,listOf a] - , do a <- var ; "intersperse" -:: a ==> listOf a ==> listOf a - , do a <- var ; "intercalate" -:: listOf a ==> listOf(listOf a) ==> listOf a - , do [a,b] <- vars 2 ; "zip" -:: listOf a ==>listOf b ==>listOf(tupleOf [a,b]) - , do [a,b] <- vars 2 ; "map" -:: (a ==> b) ==> listOf a ==> listOf b - , do [a,b] <- vars 2 ; "foldr" -:: (a ==> b ==> b) ==> b ==> listOf a ==> b - , do [a,b] <- vars 2 ; "foldl" -:: (a ==> b ==> b) ==> b ==> listOf a ==> b - , do [a,b] <- vars 2 ; "scanl" -:: (a==>b==>b)==>b==>listOf a==>listOf b - , do [a,b] <- vars 2 ; "concatMap" -:: (a==>listOf b)==>listOf a ==> listOf b - , do [a,b,c] <- vars 3 - "zipWith" -:: (a ==> b ==> c) ==> listOf a ==> listOf b ==> listOf c - ] +lists = + [ "and" -:: listOf bool ==> bool + , "or" -:: listOf bool ==> bool + , "sort" -:: listOf int ==> listOf int + , "head" -:: listOf a ==> a + , "tail" -:: listOf a ==> listOf a + , "length" -:: listOf a ==> int + , "filter" -:: (a ==> bool) ==> listOf a ==> listOf a + , "foldr1" -:: (a ==> a ==> a) ==> listOf a ==> a + , "foldl1" -:: (a ==> a ==> a) ==> listOf a ==> a + , "scanl1" -:: (a ==> a ==> a) ==> listOf a ==> a + , "forall" -:: (a ==> bool) ==> listOf a ==> bool + , "exists" -:: (a ==> bool) ==> listOf a ==> bool + , "concat" -:: listOf (listOf a) ==> listOf a + , "reverse" -:: listOf a ==> listOf a + , "take" -:: int ==> listOf a ==> listOf a + , "drop" -:: int ==> listOf a ==> listOf a + , "partition" -:: (a ==> bool) ==> listOf a ==> tupleOf [listOf a,listOf a] + , "intersperse" -:: a ==> listOf a ==> listOf a + , "intercalate" -:: listOf a ==> listOf(listOf a) ==> listOf a + , "zip" -:: listOf a ==>listOf b ==>listOf(tupleOf [a,b]) + , "map" -:: (a ==> b) ==> listOf a ==> listOf b + , "foldr" -:: (a ==> b ==> b) ==> b ==> listOf a ==> b + , "foldl" -:: (a ==> b ==> b) ==> b ==> listOf a ==> b + , "scanl" -:: (a ==> b ==> b) ==> b ==> listOf a ==> listOf b + , "concatMap" -:: (a ==> listOf b) ==> listOf a ==> listOf b + , "zipWith" -:: (a ==> b ==> c) ==> listOf a ==> listOf b ==> listOf c + ] ++ map (-: (listOf int ==> int)) [ "sum","product","maximum","minimum" ] -------- Everything -------- -hints = do - fs <- funcs ; ls <- lists ; ss <- signals ; pcasts <- polyCasts - return $ concat [ fs, ls, ss, math, bool, str2elem, textAttrs - , elements, directions, colors, lineTypes, shapes - , concreteSignals, casts, pcasts - ] +hints = + concat [ funcs, lists, signals, math, bools, str2elem, textAttrs + , elements, directions, colors, lineTypes, shapes + , concreteSignals, casts, polyCasts + ] diff --git a/elm/src/Types/Types.hs b/elm/src/Types/Types.hs index 6b24e27..746f39c 100644 --- a/elm/src/Types/Types.hs +++ b/elm/src/Types/Types.hs @@ -1,43 +1,58 @@ module Types where -import Data.List (intercalate) +import Data.Char (isDigit) +import Data.List (intercalate,isPrefixOf) import qualified Data.Set as Set type X = Int -data Type = IntT - | StringT - | CharT - | BoolT - | LambdaT Type Type +data Type = LambdaT Type Type | VarT X | ADT String [Type] deriving (Eq, Ord) -data Scheme = Forall (Set.Set X) Type deriving (Eq, Ord, Show) +data Scheme = Forall [X] [Constraint] Type deriving (Eq, Ord, Show) -element = ADT "Element" [] -direction = ADT "Direction" [] +data SuperType = SuperType String (Set.Set Type) deriving (Eq, Ord) -form = ADT "Form" [] -line = ADT "Line" [] -shape = ADT "Shape" [] -color = ADT "Color" [] -text = ADT "List" [ADT "Text" []] -point = tupleOf [IntT,IntT] +data Constraint = Type :=: Type + | Type :<: SuperType + | Type :<<: Scheme + deriving (Eq, Ord, Show) + +tipe t = ADT t [] + + +int = tipe "Int" +float = tipe "Float" +number = SuperType "Number" (Set.fromList [ int, float ]) + +char = tipe "Char" +bool = tipe "Bool" + +string = tipe "String" +text = tipe "Text" + +time = SuperType "Time" (Set.fromList [ int, float ]) + +element = tipe "Element" +direction = tipe "Direction" +form = tipe "Form" +line = tipe "Line" +shape = tipe "Shape" +color = tipe "Color" +point = tupleOf [int,int] listOf t = ADT "List" [t] signalOf t = ADT "Signal" [t] tupleOf ts = ADT ("Tuple" ++ show (length ts)) ts maybeOf t = ADT "Maybe" [t] -string = listOf CharT -time = IntT -jsBool = ADT "JSBool" [] -jsNumber = ADT "JSNumber" [] -jsString = ADT "JSString" [] -jsElement = ADT "JSElement" [] +jsBool = tipe "JSBool" +jsNumber = tipe "JSNumber" +jsString = tipe "JSString" +jsElement = tipe "JSElement" jsArray t = ADT "JSArray" [t] jsTuple ts = ADT ("JSTuple" ++ show (length ts)) ts @@ -45,7 +60,7 @@ infixr ==> t1 ==> t2 = LambdaT t1 t2 infix 8 -: -name -: tipe = (,) name tipe +name -: tipe = (,) name $ Forall [] [] tipe hasType t = map (-: t) @@ -54,14 +69,19 @@ parens = ("("++) . (++")") instance Show Type where show t = case t of - { IntT -> "Int" - ; StringT -> "String" - ; CharT -> "Char" - ; BoolT -> "Bool" - ; LambdaT t1@(LambdaT _ _) t2 -> parens (show t1) ++ " -> " ++ show t2 + { LambdaT t1@(LambdaT _ _) t2 -> parens (show t1) ++ " -> " ++ show t2 ; LambdaT t1 t2 -> show t1 ++ " -> " ++ show t2 ; VarT x -> show x ; ADT "List" [tipe] -> "[" ++ show tipe ++ "]" - ; ADT name [] -> name - ; ADT name cs -> parens $ name ++ " " ++ unwords (map show cs) + ; ADT name cs -> + if isTupleString name + then parens . intercalate "," $ map show cs + else case cs of [] -> name + _ -> parens $ name ++ " " ++ unwords (map show cs) } + +instance Show SuperType where + show (SuperType n ts) = "" ++ n ++ " (a type in {" ++ subs ++ "})" + where subs = intercalate "," . map show $ Set.toList ts + +isTupleString str = "Tuple" `isPrefixOf` str && all isDigit (drop 5 str) \ No newline at end of file diff --git a/elm/src/Types/Unify.hs b/elm/src/Types/Unify.hs index bc43fb8..cacdc20 100644 --- a/elm/src/Types/Unify.hs +++ b/elm/src/Types/Unify.hs @@ -29,33 +29,42 @@ solver ((LambdaT t1 t2 :=: LambdaT t1' t2') : cs) subs = solver ((VarT x :=: t) : cs) subs = solver (map (cSub x t) cs) . map (second $ tSub x t) $ (x,t):subs -solver ((t :=: VarT x) : cs) subs = - solver (map (cSub x t) cs) . map (second $ tSub x t) $ (x,t):subs +solver ((t :=: VarT x) : cs) subs = solver ((VarT x :=: t) : cs) subs solver ((t1 :=: t2) : cs) subs = if t1 /= t2 then uniError t1 t2 else solver cs subs -------- subtypes -------- -solver ((t1 :<: t2) : cs) subs = do - let f x = do y <- guid ; return (x,VarT y) - pairs <- mapM f . Set.toList $ getVars t2 - let t2' = foldr (uncurry tSub) t2 pairs - solver ((t1 :=: t2') : cs) subs +solver (c@(VarT x :<: SuperType n ts) : cs) subs + | all isSuper cs = solver ((VarT x :=: ADT n []) : cs) subs + | otherwise = solver (cs ++ [c]) subs + where isSuper (VarT _ :<: _) = True + isSuper _ = False +solver ((t@(ADT n1 []) :<: st@(SuperType n2 ts)) : cs) subs + | n1 == n2 || Set.member t ts = solver cs subs + | otherwise = return . Left $ "Type error: " ++ show t ++ + " is not a subtype of " ++ show st +solver ((t :<: st@(SuperType n ts)) : cs) subs + | Set.member t ts = solver cs subs + | otherwise = return . Left $ "Type error: " ++ show t ++ + " is not a subtype of " ++ show st + +solver ((t1 :<<: Forall xs cs' t2) : cs) subs = do + pairs <- mapM (\x -> liftM ((,) x . VarT) guid) xs + let t2' = foldr (uncurry tSub) t2 pairs + let cs'' = foldr (\(k,v) -> map (cSub k v)) cs' pairs + solver ((t1 :=: t2') : cs'' ++ cs) subs cSub k v (t1 :=: t2) = force $ tSub k v t1 :=: tSub k v t2 -cSub k v (t1 :<: t2) = force $ tSub k v t1 :<: tSub k v t2 +cSub k v (t :<: super) = force $ tSub k v t :<: super +cSub k v (t :<<: poly) = force $ tSub k v t :<<: poly tSub k v t@(VarT x) = if k == x then v else t tSub k v (LambdaT t1 t2) = force $ LambdaT (tSub k v t1) (tSub k v t2) tSub k v (ADT name ts) = ADT name (map (force . tSub k v) ts) tSub _ _ t = t -getVars (VarT x) = Set.singleton x -getVars (LambdaT t1 t2) = Set.union (getVars t1) (getVars t2) -getVars (ADT name ts) = Set.unions $ map getVars ts -getVars _ = Set.empty - uniError t1 t2 = return . Left $ "Type error: " ++ show t1 ++ " is not equal to " ++ show t2 @@ -63,7 +72,8 @@ force x = x `deepseq` x instance NFData Constraint where rnf (t1 :=: t2) = t1 `deepseq` t2 `deepseq` () - rnf (t1 :<: t2) = t1 `deepseq` t2 `deepseq` () + rnf (t :<: _) = t `deepseq` () + rnf (t :<<: _) = t `deepseq` () instance NFData Type where rnf (LambdaT t1 t2) = t1 `deepseq` t2 `deepseq` ()