- Simplify basic types. Add more correct versions of Type Schemes and Super Types.

- Adjust type hints to use type schemes.
- Adjust constraint generation to use type schemes and super types properly.
- Modify unification algorithm to properly handle type schemes and super types.
This commit is contained in:
evancz 2012-07-19 11:51:57 +00:00
parent 6c47a84f96
commit 153d05438c
4 changed files with 205 additions and 176 deletions

View file

@ -10,24 +10,19 @@ import Control.Monad (liftM,mapM)
import Control.Monad.State (evalState) import Control.Monad.State (evalState)
import Guid import Guid
data Constraint = Type :=: Type
| Type :<: Type
| Type :<<: Scheme
deriving (Eq, Ord, Show)
beta = VarT `liftM` guid beta = VarT `liftM` guid
unionA = Map.unionWith (++) unionA = Map.unionWith (++)
unionsA = Map.unionsWith (++) unionsA = Map.unionsWith (++)
constrain hints expr = do constrain hints expr = do
(as,cs,t) <- inference expr (as,cs,t) <- inference expr
hs <- hints let cMap = Map.intersectionWith (\s -> map (\v -> VarT v:<<: s)) (Map.fromList hints) as
let cMap = Map.intersectionWith (\t -> map (:<: t)) (Map.fromList hs) as
return $ Set.toList cs ++ (concat . map snd $ Map.toList cMap) return $ Set.toList cs ++ (concat . map snd $ Map.toList cMap)
inference :: Expr -> GuidCounter (Map.Map String [X], Set.Set Constraint, Type)
inference (Var x) = inference (Var x) =
do b <- beta do b <- guid
return (Map.singleton x [b], Set.empty, b) return (Map.singleton x [b], Set.empty, VarT b)
inference (App e1 e2) = inference (App e1 e2) =
do (a1,c1,t1) <- inference e1 do (a1,c1,t1) <- inference e1
(a2,c2,t2) <- inference e2 (a2,c2,t2) <- inference e2
@ -39,7 +34,7 @@ inference (Lambda x e) =
do (a,c,t) <- inference e do (a,c,t) <- inference e
b <- beta b <- beta
return ( Map.delete x a return ( Map.delete x a
, Set.union c . Set.fromList . map (:=: b) $ , Set.union c . Set.fromList . map (\x -> VarT x :=: b) $
Map.findWithDefault [] x a Map.findWithDefault [] x a
, LambdaT b t ) , LambdaT b t )
inference (Let defs e) = inference (Let defs e) =
@ -47,7 +42,7 @@ inference (Let defs e) =
let (xs,es) = unzip defs let (xs,es) = unzip defs
(as,cs,ts) <- unzip3 `liftM` mapM inference es (as,cs,ts) <- unzip3 `liftM` mapM inference es
let assumptions = unionsA (a:as) let assumptions = unionsA (a:as)
let f x t = map (:<: t) $ Map.findWithDefault [] x assumptions let f x t = map (:=: t) . map VarT $ Map.findWithDefault [] x assumptions
let constraints = Set.fromList . concat $ zipWith f xs ts let constraints = Set.fromList . concat $ zipWith f xs ts
return ( foldr Map.delete assumptions xs return ( foldr Map.delete assumptions xs
, Set.unions $ c:constraints:cs , Set.unions $ c:constraints:cs
@ -58,7 +53,7 @@ inference (If e1 e2 e3) =
(a2,c2,t2) <- inference e2 (a2,c2,t2) <- inference e2
(a3,c3,t3) <- inference e3 (a3,c3,t3) <- inference e3
return ( unionsA [a1,a2,a3] return ( unionsA [a1,a2,a3]
, Set.unions [c1,c2,c3, Set.fromList [ t1 :=: BoolT, t2 :=: t3 ] ] , Set.unions [c1,c2,c3, Set.fromList [ t1 :=: bool, t2 :=: t3 ] ]
, t2 ) , t2 )
inference (Data name es) = inference $ foldl' App (Var name) es inference (Data name es) = inference $ foldl' App (Var name) es
@ -68,10 +63,12 @@ inference (Range e1 e2) = inference (Var "elmRange" `App` e1 `App` e2)
inference other = inference other =
case other of case other of
Number _ -> primitive IntT IntNum _ -> do t <- beta
Chr _ -> primitive CharT return (Map.empty, Set.singleton (t :<: number), t)
FloatNum _ -> primitive float
Chr _ -> primitive char
Str _ -> primitive string Str _ -> primitive string
Boolean _ -> primitive BoolT Boolean _ -> primitive bool
_ -> beta >>= primitive _ -> beta >>= primitive
primitive t = return (Map.empty, Set.empty, t) primitive t = return (Map.empty, Set.empty, t)

View file

@ -16,126 +16,130 @@ textToText = [ "header", "italic", "bold", "underline"
textAttrs = [ "toText" -: string ==> text textAttrs = [ "toText" -: string ==> text
, "link" -: string ==> text ==> text , "link" -: string ==> text ==> text
, "Text.height" -: IntT ==> text ==> text , "Text.height" -: int ==> text ==> text
] ++ hasType (text ==> text) textToText ] ++ hasType (text ==> text) textToText
elements = let iee = IntT ==> element ==> element in elements = let iee = int ==> element ==> element in
[ "flow" -: direction ==> listOf element ==> element [ "flow" -: direction ==> listOf element ==> element
, "layers" -: listOf element ==> element , "layers" -: listOf element ==> element
, "text" -: text ==> element , "text" -: text ==> element
, "opacity" -: iee , "opacity" -: iee
, "width" -: iee , "width" -: iee
, "height" -: iee , "height" -: iee
, "size" -: IntT ==> iee , "size" -: int ==> iee
, "box" -: iee , "box" -: iee
, "centeredText" -: text ==> element , "centeredText" -: text ==> element
, "justifiedText" -: text ==> element , "justifiedText" -: text ==> element
, "collage" -: IntT ==> IntT ==> listOf form ==> element , "collage" -: int ==> int ==> listOf form ==> element
] ]
directions = hasType direction ["up","down","left","right","inward","outward"] directions = hasType direction ["up","down","left","right","inward","outward"]
colors = [ "rgb" -: IntT ==> IntT ==> IntT ==> color colors = [ "rgb" -: int ==> int ==> int ==> color
, "rgba" -: IntT ==> IntT ==> IntT ==> IntT ==> color , "rgba" -: int ==> int ==> int ==> int ==> color
] ++ hasType color ["red","green","blue","black","white"] ] ++ hasType color ["red","green","blue","black","white"]
lineTypes = [ "line" -: listOf point ==> line lineTypes = [ "line" -: listOf point ==> line
, "customLine" -: listOf IntT ==> color ==> line ==> form , "customLine" -: listOf int ==> color ==> line ==> form
] ++ hasType (color ==> line ==> form) ["solid","dashed","dotted"] ] ++ hasType (color ==> line ==> form) ["solid","dashed","dotted"]
shapes = [ "polygon" -: listOf point ==> point ==> shape shapes = [ "polygon" -: listOf point ==> point ==> shape
, "filled" -: color ==> shape ==> form , "filled" -: color ==> shape ==> form
, "outlined" -: color ==> shape ==> form , "outlined" -: color ==> shape ==> form
, "customOutline" -: listOf IntT ==> color ==> shape ==> form , "customOutline" -: listOf int ==> color ==> shape ==> form
] ++ hasType (IntT ==> IntT ==> point ==> shape) ["ngon","rect","oval"] ] ++ hasType (int ==> int ==> point ==> shape) ["ngon","rect","oval"]
-------- Foreign -------- -------- Foreign --------
casts = casts =
[ "castJSBoolToBool" -: jsBool ==> BoolT [ "castJSBoolToBool" -: jsBool ==> bool
, "castBoolToJSBool" -: BoolT ==> jsBool , "castBoolToJSBool" -: bool ==> jsBool
, "castJSNumberToInt" -: jsNumber ==> IntT , "castJSNumberToInt" -: jsNumber ==> int
, "castIntToJSNumber" -: IntT ==> jsNumber , "castIntToJSNumber" -: int ==> jsNumber
, "castJSElementToElement" -: IntT ==> IntT ==> jsElement ==> element , "castJSElementToElement" -: int ==> int ==> jsElement ==> element
, "castElementToJSElement" -: element ==> jsElement , "castElementToJSElement" -: element ==> jsElement
, "castJSStringToString" -: jsString ==> string , "castJSStringToString" -: jsString ==> string
, "castStringToJSString" -: string ==> jsString , "castStringToJSString" -: string ==> jsString
-- , "castJSNumberToFloat -: , "castJSNumberToFloat" -: jsNumber ==> float
-- , "castFloatToJSNumber -: , "castFloatToJSNumber" -: float ==> jsNumber
] ]
polyCasts = sequence castToTuple n = (,) name $ Forall [1..n] [] (jsTuple vs ==> tupleOf vs)
[ do a <- var ; "castJSArrayToList" -:: jsArray a ==> listOf a where vs = map VarT [1..n]
, do a <- var ; "castListToJSArray" -:: listOf a ==> jsArray a name = "castJSTupleToTuple" ++ show n
, do vs <- vars 2 ; "castTupleToJSTuple2" -:: tupleOf vs ==> jsTuple vs castToJSTuple n = (,) name $ Forall [1..n] [] (tupleOf vs ==> jsTuple vs)
, do vs <- vars 3 ; "castTupleToJSTuple3" -:: tupleOf vs ==> jsTuple vs where vs = map VarT [1..n]
, do vs <- vars 4 ; "castTupleToJSTuple4" -:: tupleOf vs ==> jsTuple vs name = "castTupleToJSTuple" ++ show n
, do vs <- vars 5 ; "castTupleToJSTuple5" -:: tupleOf vs ==> jsTuple vs
, do vs <- vars 2 ; "castJSTupleToTuple2" -:: jsTuple vs ==> tupleOf vs polyCasts =
, do vs <- vars 3 ; "castJSTupleToTuple3" -:: jsTuple vs ==> tupleOf vs map castToTuple [2..5] ++ map castToJSTuple [2..5] ++
, do vs <- vars 4 ; "castJSTupleToTuple4" -:: jsTuple vs ==> tupleOf vs [ "castJSArrayToList" -:: jsArray a ==> listOf a
, do vs <- vars 5 ; "castJSTupleToTuple5" -:: jsTuple vs ==> tupleOf vs , "castListToJSArray" -:: listOf a ==> jsArray a
] ]
-------- Signals -------- -------- Signals --------
sig ts = fn ts ==> fn (map signalOf ts) sig n name = (,) name $ Forall [1..n] [] (fn ts ==> fn (map signalOf ts))
where fn = foldr1 (==>) where fn = foldr1 (==>)
ts = map VarT [1..n]
signals = sequence signals =
[ do ts <- vars 1 ; "constant" -:: sig ts [ sig 1 "constant"
, do ts <- vars 2 ; "lift" -:: sig ts , sig 2 "lift"
, do ts <- vars 3 ; "lift2" -:: sig ts , sig 3 "lift2"
, do ts <- vars 4 ; "lift3" -:: sig ts , sig 4 "lift3"
, do ts <- vars 5 ; "lift4" -:: sig ts , sig 5 "lift4"
, do [a,b] <- vars 2 , "foldp" -:: (a ==> b ==> b) ==> b ==> signalOf a ==> signalOf b
"foldp" -:: (a ==> b ==> b) ==> b ==> signalOf a ==> signalOf b , "randomize" -:: int ==> int ==> signalOf a ==> signalOf int
, do a <- var ; "randomize" -:: IntT ==> IntT ==> signalOf a ==> signalOf IntT , "count" -:: signalOf a ==> signalOf int
, do a <- var ; "count" -:: signalOf a ==> signalOf IntT , "keepIf" -:: (a==>bool) ==> a ==> signalOf a ==> signalOf a
, do a <- var ; "keepIf" -:: (a==>BoolT) ==> a ==> signalOf a ==> signalOf a , "dropIf" -:: (a==>bool) ==> a ==> signalOf a ==> signalOf a
, do a <- var ; "dropIf" -:: (a==>BoolT) ==> a ==> signalOf a ==> signalOf a , "keepWhen" -:: signalOf bool ==>a==> signalOf a ==> signalOf a
, do a <- var ; "keepWhen" -:: signalOf BoolT ==>a==> signalOf a ==> signalOf a , "dropWhen" -:: signalOf bool ==>a==> signalOf a ==> signalOf a
, do a <- var ; "dropWhen" -:: signalOf BoolT ==>a==> signalOf a ==> signalOf a , "dropRepeats" -:: signalOf a ==> signalOf a
, do a <- var ; "dropRepeats" -:: signalOf a ==> signalOf a , "sampleOn" -:: signalOf a ==> signalOf b ==> signalOf b
, do [a,b] <- vars 2 ; "sampleOn" -:: signalOf a ==> signalOf b ==> signalOf b
] ]
concreteSignals = concreteSignals =
[ "keysDown" -: signalOf (listOf IntT) [ "keysDown" -: signalOf (listOf int)
, "charPressed" -: signalOf (maybeOf IntT) , "charPressed" -: signalOf (maybeOf int)
, "inRange" -: IntT ==> IntT ==> signalOf IntT , "inRange" -: int ==> int ==> signalOf int
, "every" -: time ==> signalOf time , timeScheme "every" (\t -> t ==> signalOf t)
, "before" -: time ==> signalOf BoolT , timeScheme "before" (\t -> t ==> signalOf bool)
, "after" -: time ==> signalOf BoolT , timeScheme "after" (\t -> t ==> signalOf bool)
, "dimensions" -: signalOf point , "dimensions" -: signalOf point
, "position" -: signalOf point , "position" -: signalOf point
, "x" -: signalOf IntT , "x" -: signalOf int
, "y" -: signalOf IntT , "y" -: signalOf int
, "isDown" -: signalOf BoolT , "isDown" -: signalOf bool
, "isClicked" -: signalOf BoolT , "isClicked" -: signalOf bool
, "textField" -: string ==> tupleOf [element, signalOf string] , "textField" -: string ==> tupleOf [element, signalOf string]
, "password" -: string ==> tupleOf [element, signalOf string] , "password" -: string ==> tupleOf [element, signalOf string]
, "textArea" -: IntT ==> IntT ==> tupleOf [element, signalOf string] , "textArea" -: int ==> int ==> tupleOf [element, signalOf string]
, "checkBox" -: BoolT ==> tupleOf [element, signalOf BoolT] , "checkBox" -: bool ==> tupleOf [element, signalOf bool]
, "button" -: string ==> tupleOf [element, signalOf BoolT] , "button" -: string ==> tupleOf [element, signalOf bool]
, "stringDropDown" -: listOf string ==> tupleOf [element, signalOf string] , "stringDropDown" -: listOf string ==> tupleOf [element, signalOf string]
] ]
-------- Math and Binops -------- -------- Math and Binops --------
iii = IntT ==> IntT ==> IntT binop t = t ==> t ==> t
xxb x = x ==> x ==> BoolT numScheme t name = (name, Forall [0] [VarT 0 :<: number] (t (VarT 0)))
timeScheme name t = (name, Forall [0] [VarT 0 :<: time] (t (VarT 0)))
math = math =
hasType (IntT ==> iii) ["clamp"] ++ map (numScheme (\t -> t ==> binop t)) ["clamp"] ++
hasType iii ["+", "-", "*", "/","rem","mod","logBase","max","min"] ++ map (numScheme (\t -> binop t)) ["+","-","*","max","min"] ++
hasType (IntT ==> IntT) ["sin","cos","tan","asin","acos","atan","sqrt","abs"] [ numScheme (\t -> t ==> t) "abs" ] ++
[ "/" -: binop float ] ++
hasType (binop int) ["rem","div","mod","logBase"] ++
hasType (float ==> float) ["sin","cos","tan","asin","acos","atan","sqrt"]
bool = bools =
[ "not" -: BoolT ==> BoolT ] ++ [ "not" -: bool ==> bool ] ++
hasType (xxb BoolT) ["&&","||"] ++ hasType (binop bool) ["&&","||"] ++
hasType (xxb IntT) ["<",">","<=",">="] hasType (int ==> int ==> bool) ["<",">","<=",">="]
-------- Polymorphic Functions -------- -------- Polymorphic Functions --------
@ -144,62 +148,60 @@ var = VarT `liftM` guid
vars n = mapM (const var) [1..n] vars n = mapM (const var) [1..n]
infix 8 -:: infix 8 -::
name -:: tipe = return $ name -: tipe name -:: tipe = (name, Forall [1,2,3] [] tipe)
funcs = sequence [a,b,c] = map VarT [1,2,3]
[ do a <- var ; "id" -:: a ==> a
, do a <- var ; "==" -:: a ==> a ==> BoolT funcs =
, do a <- var ; "/=" -:: a ==> a ==> BoolT [ "id" -:: a ==> a
, do [a,b,c] <- vars 3 ; "flip" -:: (a ==> b ==> c) ==> (b ==> a ==> c) , "==" -:: a ==> a ==> bool
, do [a,b,c] <- vars 3 ; "." -:: (b ==> c) ==> (a ==> b) ==> (a ==> c) , "/=" -:: a ==> a ==> bool
, do [a,b] <- vars 2 ; "$" -:: (a ==> b) ==> a ==> b , "flip" -:: (a ==> b ==> c) ==> (b ==> a ==> c)
, do a <- var ; ":" -:: a ==> listOf a ==> listOf a , "." -:: (b ==> c) ==> (a ==> b) ==> (a ==> c)
, do a <- var ; "++" -:: a ==> a ==> a , "$" -:: (a ==> b) ==> a ==> b
, do a <- var ; "Cons" -:: a ==> listOf a ==> listOf a , ":" -:: a ==> listOf a ==> listOf a
, do a <- var ; "Nil" -:: listOf a , "++" -:: a ==> a ==> a
, do a <- var ; "Just" -:: a ==> ADT "Maybe" [a] , "Cons" -:: a ==> listOf a ==> listOf a
, do a <- var ; "Nothing" -:: ADT "Maybe" [a] , "Nil" -:: listOf a
, "elmRange" -:: IntT ==> IntT ==> listOf IntT , "Just" -:: a ==> maybeOf a
, "Nothing" -:: maybeOf a
, "elmRange" -:: int ==> int ==> listOf int
] ]
ints = map (-: (listOf IntT ==> IntT)) [ "sum","product","maximum","minimum" ] lists =
[ "and" -:: listOf bool ==> bool
lists = liftM (++ints) . sequence $ , "or" -:: listOf bool ==> bool
[ "and" -:: listOf BoolT ==> BoolT , "sort" -:: listOf int ==> listOf int
, "or" -:: listOf BoolT ==> BoolT , "head" -:: listOf a ==> a
, "sort" -:: listOf IntT ==> listOf IntT , "tail" -:: listOf a ==> listOf a
, do a <- var ; "head" -:: listOf a ==> a , "length" -:: listOf a ==> int
, do a <- var ; "tail" -:: listOf a ==> listOf a , "filter" -:: (a ==> bool) ==> listOf a ==> listOf a
, do a <- var ; "length" -:: listOf a ==> IntT , "foldr1" -:: (a ==> a ==> a) ==> listOf a ==> a
, do a <- var ; "filter" -:: (a ==> BoolT) ==> listOf a ==> listOf a , "foldl1" -:: (a ==> a ==> a) ==> listOf a ==> a
, do a <- var ; "foldr1" -:: (a ==> a ==> a) ==> listOf a ==> a , "scanl1" -:: (a ==> a ==> a) ==> listOf a ==> a
, do a <- var ; "foldl1" -:: (a ==> a ==> a) ==> listOf a ==> a , "forall" -:: (a ==> bool) ==> listOf a ==> bool
, do a <- var ; "scanl1" -:: (a ==> a ==> a) ==> listOf a ==> a , "exists" -:: (a ==> bool) ==> listOf a ==> bool
, do a <- var ; "forall" -:: (a ==> BoolT) ==> listOf a ==> BoolT , "concat" -:: listOf (listOf a) ==> listOf a
, do a <- var ; "exists" -:: (a ==> BoolT) ==> listOf a ==> BoolT , "reverse" -:: listOf a ==> listOf a
, do a <- var ; "concat" -:: listOf (listOf a) ==> listOf a , "take" -:: int ==> listOf a ==> listOf a
, do a <- var ; "reverse" -:: listOf a ==> listOf a , "drop" -:: int ==> listOf a ==> listOf a
, do a <- var ; "take" -:: IntT ==> listOf a ==> listOf a , "partition" -:: (a ==> bool) ==> listOf a ==> tupleOf [listOf a,listOf a]
, do a <- var ; "drop" -:: IntT ==> listOf a ==> listOf a , "intersperse" -:: a ==> listOf a ==> listOf a
, do a <- var ; "partition" -:: (a==>BoolT)==>listOf a==>tupleOf [listOf a,listOf a] , "intercalate" -:: listOf a ==> listOf(listOf a) ==> listOf a
, do a <- var ; "intersperse" -:: a ==> listOf a ==> listOf a , "zip" -:: listOf a ==>listOf b ==>listOf(tupleOf [a,b])
, do a <- var ; "intercalate" -:: listOf a ==> listOf(listOf a) ==> listOf a , "map" -:: (a ==> b) ==> listOf a ==> listOf b
, do [a,b] <- vars 2 ; "zip" -:: listOf a ==>listOf b ==>listOf(tupleOf [a,b]) , "foldr" -:: (a ==> b ==> b) ==> b ==> listOf a ==> b
, do [a,b] <- vars 2 ; "map" -:: (a ==> b) ==> listOf a ==> listOf b , "foldl" -:: (a ==> b ==> b) ==> b ==> listOf a ==> b
, do [a,b] <- vars 2 ; "foldr" -:: (a ==> b ==> b) ==> b ==> listOf a ==> b , "scanl" -:: (a ==> b ==> b) ==> b ==> listOf a ==> listOf b
, do [a,b] <- vars 2 ; "foldl" -:: (a ==> b ==> b) ==> b ==> listOf a ==> b , "concatMap" -:: (a ==> listOf b) ==> listOf a ==> listOf b
, do [a,b] <- vars 2 ; "scanl" -:: (a==>b==>b)==>b==>listOf a==>listOf b , "zipWith" -:: (a ==> b ==> c) ==> listOf a ==> listOf b ==> listOf c
, do [a,b] <- vars 2 ; "concatMap" -:: (a==>listOf b)==>listOf a ==> listOf b ] ++ map (-: (listOf int ==> int)) [ "sum","product","maximum","minimum" ]
, do [a,b,c] <- vars 3
"zipWith" -:: (a ==> b ==> c) ==> listOf a ==> listOf b ==> listOf c
]
-------- Everything -------- -------- Everything --------
hints = do hints =
fs <- funcs ; ls <- lists ; ss <- signals ; pcasts <- polyCasts concat [ funcs, lists, signals, math, bools, str2elem, textAttrs
return $ concat [ fs, ls, ss, math, bool, str2elem, textAttrs
, elements, directions, colors, lineTypes, shapes , elements, directions, colors, lineTypes, shapes
, concreteSignals, casts, pcasts , concreteSignals, casts, polyCasts
] ]

View file

@ -1,43 +1,58 @@
module Types where module Types where
import Data.List (intercalate) import Data.Char (isDigit)
import Data.List (intercalate,isPrefixOf)
import qualified Data.Set as Set import qualified Data.Set as Set
type X = Int type X = Int
data Type = IntT data Type = LambdaT Type Type
| StringT
| CharT
| BoolT
| LambdaT Type Type
| VarT X | VarT X
| ADT String [Type] | ADT String [Type]
deriving (Eq, Ord) deriving (Eq, Ord)
data Scheme = Forall (Set.Set X) Type deriving (Eq, Ord, Show) data Scheme = Forall [X] [Constraint] Type deriving (Eq, Ord, Show)
element = ADT "Element" [] data SuperType = SuperType String (Set.Set Type) deriving (Eq, Ord)
direction = ADT "Direction" []
form = ADT "Form" [] data Constraint = Type :=: Type
line = ADT "Line" [] | Type :<: SuperType
shape = ADT "Shape" [] | Type :<<: Scheme
color = ADT "Color" [] deriving (Eq, Ord, Show)
text = ADT "List" [ADT "Text" []]
point = tupleOf [IntT,IntT] tipe t = ADT t []
int = tipe "Int"
float = tipe "Float"
number = SuperType "Number" (Set.fromList [ int, float ])
char = tipe "Char"
bool = tipe "Bool"
string = tipe "String"
text = tipe "Text"
time = SuperType "Time" (Set.fromList [ int, float ])
element = tipe "Element"
direction = tipe "Direction"
form = tipe "Form"
line = tipe "Line"
shape = tipe "Shape"
color = tipe "Color"
point = tupleOf [int,int]
listOf t = ADT "List" [t] listOf t = ADT "List" [t]
signalOf t = ADT "Signal" [t] signalOf t = ADT "Signal" [t]
tupleOf ts = ADT ("Tuple" ++ show (length ts)) ts tupleOf ts = ADT ("Tuple" ++ show (length ts)) ts
maybeOf t = ADT "Maybe" [t] maybeOf t = ADT "Maybe" [t]
string = listOf CharT
time = IntT
jsBool = ADT "JSBool" [] jsBool = tipe "JSBool"
jsNumber = ADT "JSNumber" [] jsNumber = tipe "JSNumber"
jsString = ADT "JSString" [] jsString = tipe "JSString"
jsElement = ADT "JSElement" [] jsElement = tipe "JSElement"
jsArray t = ADT "JSArray" [t] jsArray t = ADT "JSArray" [t]
jsTuple ts = ADT ("JSTuple" ++ show (length ts)) ts jsTuple ts = ADT ("JSTuple" ++ show (length ts)) ts
@ -45,7 +60,7 @@ infixr ==>
t1 ==> t2 = LambdaT t1 t2 t1 ==> t2 = LambdaT t1 t2
infix 8 -: infix 8 -:
name -: tipe = (,) name tipe name -: tipe = (,) name $ Forall [] [] tipe
hasType t = map (-: t) hasType t = map (-: t)
@ -54,14 +69,19 @@ parens = ("("++) . (++")")
instance Show Type where instance Show Type where
show t = show t =
case t of case t of
{ IntT -> "Int" { LambdaT t1@(LambdaT _ _) t2 -> parens (show t1) ++ " -> " ++ show t2
; StringT -> "String"
; CharT -> "Char"
; BoolT -> "Bool"
; LambdaT t1@(LambdaT _ _) t2 -> parens (show t1) ++ " -> " ++ show t2
; LambdaT t1 t2 -> show t1 ++ " -> " ++ show t2 ; LambdaT t1 t2 -> show t1 ++ " -> " ++ show t2
; VarT x -> show x ; VarT x -> show x
; ADT "List" [tipe] -> "[" ++ show tipe ++ "]" ; ADT "List" [tipe] -> "[" ++ show tipe ++ "]"
; ADT name [] -> name ; ADT name cs ->
; ADT name cs -> parens $ name ++ " " ++ unwords (map show cs) if isTupleString name
then parens . intercalate "," $ map show cs
else case cs of [] -> name
_ -> parens $ name ++ " " ++ unwords (map show cs)
} }
instance Show SuperType where
show (SuperType n ts) = "" ++ n ++ " (a type in {" ++ subs ++ "})"
where subs = intercalate "," . map show $ Set.toList ts
isTupleString str = "Tuple" `isPrefixOf` str && all isDigit (drop 5 str)

View file

@ -29,33 +29,42 @@ solver ((LambdaT t1 t2 :=: LambdaT t1' t2') : cs) subs =
solver ((VarT x :=: t) : cs) subs = solver ((VarT x :=: t) : cs) subs =
solver (map (cSub x t) cs) . map (second $ tSub x t) $ (x,t):subs solver (map (cSub x t) cs) . map (second $ tSub x t) $ (x,t):subs
solver ((t :=: VarT x) : cs) subs = solver ((t :=: VarT x) : cs) subs = solver ((VarT x :=: t) : cs) subs
solver (map (cSub x t) cs) . map (second $ tSub x t) $ (x,t):subs
solver ((t1 :=: t2) : cs) subs = solver ((t1 :=: t2) : cs) subs =
if t1 /= t2 then uniError t1 t2 else solver cs subs if t1 /= t2 then uniError t1 t2 else solver cs subs
-------- subtypes -------- -------- subtypes --------
solver ((t1 :<: t2) : cs) subs = do solver (c@(VarT x :<: SuperType n ts) : cs) subs
let f x = do y <- guid ; return (x,VarT y) | all isSuper cs = solver ((VarT x :=: ADT n []) : cs) subs
pairs <- mapM f . Set.toList $ getVars t2 | otherwise = solver (cs ++ [c]) subs
where isSuper (VarT _ :<: _) = True
isSuper _ = False
solver ((t@(ADT n1 []) :<: st@(SuperType n2 ts)) : cs) subs
| n1 == n2 || Set.member t ts = solver cs subs
| otherwise = return . Left $ "Type error: " ++ show t ++
" is not a subtype of " ++ show st
solver ((t :<: st@(SuperType n ts)) : cs) subs
| Set.member t ts = solver cs subs
| otherwise = return . Left $ "Type error: " ++ show t ++
" is not a subtype of " ++ show st
solver ((t1 :<<: Forall xs cs' t2) : cs) subs = do
pairs <- mapM (\x -> liftM ((,) x . VarT) guid) xs
let t2' = foldr (uncurry tSub) t2 pairs let t2' = foldr (uncurry tSub) t2 pairs
solver ((t1 :=: t2') : cs) subs let cs'' = foldr (\(k,v) -> map (cSub k v)) cs' pairs
solver ((t1 :=: t2') : cs'' ++ cs) subs
cSub k v (t1 :=: t2) = force $ tSub k v t1 :=: tSub k v t2 cSub k v (t1 :=: t2) = force $ tSub k v t1 :=: tSub k v t2
cSub k v (t1 :<: t2) = force $ tSub k v t1 :<: tSub k v t2 cSub k v (t :<: super) = force $ tSub k v t :<: super
cSub k v (t :<<: poly) = force $ tSub k v t :<<: poly
tSub k v t@(VarT x) = if k == x then v else t tSub k v t@(VarT x) = if k == x then v else t
tSub k v (LambdaT t1 t2) = force $ LambdaT (tSub k v t1) (tSub k v t2) tSub k v (LambdaT t1 t2) = force $ LambdaT (tSub k v t1) (tSub k v t2)
tSub k v (ADT name ts) = ADT name (map (force . tSub k v) ts) tSub k v (ADT name ts) = ADT name (map (force . tSub k v) ts)
tSub _ _ t = t tSub _ _ t = t
getVars (VarT x) = Set.singleton x
getVars (LambdaT t1 t2) = Set.union (getVars t1) (getVars t2)
getVars (ADT name ts) = Set.unions $ map getVars ts
getVars _ = Set.empty
uniError t1 t2 = uniError t1 t2 =
return . Left $ "Type error: " ++ show t1 ++ " is not equal to " ++ show t2 return . Left $ "Type error: " ++ show t1 ++ " is not equal to " ++ show t2
@ -63,7 +72,8 @@ force x = x `deepseq` x
instance NFData Constraint where instance NFData Constraint where
rnf (t1 :=: t2) = t1 `deepseq` t2 `deepseq` () rnf (t1 :=: t2) = t1 `deepseq` t2 `deepseq` ()
rnf (t1 :<: t2) = t1 `deepseq` t2 `deepseq` () rnf (t :<: _) = t `deepseq` ()
rnf (t :<<: _) = t `deepseq` ()
instance NFData Type where instance NFData Type where
rnf (LambdaT t1 t2) = t1 `deepseq` t2 `deepseq` () rnf (LambdaT t1 t2) = t1 `deepseq` t2 `deepseq` ()