expand tabs.

This commit is contained in:
Vincent Hanquez 2012-03-27 08:57:51 +01:00
parent 3b4baf2f91
commit 9da6b9c8c8
20 changed files with 1662 additions and 1662 deletions

View file

@ -7,9 +7,9 @@
-- --
module Network.TLS.Cap module Network.TLS.Cap
( hasHelloExtensions ( hasHelloExtensions
, hasExplicitBlockIV , hasExplicitBlockIV
) where ) where
import Network.TLS.Struct import Network.TLS.Struct

View file

@ -8,16 +8,16 @@
-- Portability : unknown -- Portability : unknown
-- --
module Network.TLS.Cipher module Network.TLS.Cipher
( BulkFunctions(..) ( BulkFunctions(..)
, CipherKeyExchangeType(..) , CipherKeyExchangeType(..)
, Bulk(..) , Bulk(..)
, Hash(..) , Hash(..)
, Cipher(..) , Cipher(..)
, cipherKeyBlockSize , cipherKeyBlockSize
, Key , Key
, IV , IV
, cipherExchangeNeedMoreData , cipherExchangeNeedMoreData
) where ) where
import Data.Word import Data.Word
import Network.TLS.Struct (Version(..)) import Network.TLS.Struct (Version(..))
@ -29,59 +29,59 @@ type Key = B.ByteString
type IV = B.ByteString type IV = B.ByteString
data BulkFunctions = data BulkFunctions =
BulkNoneF -- special value for 0 BulkNoneF -- special value for 0
| BulkBlockF (Key -> IV -> B.ByteString -> B.ByteString) | BulkBlockF (Key -> IV -> B.ByteString -> B.ByteString)
(Key -> IV -> B.ByteString -> B.ByteString) (Key -> IV -> B.ByteString -> B.ByteString)
| BulkStreamF (Key -> IV) | BulkStreamF (Key -> IV)
(IV -> B.ByteString -> (B.ByteString, IV)) (IV -> B.ByteString -> (B.ByteString, IV))
(IV -> B.ByteString -> (B.ByteString, IV)) (IV -> B.ByteString -> (B.ByteString, IV))
data CipherKeyExchangeType = data CipherKeyExchangeType =
CipherKeyExchange_RSA CipherKeyExchange_RSA
| CipherKeyExchange_DH_Anon | CipherKeyExchange_DH_Anon
| CipherKeyExchange_DHE_RSA | CipherKeyExchange_DHE_RSA
| CipherKeyExchange_ECDHE_RSA | CipherKeyExchange_ECDHE_RSA
| CipherKeyExchange_DHE_DSS | CipherKeyExchange_DHE_DSS
| CipherKeyExchange_DH_DSS | CipherKeyExchange_DH_DSS
| CipherKeyExchange_DH_RSA | CipherKeyExchange_DH_RSA
| CipherKeyExchange_ECDH_ECDSA | CipherKeyExchange_ECDH_ECDSA
| CipherKeyExchange_ECDH_RSA | CipherKeyExchange_ECDH_RSA
| CipherKeyExchange_ECDHE_ECDSA | CipherKeyExchange_ECDHE_ECDSA
deriving (Show,Eq) deriving (Show,Eq)
data Bulk = Bulk data Bulk = Bulk
{ bulkName :: String { bulkName :: String
, bulkKeySize :: Int , bulkKeySize :: Int
, bulkIVSize :: Int , bulkIVSize :: Int
, bulkBlockSize :: Int , bulkBlockSize :: Int
, bulkF :: BulkFunctions , bulkF :: BulkFunctions
} }
data Hash = Hash data Hash = Hash
{ hashName :: String { hashName :: String
, hashSize :: Int , hashSize :: Int
, hashF :: B.ByteString -> B.ByteString , hashF :: B.ByteString -> B.ByteString
} }
-- | Cipher algorithm -- | Cipher algorithm
data Cipher = Cipher data Cipher = Cipher
{ cipherID :: Word16 { cipherID :: Word16
, cipherName :: String , cipherName :: String
, cipherHash :: Hash , cipherHash :: Hash
, cipherBulk :: Bulk , cipherBulk :: Bulk
, cipherKeyExchange :: CipherKeyExchangeType , cipherKeyExchange :: CipherKeyExchangeType
, cipherMinVer :: Maybe Version , cipherMinVer :: Maybe Version
} }
cipherKeyBlockSize :: Cipher -> Int cipherKeyBlockSize :: Cipher -> Int
cipherKeyBlockSize cipher = 2 * (hashSize (cipherHash cipher) + bulkIVSize bulk + bulkKeySize bulk) cipherKeyBlockSize cipher = 2 * (hashSize (cipherHash cipher) + bulkIVSize bulk + bulkKeySize bulk)
where bulk = cipherBulk cipher where bulk = cipherBulk cipher
instance Show Cipher where instance Show Cipher where
show c = cipherName c show c = cipherName c
instance Eq Cipher where instance Eq Cipher where
(==) c1 c2 = cipherID c1 == cipherID c2 (==) c1 c2 = cipherID c1 == cipherID c2
cipherExchangeNeedMoreData :: CipherKeyExchangeType -> Bool cipherExchangeNeedMoreData :: CipherKeyExchangeType -> Bool
cipherExchangeNeedMoreData CipherKeyExchange_RSA = False cipherExchangeNeedMoreData CipherKeyExchange_RSA = False

View file

@ -8,18 +8,18 @@
-- Portability : unknown -- Portability : unknown
-- --
module Network.TLS.Compression module Network.TLS.Compression
( CompressionC(..) ( CompressionC(..)
, Compression(..) , Compression(..)
, nullCompression , nullCompression
-- * member redefined for the class abstraction -- * member redefined for the class abstraction
, compressionID , compressionID
, compressionDeflate , compressionDeflate
, compressionInflate , compressionInflate
-- * helper -- * helper
, compressionIntersectID , compressionIntersectID
) where ) where
import Data.Word import Data.Word
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
@ -27,9 +27,9 @@ import Control.Arrow (first)
-- | supported compression algorithms need to be part of this class -- | supported compression algorithms need to be part of this class
class CompressionC a where class CompressionC a where
compressionCID :: a -> Word8 compressionCID :: a -> Word8
compressionCDeflate :: a -> ByteString -> (a, ByteString) compressionCDeflate :: a -> ByteString -> (a, ByteString)
compressionCInflate :: a -> ByteString -> (a, ByteString) compressionCInflate :: a -> ByteString -> (a, ByteString)
-- | every compression need to be wrapped in this, to fit in structure -- | every compression need to be wrapped in this, to fit in structure
data Compression = forall a . CompressionC a => Compression a data Compression = forall a . CompressionC a => Compression a
@ -49,7 +49,7 @@ compressionInflate :: ByteString -> Compression -> (Compression, ByteString)
compressionInflate bytes (Compression c) = first Compression $ compressionCInflate c bytes compressionInflate bytes (Compression c) = first Compression $ compressionCInflate c bytes
instance Show Compression where instance Show Compression where
show = show . compressionID show = show . compressionID
-- | intersect a list of ids commonly given by the other side with a list of compression -- | intersect a list of ids commonly given by the other side with a list of compression
-- the function keeps the list of compression in order, to be able to find quickly the prefered -- the function keeps the list of compression in order, to be able to find quickly the prefered
@ -60,9 +60,9 @@ compressionIntersectID l ids = filter (\c -> elem (compressionID c) ids) l
data NullCompression = NullCompression data NullCompression = NullCompression
instance CompressionC NullCompression where instance CompressionC NullCompression where
compressionCID _ = 0 compressionCID _ = 0
compressionCDeflate s b = (s, b) compressionCDeflate s b = (s, b)
compressionCInflate s b = (s, b) compressionCInflate s b = (s, b)
-- | default null compression -- | default null compression
nullCompression :: Compression nullCompression :: Compression

View file

@ -6,54 +6,54 @@
-- Portability : unknown -- Portability : unknown
-- --
module Network.TLS.Context module Network.TLS.Context
( (
-- * Context configuration -- * Context configuration
Params(..) Params(..)
, Logging(..) , Logging(..)
, SessionData(..) , SessionData(..)
, Measurement(..) , Measurement(..)
, CertificateUsage(..) , CertificateUsage(..)
, CertificateRejectReason(..) , CertificateRejectReason(..)
, defaultLogging , defaultLogging
, defaultParamsClient , defaultParamsClient
, defaultParamsServer , defaultParamsServer
-- * Context object and accessor -- * Context object and accessor
, Backend(..) , Backend(..)
, Context , Context
, ctxParams , ctxParams
, ctxConnection , ctxConnection
, ctxEOF , ctxEOF
, ctxEstablished , ctxEstablished
, ctxLogging , ctxLogging
, setEOF , setEOF
, setEstablished , setEstablished
, connectionFlush , connectionFlush
, connectionSend , connectionSend
, connectionRecv , connectionRecv
, updateMeasure , updateMeasure
, withMeasure , withMeasure
-- * deprecated types -- * deprecated types
, TLSParams , TLSParams
, TLSLogging , TLSLogging
, TLSCertificateUsage , TLSCertificateUsage
, TLSCertificateRejectReason , TLSCertificateRejectReason
, TLSCtx , TLSCtx
-- * deprecated values -- * deprecated values
, defaultParams , defaultParams
-- * New contexts -- * New contexts
, contextNew , contextNew
, contextNewOnHandle , contextNewOnHandle
-- * Using context states -- * Using context states
, throwCore , throwCore
, usingState , usingState
, usingState_ , usingState_
, getStateRNG , getStateRNG
) where ) where
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Cipher import Network.TLS.Cipher
@ -77,11 +77,11 @@ import System.IO (Handle, hSetBuffering, BufferMode(..), hFlush)
import Prelude hiding (catch) import Prelude hiding (catch)
data Logging = Logging data Logging = Logging
{ loggingPacketSent :: String -> IO () { loggingPacketSent :: String -> IO ()
, loggingPacketRecv :: String -> IO () , loggingPacketRecv :: String -> IO ()
, loggingIOSent :: B.ByteString -> IO () , loggingIOSent :: B.ByteString -> IO ()
, loggingIORecv :: Header -> B.ByteString -> IO () , loggingIORecv :: Header -> B.ByteString -> IO ()
} }
data ClientParams = ClientParams data ClientParams = ClientParams
data ServerParams = ServerParams data ServerParams = ServerParams
@ -89,61 +89,61 @@ data ServerParams = ServerParams
data RoleParams = Client ClientParams | Server ServerParams data RoleParams = Client ClientParams | Server ServerParams
data Params = Params data Params = Params
{ pConnectVersion :: Version -- ^ version to use on client connection. { pConnectVersion :: Version -- ^ version to use on client connection.
, pAllowedVersions :: [Version] -- ^ allowed versions that we can use. , pAllowedVersions :: [Version] -- ^ allowed versions that we can use.
, pCiphers :: [Cipher] -- ^ all ciphers supported ordered by priority. , pCiphers :: [Cipher] -- ^ all ciphers supported ordered by priority.
, pCompressions :: [Compression] -- ^ all compression supported ordered by priority. , pCompressions :: [Compression] -- ^ all compression supported ordered by priority.
, pWantClientCert :: Bool -- ^ request a certificate from client. , pWantClientCert :: Bool -- ^ request a certificate from client.
-- use by server only. -- use by server only.
, pUseSecureRenegotiation :: Bool -- ^ notify that we want to use secure renegotation , pUseSecureRenegotiation :: Bool -- ^ notify that we want to use secure renegotation
, pUseSession :: Bool -- ^ generate new session if specified , pUseSession :: Bool -- ^ generate new session if specified
, pCertificates :: [(X509, Maybe PrivateKey)] -- ^ the cert chain for this context with the associated keys if any. , pCertificates :: [(X509, Maybe PrivateKey)] -- ^ the cert chain for this context with the associated keys if any.
, pLogging :: Logging -- ^ callback for logging , pLogging :: Logging -- ^ callback for logging
, onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake , onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake
, onCertificatesRecv :: [X509] -> IO CertificateUsage -- ^ callback to verify received cert chain. , onCertificatesRecv :: [X509] -> IO CertificateUsage -- ^ callback to verify received cert chain.
, onSessionResumption :: SessionID -> IO (Maybe SessionData) -- ^ callback to maybe resume session on server. , onSessionResumption :: SessionID -> IO (Maybe SessionData) -- ^ callback to maybe resume session on server.
, onSessionEstablished :: SessionID -> SessionData -> IO () -- ^ callback when session have been established , onSessionEstablished :: SessionID -> SessionData -> IO () -- ^ callback when session have been established
, onSessionInvalidated :: SessionID -> IO () -- ^ callback when session is invalidated by error , onSessionInvalidated :: SessionID -> IO () -- ^ callback when session is invalidated by error
, onSuggestNextProtocols :: IO (Maybe [B.ByteString]) -- ^ suggested next protocols accoring to the next protocol negotiation extension. , onSuggestNextProtocols :: IO (Maybe [B.ByteString]) -- ^ suggested next protocols accoring to the next protocol negotiation extension.
, onNPNServerSuggest :: Maybe ([B.ByteString] -> IO B.ByteString) , onNPNServerSuggest :: Maybe ([B.ByteString] -> IO B.ByteString)
, sessionResumeWith :: Maybe (SessionID, SessionData) -- ^ try to establish a connection using this session. , sessionResumeWith :: Maybe (SessionID, SessionData) -- ^ try to establish a connection using this session.
, roleParams :: RoleParams , roleParams :: RoleParams
} }
defaultLogging :: Logging defaultLogging :: Logging
defaultLogging = Logging defaultLogging = Logging
{ loggingPacketSent = (\_ -> return ()) { loggingPacketSent = (\_ -> return ())
, loggingPacketRecv = (\_ -> return ()) , loggingPacketRecv = (\_ -> return ())
, loggingIOSent = (\_ -> return ()) , loggingIOSent = (\_ -> return ())
, loggingIORecv = (\_ _ -> return ()) , loggingIORecv = (\_ _ -> return ())
} }
defaultParamsClient :: Params defaultParamsClient :: Params
defaultParamsClient = Params defaultParamsClient = Params
{ pConnectVersion = TLS10 { pConnectVersion = TLS10
, pAllowedVersions = [TLS10,TLS11,TLS12] , pAllowedVersions = [TLS10,TLS11,TLS12]
, pCiphers = [] , pCiphers = []
, pCompressions = [nullCompression] , pCompressions = [nullCompression]
, pWantClientCert = False , pWantClientCert = False
, pUseSecureRenegotiation = True , pUseSecureRenegotiation = True
, pUseSession = True , pUseSession = True
, pCertificates = [] , pCertificates = []
, pLogging = defaultLogging , pLogging = defaultLogging
, onHandshake = (\_ -> return True) , onHandshake = (\_ -> return True)
, onCertificatesRecv = (\_ -> return CertificateUsageAccept) , onCertificatesRecv = (\_ -> return CertificateUsageAccept)
, onSessionResumption = (\_ -> return Nothing) , onSessionResumption = (\_ -> return Nothing)
, onSessionEstablished = (\_ _ -> return ()) , onSessionEstablished = (\_ _ -> return ())
, onSessionInvalidated = (\_ -> return ()) , onSessionInvalidated = (\_ -> return ())
, onSuggestNextProtocols = return Nothing , onSuggestNextProtocols = return Nothing
, onNPNServerSuggest = Nothing , onNPNServerSuggest = Nothing
, sessionResumeWith = Nothing , sessionResumeWith = Nothing
, roleParams = Client $ ClientParams , roleParams = Client $ ClientParams
} }
defaultParamsServer :: Params defaultParamsServer :: Params
defaultParamsServer = defaultParamsClient defaultParamsServer = defaultParamsClient
{ roleParams = Server $ ServerParams { roleParams = Server $ ServerParams
} }
defaultParams :: Params defaultParams :: Params
defaultParams = defaultParamsClient defaultParams = defaultParamsClient
@ -151,45 +151,45 @@ defaultParams = defaultParamsClient
instance Show Params where instance Show Params where
show p = "Params { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v) show p = "Params { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v)
[ ("connectVersion", show $ pConnectVersion p) [ ("connectVersion", show $ pConnectVersion p)
, ("allowedVersions", show $ pAllowedVersions p) , ("allowedVersions", show $ pAllowedVersions p)
, ("ciphers", show $ pCiphers p) , ("ciphers", show $ pCiphers p)
, ("compressions", show $ pCompressions p) , ("compressions", show $ pCompressions p)
, ("want-client-cert", show $ pWantClientCert p) , ("want-client-cert", show $ pWantClientCert p)
, ("certificates", show $ length $ pCertificates p) , ("certificates", show $ length $ pCertificates p)
]) ++ " }" ]) ++ " }"
-- | Certificate and Chain rejection reason -- | Certificate and Chain rejection reason
data CertificateRejectReason = data CertificateRejectReason =
CertificateRejectExpired CertificateRejectExpired
| CertificateRejectRevoked | CertificateRejectRevoked
| CertificateRejectUnknownCA | CertificateRejectUnknownCA
| CertificateRejectOther String | CertificateRejectOther String
deriving (Show,Eq) deriving (Show,Eq)
-- | Certificate Usage callback possible returns values. -- | Certificate Usage callback possible returns values.
data CertificateUsage = data CertificateUsage =
CertificateUsageAccept -- ^ usage of certificate accepted CertificateUsageAccept -- ^ usage of certificate accepted
| CertificateUsageReject CertificateRejectReason -- ^ usage of certificate rejected | CertificateUsageReject CertificateRejectReason -- ^ usage of certificate rejected
deriving (Show,Eq) deriving (Show,Eq)
-- | -- |
data Backend = Backend data Backend = Backend
{ backendFlush :: IO () -- ^ Flush the connection sending buffer, if any. { backendFlush :: IO () -- ^ Flush the connection sending buffer, if any.
, backendSend :: ByteString -> IO () -- ^ Send a bytestring through the connection. , backendSend :: ByteString -> IO () -- ^ Send a bytestring through the connection.
, backendRecv :: Int -> IO ByteString -- ^ Receive specified number of bytes from the connection. , backendRecv :: Int -> IO ByteString -- ^ Receive specified number of bytes from the connection.
} }
-- | A TLS Context keep tls specific state, parameters and backend information. -- | A TLS Context keep tls specific state, parameters and backend information.
data Context = Context data Context = Context
{ ctxConnection :: Backend -- ^ return the backend object associated with this context { ctxConnection :: Backend -- ^ return the backend object associated with this context
, ctxParams :: Params , ctxParams :: Params
, ctxState :: MVar TLSState , ctxState :: MVar TLSState
, ctxMeasurement :: IORef Measurement , ctxMeasurement :: IORef Measurement
, ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not. , ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not.
, ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful. , ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful.
} }
-- deprecated types, setup as aliases for compatibility. -- deprecated types, setup as aliases for compatibility.
type TLSParams = Params type TLSParams = Params
@ -238,23 +238,23 @@ contextNew :: (MonadIO m, CryptoRandomGen rng)
-> m Context -> m Context
contextNew backend params rng = liftIO $ do contextNew backend params rng = liftIO $ do
let clientContext = case roleParams params of let clientContext = case roleParams params of
Client {} -> True Client {} -> True
Server {} -> False Server {} -> False
let st = (newTLSState rng) { stClientContext = clientContext } let st = (newTLSState rng) { stClientContext = clientContext }
stvar <- newMVar st stvar <- newMVar st
eof <- newIORef False eof <- newIORef False
established <- newIORef False established <- newIORef False
stats <- newIORef newMeasurement stats <- newIORef newMeasurement
return $ Context return $ Context
{ ctxConnection = backend { ctxConnection = backend
, ctxParams = params , ctxParams = params
, ctxState = stvar , ctxState = stvar
, ctxMeasurement = stats , ctxMeasurement = stats
, ctxEOF_ = eof , ctxEOF_ = eof
, ctxEstablished_ = established , ctxEstablished_ = established
} }
-- | create a new context on an handle. -- | create a new context on an handle.
contextNewOnHandle :: (MonadIO m, CryptoRandomGen rng) contextNewOnHandle :: (MonadIO m, CryptoRandomGen rng)
@ -263,8 +263,8 @@ contextNewOnHandle :: (MonadIO m, CryptoRandomGen rng)
-> rng -- ^ Random number generator associated with this context. -> rng -- ^ Random number generator associated with this context.
-> m Context -> m Context
contextNewOnHandle handle params st = contextNewOnHandle handle params st =
liftIO (hSetBuffering handle NoBuffering) >> contextNew backend params st liftIO (hSetBuffering handle NoBuffering) >> contextNew backend params st
where backend = Backend (hFlush handle) (B.hPut handle) (B.hGet handle) where backend = Backend (hFlush handle) (B.hPut handle) (B.hGet handle)
throwCore :: (MonadIO m, Exception e) => e -> m a throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO throwCore = liftIO . throwIO
@ -272,16 +272,16 @@ throwCore = liftIO . throwIO
usingState :: MonadIO m => Context -> TLSSt a -> m (Either TLSError a) usingState :: MonadIO m => Context -> TLSSt a -> m (Either TLSError a)
usingState ctx f = usingState ctx f =
liftIO $ modifyMVar (ctxState ctx) $ \st -> liftIO $ modifyMVar (ctxState ctx) $ \st ->
let (a, newst) = runTLSState f st let (a, newst) = runTLSState f st
in newst `seq` return (newst, a) in newst `seq` return (newst, a)
usingState_ :: MonadIO m => Context -> TLSSt a -> m a usingState_ :: MonadIO m => Context -> TLSSt a -> m a
usingState_ ctx f = do usingState_ ctx f = do
ret <- usingState ctx f ret <- usingState ctx f
case ret of case ret of
Left err -> throwCore err Left err -> throwCore err
Right r -> return r Right r -> return r
getStateRNG :: MonadIO m => Context -> Int -> m Bytes getStateRNG :: MonadIO m => Context -> Int -> m Bytes
getStateRNG ctx n = usingState_ ctx (genTLSRandom n) getStateRNG ctx n = usingState_ ctx (genTLSRandom n)

View file

@ -8,25 +8,25 @@
-- Portability : unknown -- Portability : unknown
-- --
module Network.TLS.Core module Network.TLS.Core
( (
-- * Internal packet sending and receiving -- * Internal packet sending and receiving
sendPacket sendPacket
, recvPacket , recvPacket
-- * Initialisation and Termination of context -- * Initialisation and Termination of context
, bye , bye
, handshake , handshake
, HandshakeFailed(..) , HandshakeFailed(..)
, ConnectionNotEstablished(..) , ConnectionNotEstablished(..)
-- * Next Protocol Negotiation -- * Next Protocol Negotiation
, getNegotiatedProtocol , getNegotiatedProtocol
-- * High level API -- * High level API
, sendData , sendData
, recvData , recvData
, recvData' , recvData'
) where ) where
import Network.TLS.Context import Network.TLS.Context
import Network.TLS.Struct import Network.TLS.Struct
@ -54,10 +54,10 @@ import System.IO.Error (mkIOError, eofErrorType)
import Prelude hiding (catch) import Prelude hiding (catch)
data HandshakeFailed = HandshakeFailed TLSError data HandshakeFailed = HandshakeFailed TLSError
deriving (Show,Eq,Typeable) deriving (Show,Eq,Typeable)
data ConnectionNotEstablished = ConnectionNotEstablished data ConnectionNotEstablished = ConnectionNotEstablished
deriving (Show,Eq,Typeable) deriving (Show,Eq,Typeable)
instance Exception HandshakeFailed instance Exception HandshakeFailed
instance Exception ConnectionNotEstablished instance Exception ConnectionNotEstablished
@ -71,29 +71,29 @@ handshakeFailed err = throwIO $ HandshakeFailed err
checkValid :: MonadIO m => Context -> m () checkValid :: MonadIO m => Context -> m ()
checkValid ctx = do checkValid ctx = do
established <- ctxEstablished ctx established <- ctxEstablished ctx
unless established $ liftIO $ throwIO ConnectionNotEstablished unless established $ liftIO $ throwIO ConnectionNotEstablished
eofed <- ctxEOF ctx eofed <- ctxEOF ctx
when eofed $ liftIO $ throwIO $ mkIOError eofErrorType "data" Nothing Nothing when eofed $ liftIO $ throwIO $ mkIOError eofErrorType "data" Nothing Nothing
readExact :: MonadIO m => Context -> Int -> m Bytes readExact :: MonadIO m => Context -> Int -> m Bytes
readExact ctx sz = do readExact ctx sz = do
hdrbs <- liftIO $ connectionRecv ctx sz hdrbs <- liftIO $ connectionRecv ctx sz
when (B.length hdrbs < sz) $ do when (B.length hdrbs < sz) $ do
setEOF ctx setEOF ctx
if B.null hdrbs if B.null hdrbs
then throwCore Error_EOF then throwCore Error_EOF
else throwCore (Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ (show $B.length hdrbs))) else throwCore (Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ (show $B.length hdrbs)))
return hdrbs return hdrbs
recvRecord :: MonadIO m => Context -> m (Either TLSError (Record Plaintext)) recvRecord :: MonadIO m => Context -> m (Either TLSError (Record Plaintext))
recvRecord ctx = readExact ctx 5 >>= either (return . Left) recvLength . decodeHeader recvRecord ctx = readExact ctx 5 >>= either (return . Left) recvLength . decodeHeader
where recvLength header@(Header _ _ readlen) where recvLength header@(Header _ _ readlen)
| readlen > 16384 + 2048 = return $ Left $ Error_Protocol ("record exceeding maximum size", True, RecordOverflow) | readlen > 16384 + 2048 = return $ Left $ Error_Protocol ("record exceeding maximum size", True, RecordOverflow)
| otherwise = do | otherwise = do
content <- readExact ctx (fromIntegral readlen) content <- readExact ctx (fromIntegral readlen)
liftIO $ (loggingIORecv $ ctxLogging ctx) header content liftIO $ (loggingIORecv $ ctxLogging ctx) header content
usingState ctx $ disengageRecord $ rawToRecord header (fragmentCiphertext content) usingState ctx $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
-- | receive one packet from the context that contains 1 or -- | receive one packet from the context that contains 1 or
@ -101,45 +101,45 @@ recvRecord ctx = readExact ctx 5 >>= either (return . Left) recvLength . decodeH
-- TLSError if the packet is unexpected or malformed -- TLSError if the packet is unexpected or malformed
recvPacket :: MonadIO m => Context -> m (Either TLSError Packet) recvPacket :: MonadIO m => Context -> m (Either TLSError Packet)
recvPacket ctx = do recvPacket ctx = do
erecord <- recvRecord ctx erecord <- recvRecord ctx
case erecord of case erecord of
Left err -> return $ Left err Left err -> return $ Left err
Right record -> do Right record -> do
pkt <- usingState ctx $ processPacket record pkt <- usingState ctx $ processPacket record
case pkt of case pkt of
Right p -> liftIO $ (loggingPacketRecv $ ctxLogging ctx) $ show p Right p -> liftIO $ (loggingPacketRecv $ ctxLogging ctx) $ show p
_ -> return () _ -> return ()
return pkt return pkt
recvPacketHandshake :: MonadIO m => Context -> m [Handshake] recvPacketHandshake :: MonadIO m => Context -> m [Handshake]
recvPacketHandshake ctx = do recvPacketHandshake ctx = do
pkts <- recvPacket ctx pkts <- recvPacket ctx
case pkts of case pkts of
Right (Handshake l) -> return l Right (Handshake l) -> return l
Right x -> fail ("unexpected type received. expecting handshake and got: " ++ show x) Right x -> fail ("unexpected type received. expecting handshake and got: " ++ show x)
Left err -> throwCore err Left err -> throwCore err
data RecvState m = data RecvState m =
RecvStateNext (Packet -> m (RecvState m)) RecvStateNext (Packet -> m (RecvState m))
| RecvStateHandshake (Handshake -> m (RecvState m)) | RecvStateHandshake (Handshake -> m (RecvState m))
| RecvStateDone | RecvStateDone
runRecvState :: MonadIO m => Context -> RecvState m -> m () runRecvState :: MonadIO m => Context -> RecvState m -> m ()
runRecvState _ (RecvStateDone) = return () runRecvState _ (RecvStateDone) = return ()
runRecvState ctx (RecvStateNext f) = recvPacket ctx >>= either throwCore f >>= runRecvState ctx runRecvState ctx (RecvStateNext f) = recvPacket ctx >>= either throwCore f >>= runRecvState ctx
runRecvState ctx iniState = recvPacketHandshake ctx >>= loop iniState >>= runRecvState ctx runRecvState ctx iniState = recvPacketHandshake ctx >>= loop iniState >>= runRecvState ctx
where where
loop :: MonadIO m => RecvState m -> [Handshake] -> m (RecvState m) loop :: MonadIO m => RecvState m -> [Handshake] -> m (RecvState m)
loop recvState [] = return recvState loop recvState [] = return recvState
loop (RecvStateHandshake f) (x:xs) = do loop (RecvStateHandshake f) (x:xs) = do
nstate <- f x nstate <- f x
usingState_ ctx $ processHandshake x usingState_ ctx $ processHandshake x
loop nstate xs loop nstate xs
loop _ _ = unexpected "spurious handshake" Nothing loop _ _ = unexpected "spurious handshake" Nothing
sendChangeCipherAndFinish :: MonadIO m => Context -> Bool -> m () sendChangeCipherAndFinish :: MonadIO m => Context -> Bool -> m ()
sendChangeCipherAndFinish ctx isClient = do sendChangeCipherAndFinish ctx isClient = do
sendPacket ctx ChangeCipherSpec sendPacket ctx ChangeCipherSpec
when isClient $ do when isClient $ do
suggest <- usingState_ ctx $ getServerNextProtocolSuggest suggest <- usingState_ ctx $ getServerNextProtocolSuggest
case (onNPNServerSuggest (ctxParams ctx), suggest) of case (onNPNServerSuggest (ctxParams ctx), suggest) of
@ -151,34 +151,34 @@ sendChangeCipherAndFinish ctx isClient = do
(Just _, Nothing) -> return () (Just _, Nothing) -> return ()
-- client didn't offer. do nothing. -- client didn't offer. do nothing.
(Nothing, _) -> return () (Nothing, _) -> return ()
liftIO $ connectionFlush ctx liftIO $ connectionFlush ctx
cf <- usingState_ ctx $ getHandshakeDigest isClient cf <- usingState_ ctx $ getHandshakeDigest isClient
sendPacket ctx (Handshake [Finished cf]) sendPacket ctx (Handshake [Finished cf])
liftIO $ connectionFlush ctx liftIO $ connectionFlush ctx
recvChangeCipherAndFinish :: MonadIO m => Context -> m () recvChangeCipherAndFinish :: MonadIO m => Context -> m ()
recvChangeCipherAndFinish ctx = runRecvState ctx (RecvStateNext expectChangeCipher) recvChangeCipherAndFinish ctx = runRecvState ctx (RecvStateNext expectChangeCipher)
where where
expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish
expectChangeCipher p = unexpected (show p) (Just "change cipher") expectChangeCipher p = unexpected (show p) (Just "change cipher")
expectFinish (Finished _) = return RecvStateDone expectFinish (Finished _) = return RecvStateDone
expectFinish p = unexpected (show p) (Just "Handshake Finished") expectFinish p = unexpected (show p) (Just "Handshake Finished")
unexpected :: MonadIO m => String -> Maybe [Char] -> m a unexpected :: MonadIO m => String -> Maybe [Char] -> m a
unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" expected: " ++) expected) unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" expected: " ++) expected)
newSession :: MonadIO m => Context -> m Session newSession :: MonadIO m => Context -> m Session
newSession ctx newSession ctx
| pUseSession $ ctxParams ctx = getStateRNG ctx 32 >>= return . Session . Just | pUseSession $ ctxParams ctx = getStateRNG ctx 32 >>= return . Session . Just
| otherwise = return $ Session Nothing | otherwise = return $ Session Nothing
-- | Send one packet to the context -- | Send one packet to the context
sendPacket :: MonadIO m => Context -> Packet -> m () sendPacket :: MonadIO m => Context -> Packet -> m ()
sendPacket ctx pkt = do sendPacket ctx pkt = do
liftIO $ (loggingPacketSent $ ctxLogging ctx) (show pkt) liftIO $ (loggingPacketSent $ ctxLogging ctx) (show pkt)
dataToSend <- usingState_ ctx $ writePacket pkt dataToSend <- usingState_ ctx $ writePacket pkt
liftIO $ (loggingIOSent $ ctxLogging ctx) dataToSend liftIO $ (loggingIOSent $ ctxLogging ctx) dataToSend
liftIO $ connectionSend ctx dataToSend liftIO $ connectionSend ctx dataToSend
-- | notify the context that this side wants to close connection. -- | notify the context that this side wants to close connection.
-- this is important that it is called before closing the handle, otherwise -- this is important that it is called before closing the handle, otherwise
@ -196,249 +196,249 @@ getNegotiatedProtocol ctx = usingState_ ctx S.getNegotiatedProtocol
-- | when a new handshake is done, wrap up & clean up. -- | when a new handshake is done, wrap up & clean up.
handshakeTerminate :: MonadIO m => Context -> m () handshakeTerminate :: MonadIO m => Context -> m ()
handshakeTerminate ctx = do handshakeTerminate ctx = do
session <- usingState_ ctx getSession session <- usingState_ ctx getSession
-- only callback the session established if we have a session -- only callback the session established if we have a session
case session of case session of
Session (Just sessionId) -> do Session (Just sessionId) -> do
sessionData <- usingState_ ctx getSessionData sessionData <- usingState_ ctx getSessionData
liftIO $ (onSessionEstablished $ ctxParams ctx) sessionId (fromJust sessionData) liftIO $ (onSessionEstablished $ ctxParams ctx) sessionId (fromJust sessionData)
_ -> return () _ -> return ()
-- forget all handshake data now and reset bytes counters. -- forget all handshake data now and reset bytes counters.
usingState_ ctx endHandshake usingState_ ctx endHandshake
updateMeasure ctx resetBytesCounters updateMeasure ctx resetBytesCounters
-- mark the secure connection up and running. -- mark the secure connection up and running.
setEstablished ctx True setEstablished ctx True
return () return ()
-- client part of handshake. send a bunch of handshake of client -- client part of handshake. send a bunch of handshake of client
-- values intertwined with response from the server. -- values intertwined with response from the server.
handshakeClient :: MonadIO m => Context -> m () handshakeClient :: MonadIO m => Context -> m ()
handshakeClient ctx = do handshakeClient ctx = do
updateMeasure ctx incrementNbHandshakes updateMeasure ctx incrementNbHandshakes
sendClientHello sendClientHello
recvServerHello recvServerHello
sessionResuming <- usingState_ ctx isSessionResuming sessionResuming <- usingState_ ctx isSessionResuming
if sessionResuming if sessionResuming
then sendChangeCipherAndFinish ctx True then sendChangeCipherAndFinish ctx True
else do else do
sendCertificate >> sendClientKeyXchg >> sendCertificateVerify sendCertificate >> sendClientKeyXchg >> sendCertificateVerify
sendChangeCipherAndFinish ctx True sendChangeCipherAndFinish ctx True
recvChangeCipherAndFinish ctx recvChangeCipherAndFinish ctx
handshakeTerminate ctx handshakeTerminate ctx
where where
params = ctxParams ctx params = ctxParams ctx
ver = pConnectVersion params ver = pConnectVersion params
allowedvers = pAllowedVersions params allowedvers = pAllowedVersions params
ciphers = pCiphers params ciphers = pCiphers params
compressions = pCompressions params compressions = pCompressions params
clientCerts = map fst $ pCertificates params clientCerts = map fst $ pCertificates params
getExtensions = sequence [secureReneg, npnExtention] >>= return . catMaybes getExtensions = sequence [secureReneg, npnExtention] >>= return . catMaybes
secureReneg = secureReneg =
if pUseSecureRenegotiation params if pUseSecureRenegotiation params
then usingState_ ctx (getVerifiedData True) >>= \vd -> return $ Just (0xff01, encodeExtSecureRenegotiation vd Nothing) then usingState_ ctx (getVerifiedData True) >>= \vd -> return $ Just (0xff01, encodeExtSecureRenegotiation vd Nothing)
else return Nothing else return Nothing
npnExtention = if isJust $ onNPNServerSuggest params npnExtention = if isJust $ onNPNServerSuggest params
then return $ Just (13172, "") then return $ Just (13172, "")
else return Nothing else return Nothing
sendClientHello = do sendClientHello = do
crand <- getStateRNG ctx 32 >>= return . ClientRandom crand <- getStateRNG ctx 32 >>= return . ClientRandom
let clientSession = Session . maybe Nothing (Just . fst) $ sessionResumeWith params let clientSession = Session . maybe Nothing (Just . fst) $ sessionResumeWith params
extensions <- getExtensions extensions <- getExtensions
usingState_ ctx (startHandshakeClient ver crand) usingState_ ctx (startHandshakeClient ver crand)
sendPacket ctx $ Handshake sendPacket ctx $ Handshake
[ ClientHello ver crand clientSession (map cipherID ciphers) [ ClientHello ver crand clientSession (map cipherID ciphers)
(map compressionID compressions) extensions (map compressionID compressions) extensions
] ]
expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish
expectChangeCipher p = unexpected (show p) (Just "change cipher") expectChangeCipher p = unexpected (show p) (Just "change cipher")
expectFinish (Finished _) = return RecvStateDone expectFinish (Finished _) = return RecvStateDone
expectFinish p = unexpected (show p) (Just "Handshake Finished") expectFinish p = unexpected (show p) (Just "Handshake Finished")
sendCertificate = do sendCertificate = do
-- Send Certificate if requested. XXX disabled for now. -- Send Certificate if requested. XXX disabled for now.
certRequested <- return False certRequested <- return False
when certRequested (sendPacket ctx $ Handshake [Certificates clientCerts]) when certRequested (sendPacket ctx $ Handshake [Certificates clientCerts])
sendCertificateVerify = sendCertificateVerify =
{- maybe send certificateVerify -} {- maybe send certificateVerify -}
{- FIXME not implemented yet -} {- FIXME not implemented yet -}
return () return ()
recvServerHello = runRecvState ctx (RecvStateHandshake onServerHello) recvServerHello = runRecvState ctx (RecvStateHandshake onServerHello)
onServerHello :: MonadIO m => Handshake -> m (RecvState m) onServerHello :: MonadIO m => Handshake -> m (RecvState m)
onServerHello sh@(ServerHello rver _ serverSession cipher _ exts) = do onServerHello sh@(ServerHello rver _ serverSession cipher _ exts) = do
when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion) when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
case find ((==) rver) allowedvers of case find ((==) rver) allowedvers of
Nothing -> throwCore $ Error_Protocol ("version " ++ show ver ++ "is not supported", True, ProtocolVersion) Nothing -> throwCore $ Error_Protocol ("version " ++ show ver ++ "is not supported", True, ProtocolVersion)
Just _ -> usingState_ ctx $ setVersion ver Just _ -> usingState_ ctx $ setVersion ver
case find ((==) cipher . cipherID) ciphers of case find ((==) cipher . cipherID) ciphers of
Nothing -> throwCore $ Error_Protocol ("no cipher in common with the server", True, HandshakeFailure) Nothing -> throwCore $ Error_Protocol ("no cipher in common with the server", True, HandshakeFailure)
Just c -> usingState_ ctx $ setCipher c Just c -> usingState_ ctx $ setCipher c
let resumingSession = case sessionResumeWith params of let resumingSession = case sessionResumeWith params of
Just (sessionId, sessionData) -> if serverSession == Session (Just sessionId) then Just sessionData else Nothing Just (sessionId, sessionData) -> if serverSession == Session (Just sessionId) then Just sessionData else Nothing
Nothing -> Nothing Nothing -> Nothing
usingState_ ctx $ setSession serverSession (isJust resumingSession) usingState_ ctx $ setSession serverSession (isJust resumingSession)
usingState_ ctx $ processServerHello sh usingState_ ctx $ processServerHello sh
case decodeExtNextProtocolNegotiation `fmap` (lookup 13172 exts) of case decodeExtNextProtocolNegotiation `fmap` (lookup 13172 exts) of
Just (Right protos) -> usingState_ ctx $ do Just (Right protos) -> usingState_ ctx $ do
setExtensionNPN True setExtensionNPN True
setServerNextProtocolSuggest protos setServerNextProtocolSuggest protos
Just (Left err) -> throwCore (Error_Protocol ("could not decode NPN handshake: " ++ show err, True, DecodeError)) Just (Left err) -> throwCore (Error_Protocol ("could not decode NPN handshake: " ++ show err, True, DecodeError))
Nothing -> return () Nothing -> return ()
case resumingSession of case resumingSession of
Nothing -> return $ RecvStateHandshake processCertificate Nothing -> return $ RecvStateHandshake processCertificate
Just sessionData -> do Just sessionData -> do
usingState_ ctx (setMasterSecret $ sessionSecret sessionData) usingState_ ctx (setMasterSecret $ sessionSecret sessionData)
return $ RecvStateNext expectChangeCipher return $ RecvStateNext expectChangeCipher
onServerHello p = unexpected (show p) (Just "server hello") onServerHello p = unexpected (show p) (Just "server hello")
processCertificate :: MonadIO m => Handshake -> m (RecvState m) processCertificate :: MonadIO m => Handshake -> m (RecvState m)
processCertificate (Certificates certs) = do processCertificate (Certificates certs) = do
usage <- liftIO $ catch (onCertificatesRecv params $ certs) rejectOnException usage <- liftIO $ catch (onCertificatesRecv params $ certs) rejectOnException
case usage of case usage of
CertificateUsageAccept -> return () CertificateUsageAccept -> return ()
CertificateUsageReject reason -> certificateRejected reason CertificateUsageReject reason -> certificateRejected reason
return $ RecvStateHandshake processServerKeyExchange return $ RecvStateHandshake processServerKeyExchange
where where
rejectOnException :: SomeException -> IO TLSCertificateUsage rejectOnException :: SomeException -> IO TLSCertificateUsage
rejectOnException e = return $ CertificateUsageReject $ CertificateRejectOther $ show e rejectOnException e = return $ CertificateUsageReject $ CertificateRejectOther $ show e
processCertificate p = processServerKeyExchange p processCertificate p = processServerKeyExchange p
processServerKeyExchange :: MonadIO m => Handshake -> m (RecvState m) processServerKeyExchange :: MonadIO m => Handshake -> m (RecvState m)
processServerKeyExchange (ServerKeyXchg _) = return $ RecvStateHandshake processCertificateRequest processServerKeyExchange (ServerKeyXchg _) = return $ RecvStateHandshake processCertificateRequest
processServerKeyExchange p = processCertificateRequest p processServerKeyExchange p = processCertificateRequest p
processCertificateRequest (CertRequest _ _ _) = do processCertificateRequest (CertRequest _ _ _) = do
--modify (\sc -> sc { scCertRequested = True }) --modify (\sc -> sc { scCertRequested = True })
return $ RecvStateHandshake processServerHelloDone return $ RecvStateHandshake processServerHelloDone
processCertificateRequest p = processServerHelloDone p processCertificateRequest p = processServerHelloDone p
processServerHelloDone ServerHelloDone = return RecvStateDone processServerHelloDone ServerHelloDone = return RecvStateDone
processServerHelloDone p = unexpected (show p) (Just "server hello data") processServerHelloDone p = unexpected (show p) (Just "server hello data")
sendClientKeyXchg = do sendClientKeyXchg = do
encryptedPreMaster <- usingState_ ctx $ do encryptedPreMaster <- usingState_ ctx $ do
xver <- stVersion <$> get xver <- stVersion <$> get
prerand <- genTLSRandom 46 prerand <- genTLSRandom 46
let premaster = encodePreMasterSecret xver prerand let premaster = encodePreMasterSecret xver prerand
setMasterSecretFromPre premaster setMasterSecretFromPre premaster
-- SSL3 implementation generally forget this length field since it's redundant, -- SSL3 implementation generally forget this length field since it's redundant,
-- however TLS10 make it clear that the length field need to be present. -- however TLS10 make it clear that the length field need to be present.
e <- encryptRSA premaster e <- encryptRSA premaster
let extra = if xver < TLS10 let extra = if xver < TLS10
then B.empty then B.empty
else encodeWord16 $ fromIntegral $ B.length e else encodeWord16 $ fromIntegral $ B.length e
return $ extra `B.append` e return $ extra `B.append` e
sendPacket ctx $ Handshake [ClientKeyXchg encryptedPreMaster] sendPacket ctx $ Handshake [ClientKeyXchg encryptedPreMaster]
-- on certificate reject, throw an exception with the proper protocol alert error. -- on certificate reject, throw an exception with the proper protocol alert error.
certificateRejected CertificateRejectRevoked = certificateRejected CertificateRejectRevoked =
throwCore $ Error_Protocol ("certificate is revoked", True, CertificateRevoked) throwCore $ Error_Protocol ("certificate is revoked", True, CertificateRevoked)
certificateRejected CertificateRejectExpired = certificateRejected CertificateRejectExpired =
throwCore $ Error_Protocol ("certificate has expired", True, CertificateExpired) throwCore $ Error_Protocol ("certificate has expired", True, CertificateExpired)
certificateRejected CertificateRejectUnknownCA = certificateRejected CertificateRejectUnknownCA =
throwCore $ Error_Protocol ("certificate has unknown CA", True, UnknownCa) throwCore $ Error_Protocol ("certificate has unknown CA", True, UnknownCa)
certificateRejected (CertificateRejectOther s) = certificateRejected (CertificateRejectOther s) =
throwCore $ Error_Protocol ("certificate rejected: " ++ s, True, CertificateUnknown) throwCore $ Error_Protocol ("certificate rejected: " ++ s, True, CertificateUnknown)
handshakeServerWith :: MonadIO m => Context -> Handshake -> m () handshakeServerWith :: MonadIO m => Context -> Handshake -> m ()
handshakeServerWith ctx clientHello@(ClientHello ver _ clientSession ciphers compressions exts) = do handshakeServerWith ctx clientHello@(ClientHello ver _ clientSession ciphers compressions exts) = do
-- check if policy allow this new handshake to happens -- check if policy allow this new handshake to happens
handshakeAuthorized <- withMeasure ctx (onHandshake $ ctxParams ctx) handshakeAuthorized <- withMeasure ctx (onHandshake $ ctxParams ctx)
unless handshakeAuthorized (throwCore $ Error_HandshakePolicy "server: handshake denied") unless handshakeAuthorized (throwCore $ Error_HandshakePolicy "server: handshake denied")
updateMeasure ctx incrementNbHandshakes updateMeasure ctx incrementNbHandshakes
-- Handle Client hello -- Handle Client hello
usingState_ ctx $ processHandshake clientHello usingState_ ctx $ processHandshake clientHello
when (ver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion) when (ver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
when (not $ elem ver (pAllowedVersions params)) $ when (not $ elem ver (pAllowedVersions params)) $
throwCore $ Error_Protocol ("version " ++ show ver ++ "is not supported", True, ProtocolVersion) throwCore $ Error_Protocol ("version " ++ show ver ++ "is not supported", True, ProtocolVersion)
when (commonCiphers == []) $ when (commonCiphers == []) $
throwCore $ Error_Protocol ("no cipher in common with the client", True, HandshakeFailure) throwCore $ Error_Protocol ("no cipher in common with the client", True, HandshakeFailure)
when (null commonCompressions) $ when (null commonCompressions) $
throwCore $ Error_Protocol ("no compression in common with the client", True, HandshakeFailure) throwCore $ Error_Protocol ("no compression in common with the client", True, HandshakeFailure)
usingState_ ctx $ modify (\st -> st usingState_ ctx $ modify (\st -> st
{ stVersion = ver { stVersion = ver
, stCipher = Just usedCipher , stCipher = Just usedCipher
, stCompression = usedCompression , stCompression = usedCompression
}) })
resumeSessionData <- case clientSession of resumeSessionData <- case clientSession of
(Session (Just clientSessionId)) -> liftIO $ onSessionResumption params $ clientSessionId (Session (Just clientSessionId)) -> liftIO $ onSessionResumption params $ clientSessionId
(Session Nothing) -> return Nothing (Session Nothing) -> return Nothing
case resumeSessionData of case resumeSessionData of
Nothing -> do Nothing -> do
handshakeSendServerData handshakeSendServerData
liftIO $ connectionFlush ctx liftIO $ connectionFlush ctx
-- Receive client info until client Finished. -- Receive client info until client Finished.
recvClientData recvClientData
sendChangeCipherAndFinish ctx False sendChangeCipherAndFinish ctx False
Just sessionData -> do Just sessionData -> do
usingState_ ctx (setSession clientSession True) usingState_ ctx (setSession clientSession True)
serverhello <- makeServerHello clientSession serverhello <- makeServerHello clientSession
sendPacket ctx $ Handshake [serverhello] sendPacket ctx $ Handshake [serverhello]
usingState_ ctx $ setMasterSecret $ sessionSecret sessionData usingState_ ctx $ setMasterSecret $ sessionSecret sessionData
sendChangeCipherAndFinish ctx False sendChangeCipherAndFinish ctx False
recvChangeCipherAndFinish ctx recvChangeCipherAndFinish ctx
handshakeTerminate ctx handshakeTerminate ctx
where where
params = ctxParams ctx params = ctxParams ctx
commonCiphers = intersect ciphers (map cipherID $ pCiphers params) commonCiphers = intersect ciphers (map cipherID $ pCiphers params)
usedCipher = fromJust $ find (\c -> cipherID c == head commonCiphers) (pCiphers params) usedCipher = fromJust $ find (\c -> cipherID c == head commonCiphers) (pCiphers params)
commonCompressions = compressionIntersectID (pCompressions params) compressions commonCompressions = compressionIntersectID (pCompressions params) compressions
usedCompression = head commonCompressions usedCompression = head commonCompressions
srvCerts = map fst $ pCertificates params srvCerts = map fst $ pCertificates params
privKeys = map snd $ pCertificates params privKeys = map snd $ pCertificates params
needKeyXchg = cipherExchangeNeedMoreData $ cipherKeyExchange usedCipher needKeyXchg = cipherExchangeNeedMoreData $ cipherKeyExchange usedCipher
clientRequestedNPN = isJust $ lookup 13172 exts clientRequestedNPN = isJust $ lookup 13172 exts
--- ---
recvClientData = runRecvState ctx (RecvStateHandshake processClientCertificate) recvClientData = runRecvState ctx (RecvStateHandshake processClientCertificate)
processClientCertificate (Certificates _) = return $ RecvStateHandshake processClientKeyExchange processClientCertificate (Certificates _) = return $ RecvStateHandshake processClientKeyExchange
processClientCertificate p = processClientKeyExchange p processClientCertificate p = processClientKeyExchange p
processClientKeyExchange (ClientKeyXchg _) = return $ RecvStateNext processCertificateVerify processClientKeyExchange (ClientKeyXchg _) = return $ RecvStateNext processCertificateVerify
processClientKeyExchange p = unexpected (show p) (Just "client key exchange") processClientKeyExchange p = unexpected (show p) (Just "client key exchange")
processCertificateVerify (Handshake [CertVerify _]) = return $ RecvStateNext expectChangeCipher processCertificateVerify (Handshake [CertVerify _]) = return $ RecvStateNext expectChangeCipher
processCertificateVerify p = expectChangeCipher p processCertificateVerify p = expectChangeCipher p
expectChangeCipher ChangeCipherSpec = do npn <- usingState_ ctx getExtensionNPN expectChangeCipher ChangeCipherSpec = do npn <- usingState_ ctx getExtensionNPN
return $ RecvStateHandshake $ if npn return $ RecvStateHandshake $ if npn
then expectNPN then expectNPN
else expectFinish else expectFinish
expectChangeCipher p = unexpected (show p) (Just "change cipher") expectChangeCipher p = unexpected (show p) (Just "change cipher")
expectNPN (NextProtocolNegotiation _) = return $ RecvStateHandshake expectFinish expectNPN (NextProtocolNegotiation _) = return $ RecvStateHandshake expectFinish
expectNPN p = unexpected (show p) (Just "Handshake NextProtocolNegotiation") expectNPN p = unexpected (show p) (Just "Handshake NextProtocolNegotiation")
expectFinish (Finished _) = return RecvStateDone expectFinish (Finished _) = return RecvStateDone
expectFinish p = unexpected (show p) (Just "Handshake Finished") expectFinish p = unexpected (show p) (Just "Handshake Finished")
--- ---
makeServerHello session = do makeServerHello session = do
srand <- getStateRNG ctx 32 >>= return . ServerRandom srand <- getStateRNG ctx 32 >>= return . ServerRandom
case privKeys of case privKeys of
(Just privkey : _) -> usingState_ ctx $ setPrivateKey privkey (Just privkey : _) -> usingState_ ctx $ setPrivateKey privkey
_ -> return () -- return a sensible error _ -> return () -- return a sensible error
-- in TLS12, we need to check as well the certificates we are sending if they have in the extension -- in TLS12, we need to check as well the certificates we are sending if they have in the extension
-- the necessary bits set. -- the necessary bits set.
secReneg <- usingState_ ctx getSecureRenegotiation secReneg <- usingState_ ctx getSecureRenegotiation
secRengExt <- if secReneg secRengExt <- if secReneg
then do then do
vf <- usingState_ ctx $ do vf <- usingState_ ctx $ do
cvf <- getVerifiedData True cvf <- getVerifiedData True
svf <- getVerifiedData False svf <- getVerifiedData False
return $ encodeExtSecureRenegotiation cvf (Just svf) return $ encodeExtSecureRenegotiation cvf (Just svf)
return [ (0xff01, vf) ] return [ (0xff01, vf) ]
else return [] else return []
nextProtocols <- nextProtocols <-
if clientRequestedNPN if clientRequestedNPN
then liftIO $ onSuggestNextProtocols params then liftIO $ onSuggestNextProtocols params
@ -449,83 +449,83 @@ handshakeServerWith ctx clientHello@(ClientHello ver _ clientSession ciphers com
return [ (13172, encodeExtNextProtocolNegotiation protos) ] return [ (13172, encodeExtNextProtocolNegotiation protos) ]
Nothing -> return [] Nothing -> return []
let extensions = secRengExt ++ npnExt let extensions = secRengExt ++ npnExt
usingState_ ctx (setVersion ver >> setServerRandom srand) usingState_ ctx (setVersion ver >> setServerRandom srand)
return $ ServerHello ver srand session (cipherID usedCipher) return $ ServerHello ver srand session (cipherID usedCipher)
(compressionID usedCompression) extensions (compressionID usedCompression) extensions
handshakeSendServerData = do handshakeSendServerData = do
serverSession <- newSession ctx serverSession <- newSession ctx
usingState_ ctx (setSession serverSession False) usingState_ ctx (setSession serverSession False)
serverhello <- makeServerHello serverSession serverhello <- makeServerHello serverSession
-- send ServerHello & Certificate & ServerKeyXchg & CertReq -- send ServerHello & Certificate & ServerKeyXchg & CertReq
sendPacket ctx $ Handshake [ serverhello, Certificates srvCerts ] sendPacket ctx $ Handshake [ serverhello, Certificates srvCerts ]
when needKeyXchg $ do when needKeyXchg $ do
let skg = SKX_RSA Nothing let skg = SKX_RSA Nothing
sendPacket ctx (Handshake [ServerKeyXchg skg]) sendPacket ctx (Handshake [ServerKeyXchg skg])
-- FIXME we don't do this on a Anonymous server -- FIXME we don't do this on a Anonymous server
when (pWantClientCert params) $ do when (pWantClientCert params) $ do
let certTypes = [ CertificateType_RSA_Sign ] let certTypes = [ CertificateType_RSA_Sign ]
let creq = CertRequest certTypes Nothing [0,0,0] let creq = CertRequest certTypes Nothing [0,0,0]
sendPacket ctx (Handshake [creq]) sendPacket ctx (Handshake [creq])
-- Send HelloDone -- Send HelloDone
sendPacket ctx (Handshake [ServerHelloDone]) sendPacket ctx (Handshake [ServerHelloDone])
handshakeServerWith _ _ = fail "unexpected handshake type received. expecting client hello" handshakeServerWith _ _ = fail "unexpected handshake type received. expecting client hello"
-- after receiving a client hello, we need to redo a handshake -- after receiving a client hello, we need to redo a handshake
handshakeServer :: MonadIO m => Context -> m () handshakeServer :: MonadIO m => Context -> m ()
handshakeServer ctx = do handshakeServer ctx = do
hss <- recvPacketHandshake ctx hss <- recvPacketHandshake ctx
case hss of case hss of
[ch] -> handshakeServerWith ctx ch [ch] -> handshakeServerWith ctx ch
_ -> fail ("unexpected handshake received, excepting client hello and received " ++ show hss) _ -> fail ("unexpected handshake received, excepting client hello and received " ++ show hss)
-- | Handshake for a new TLS connection -- | Handshake for a new TLS connection
-- This is to be called at the beginning of a connection, and during renegotiation -- This is to be called at the beginning of a connection, and during renegotiation
handshake :: MonadIO m => Context -> m () handshake :: MonadIO m => Context -> m ()
handshake ctx = do handshake ctx = do
cc <- usingState_ ctx (stClientContext <$> get) cc <- usingState_ ctx (stClientContext <$> get)
liftIO $ handleException $ if cc then handshakeClient ctx else handshakeServer ctx liftIO $ handleException $ if cc then handshakeClient ctx else handshakeServer ctx
where where
handleException f = catch f $ \exception -> do handleException f = catch f $ \exception -> do
let tlserror = maybe (Error_Misc $ show exception) id $ fromException exception let tlserror = maybe (Error_Misc $ show exception) id $ fromException exception
setEstablished ctx False setEstablished ctx False
sendPacket ctx (errorToAlert tlserror) sendPacket ctx (errorToAlert tlserror)
handshakeFailed tlserror handshakeFailed tlserror
-- | sendData sends a bunch of data. -- | sendData sends a bunch of data.
-- It will automatically chunk data to acceptable packet size -- It will automatically chunk data to acceptable packet size
sendData :: MonadIO m => Context -> L.ByteString -> m () sendData :: MonadIO m => Context -> L.ByteString -> m ()
sendData ctx dataToSend = checkValid ctx >> mapM_ sendDataChunk (L.toChunks dataToSend) sendData ctx dataToSend = checkValid ctx >> mapM_ sendDataChunk (L.toChunks dataToSend)
where sendDataChunk d where sendDataChunk d
| B.length d > 16384 = do | B.length d > 16384 = do
let (sending, remain) = B.splitAt 16384 d let (sending, remain) = B.splitAt 16384 d
sendPacket ctx $ AppData sending sendPacket ctx $ AppData sending
sendDataChunk remain sendDataChunk remain
| otherwise = sendPacket ctx $ AppData d | otherwise = sendPacket ctx $ AppData d
-- | recvData get data out of Data packet, and automatically renegotiate if -- | recvData get data out of Data packet, and automatically renegotiate if
-- a Handshake ClientHello is received -- a Handshake ClientHello is received
recvData :: MonadIO m => Context -> m B.ByteString recvData :: MonadIO m => Context -> m B.ByteString
recvData ctx = do recvData ctx = do
checkValid ctx checkValid ctx
pkt <- recvPacket ctx pkt <- recvPacket ctx
case pkt of case pkt of
-- on server context receiving a client hello == renegotiation -- on server context receiving a client hello == renegotiation
Right (Handshake [ch@(ClientHello {})]) -> Right (Handshake [ch@(ClientHello {})]) ->
handshakeServerWith ctx ch >> recvData ctx handshakeServerWith ctx ch >> recvData ctx
-- on client context, receiving a hello request == renegotiation -- on client context, receiving a hello request == renegotiation
Right (Handshake [HelloRequest]) -> Right (Handshake [HelloRequest]) ->
handshakeClient ctx >> recvData ctx handshakeClient ctx >> recvData ctx
Right (Alert [(AlertLevel_Fatal, _)]) -> do Right (Alert [(AlertLevel_Fatal, _)]) -> do
setEOF ctx setEOF ctx
return B.empty return B.empty
Right (Alert [(AlertLevel_Warning, CloseNotify)]) -> do Right (Alert [(AlertLevel_Warning, CloseNotify)]) -> do
setEOF ctx setEOF ctx
return B.empty return B.empty
Right (AppData x) -> return x Right (AppData x) -> return x
Right p -> error ("error unexpected packet: " ++ show p) Right p -> error ("error unexpected packet: " ++ show p)
Left err -> error ("error received: " ++ show err) Left err -> error ("error received: " ++ show err)
recvData' :: MonadIO m => Context -> m L.ByteString recvData' :: MonadIO m => Context -> m L.ByteString
recvData' ctx = recvData ctx >>= return . L.fromChunks . (:[]) recvData' ctx = recvData ctx >>= return . L.fromChunks . (:[])

View file

@ -1,23 +1,23 @@
{-# OPTIONS_HADDOCK hide #-} {-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE ExistentialQuantification #-}
module Network.TLS.Crypto module Network.TLS.Crypto
( HashCtx(..) ( HashCtx(..)
, hashInit , hashInit
, hashUpdate , hashUpdate
, hashUpdateSSL , hashUpdateSSL
, hashFinal , hashFinal
-- * constructor -- * constructor
, hashMD5SHA1 , hashMD5SHA1
, hashSHA256 , hashSHA256
-- * key exchange generic interface -- * key exchange generic interface
, PublicKey(..) , PublicKey(..)
, PrivateKey(..) , PrivateKey(..)
, kxEncrypt , kxEncrypt
, kxDecrypt , kxDecrypt
, KxError(..) , KxError(..)
) where ) where
import qualified Crypto.Hash.SHA256 as SHA256 import qualified Crypto.Hash.SHA256 as SHA256
import qualified Crypto.Hash.SHA1 as SHA1 import qualified Crypto.Hash.SHA1 as SHA1
@ -32,48 +32,48 @@ data PublicKey = PubRSA RSA.PublicKey
data PrivateKey = PrivRSA RSA.PrivateKey data PrivateKey = PrivRSA RSA.PrivateKey
instance Show PublicKey where instance Show PublicKey where
show (_) = "PublicKey(..)" show (_) = "PublicKey(..)"
instance Show PrivateKey where instance Show PrivateKey where
show (_) = "privateKey(..)" show (_) = "privateKey(..)"
data KxError = RSAError RSA.Error data KxError = RSAError RSA.Error
deriving (Show) deriving (Show)
data KeyXchg = data KeyXchg =
KxRSA RSA.PublicKey RSA.PrivateKey KxRSA RSA.PublicKey RSA.PrivateKey
deriving (Show) deriving (Show)
class HashCtxC a where class HashCtxC a where
hashCName :: a -> String hashCName :: a -> String
hashCInit :: a -> a hashCInit :: a -> a
hashCUpdate :: a -> B.ByteString -> a hashCUpdate :: a -> B.ByteString -> a
hashCUpdateSSL :: a -> (B.ByteString,B.ByteString) -> a hashCUpdateSSL :: a -> (B.ByteString,B.ByteString) -> a
hashCFinal :: a -> B.ByteString hashCFinal :: a -> B.ByteString
data HashCtx = forall h . HashCtxC h => HashCtx h data HashCtx = forall h . HashCtxC h => HashCtx h
instance Show HashCtx where instance Show HashCtx where
show (HashCtx c) = hashCName c show (HashCtx c) = hashCName c
{- MD5 & SHA1 joined -} {- MD5 & SHA1 joined -}
data HashMD5SHA1 = HashMD5SHA1 SHA1.Ctx MD5.Ctx data HashMD5SHA1 = HashMD5SHA1 SHA1.Ctx MD5.Ctx
instance HashCtxC HashMD5SHA1 where instance HashCtxC HashMD5SHA1 where
hashCName _ = "MD5-SHA1" hashCName _ = "MD5-SHA1"
hashCInit _ = HashMD5SHA1 SHA1.init MD5.init hashCInit _ = HashMD5SHA1 SHA1.init MD5.init
hashCUpdate (HashMD5SHA1 sha1ctx md5ctx) b = HashMD5SHA1 (SHA1.update sha1ctx b) (MD5.update md5ctx b) hashCUpdate (HashMD5SHA1 sha1ctx md5ctx) b = HashMD5SHA1 (SHA1.update sha1ctx b) (MD5.update md5ctx b)
hashCUpdateSSL (HashMD5SHA1 sha1ctx md5ctx) (b1,b2) = HashMD5SHA1 (SHA1.update sha1ctx b2) (MD5.update md5ctx b1) hashCUpdateSSL (HashMD5SHA1 sha1ctx md5ctx) (b1,b2) = HashMD5SHA1 (SHA1.update sha1ctx b2) (MD5.update md5ctx b1)
hashCFinal (HashMD5SHA1 sha1ctx md5ctx) = B.concat [MD5.finalize md5ctx, SHA1.finalize sha1ctx] hashCFinal (HashMD5SHA1 sha1ctx md5ctx) = B.concat [MD5.finalize md5ctx, SHA1.finalize sha1ctx]
data HashSHA256 = HashSHA256 SHA256.Ctx data HashSHA256 = HashSHA256 SHA256.Ctx
instance HashCtxC HashSHA256 where instance HashCtxC HashSHA256 where
hashCName _ = "SHA256" hashCName _ = "SHA256"
hashCInit _ = HashSHA256 SHA256.init hashCInit _ = HashSHA256 SHA256.init
hashCUpdate (HashSHA256 ctx) b = HashSHA256 (SHA256.update ctx b) hashCUpdate (HashSHA256 ctx) b = HashSHA256 (SHA256.update ctx b)
hashCUpdateSSL _ _ = undefined hashCUpdateSSL _ _ = undefined
hashCFinal (HashSHA256 ctx) = SHA256.finalize ctx hashCFinal (HashSHA256 ctx) = SHA256.finalize ctx
-- functions to use the hidden class. -- functions to use the hidden class.
hashInit :: HashCtx -> HashCtx hashInit :: HashCtx -> HashCtx

View file

@ -7,13 +7,13 @@
-- Portability : unknown -- Portability : unknown
-- --
module Network.TLS.Internal module Network.TLS.Internal
( module Network.TLS.Struct ( module Network.TLS.Struct
, module Network.TLS.Packet , module Network.TLS.Packet
, module Network.TLS.Receiving , module Network.TLS.Receiving
, module Network.TLS.Sending , module Network.TLS.Sending
, sendPacket , sendPacket
, recvPacket , recvPacket
) where ) where
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Packet import Network.TLS.Packet

View file

@ -1,14 +1,14 @@
module Network.TLS.MAC module Network.TLS.MAC
( hmacMD5 ( hmacMD5
, hmacSHA1 , hmacSHA1
, hmacSHA256 , hmacSHA256
, macSSL , macSSL
, hmac , hmac
, prf_MD5 , prf_MD5
, prf_SHA1 , prf_SHA1
, prf_SHA256 , prf_SHA256
, prf_MD5SHA1 , prf_MD5SHA1
) where ) where
import qualified Crypto.Hash.MD5 as MD5 import qualified Crypto.Hash.MD5 as MD5
import qualified Crypto.Hash.SHA1 as SHA1 import qualified Crypto.Hash.SHA1 as SHA1
@ -21,22 +21,22 @@ type HMAC = ByteString -> ByteString -> ByteString
macSSL :: (ByteString -> ByteString) -> HMAC macSSL :: (ByteString -> ByteString) -> HMAC
macSSL f secret msg = f $! B.concat [ secret, B.replicate padlen 0x5c, macSSL f secret msg = f $! B.concat [ secret, B.replicate padlen 0x5c,
f $! B.concat [ secret, B.replicate padlen 0x36, msg ] ] f $! B.concat [ secret, B.replicate padlen 0x36, msg ] ]
where where
-- get the type of algorithm out of the digest length by using the hash fct. -- get the type of algorithm out of the digest length by using the hash fct.
padlen = if (B.length $ f B.empty) == 16 then 48 else 40 padlen = if (B.length $ f B.empty) == 16 then 48 else 40
hmac :: (ByteString -> ByteString) -> Int -> HMAC hmac :: (ByteString -> ByteString) -> Int -> HMAC
hmac f bl secret msg = hmac f bl secret msg =
f $! B.append opad (f $! B.append ipad msg) f $! B.append opad (f $! B.append ipad msg)
where where
opad = B.map (xor 0x5c) k' opad = B.map (xor 0x5c) k'
ipad = B.map (xor 0x36) k' ipad = B.map (xor 0x36) k'
k' = B.append kt pad k' = B.append kt pad
where where
kt = if B.length secret > fromIntegral bl then f secret else secret kt = if B.length secret > fromIntegral bl then f secret else secret
pad = B.replicate (fromIntegral bl - B.length kt) 0 pad = B.replicate (fromIntegral bl - B.length kt) 0
hmacMD5 :: HMAC hmacMD5 :: HMAC
hmacMD5 secret msg = hmac MD5.hash 64 secret msg hmacMD5 secret msg = hmac MD5.hash 64 secret msg
@ -49,12 +49,12 @@ hmacSHA256 secret msg = hmac SHA256.hash 64 secret msg
hmacIter :: HMAC -> ByteString -> ByteString -> ByteString -> Int -> [ByteString] hmacIter :: HMAC -> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter f secret seed aprev len = hmacIter f secret seed aprev len =
let an = f secret aprev in let an = f secret aprev in
let out = f secret (B.concat [an, seed]) in let out = f secret (B.concat [an, seed]) in
let digestsize = fromIntegral $ B.length out in let digestsize = fromIntegral $ B.length out in
if digestsize >= len if digestsize >= len
then [ B.take (fromIntegral len) out ] then [ B.take (fromIntegral len) out ]
else out : hmacIter f secret seed an (len - digestsize) else out : hmacIter f secret seed an (len - digestsize)
prf_SHA1 :: ByteString -> ByteString -> Int -> ByteString prf_SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA1 secret seed len = B.concat $ hmacIter hmacSHA1 secret seed seed len prf_SHA1 secret seed len = B.concat $ hmacIter hmacSHA1 secret seed seed len
@ -64,11 +64,11 @@ prf_MD5 secret seed len = B.concat $ hmacIter hmacMD5 secret seed seed len
prf_MD5SHA1 :: ByteString -> ByteString -> Int -> ByteString prf_MD5SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5SHA1 secret seed len = prf_MD5SHA1 secret seed len =
B.pack $ B.zipWith xor (prf_MD5 s1 seed len) (prf_SHA1 s2 seed len) B.pack $ B.zipWith xor (prf_MD5 s1 seed len) (prf_SHA1 s2 seed len)
where where
slen = B.length secret slen = B.length secret
s1 = B.take (slen `div` 2 + slen `mod` 2) secret s1 = B.take (slen `div` 2 + slen `mod` 2) secret
s2 = B.drop (slen `div` 2) secret s2 = B.drop (slen `div` 2) secret
prf_SHA256 :: ByteString -> ByteString -> Int -> ByteString prf_SHA256 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA256 secret seed len = B.concat $ hmacIter hmacSHA256 secret seed seed len prf_SHA256 secret seed len = B.concat $ hmacIter hmacSHA256 secret seed seed len

View file

@ -6,41 +6,41 @@
-- Portability : unknown -- Portability : unknown
-- --
module Network.TLS.Measurement module Network.TLS.Measurement
( Measurement(..) ( Measurement(..)
, newMeasurement , newMeasurement
, addBytesReceived , addBytesReceived
, addBytesSent , addBytesSent
, resetBytesCounters , resetBytesCounters
, incrementNbHandshakes , incrementNbHandshakes
) where ) where
import Data.Word import Data.Word
-- | record some data about this connection. -- | record some data about this connection.
data Measurement = Measurement data Measurement = Measurement
{ nbHandshakes :: !Word32 -- ^ number of handshakes on this context { nbHandshakes :: !Word32 -- ^ number of handshakes on this context
, bytesReceived :: !Word32 -- ^ bytes received since last handshake , bytesReceived :: !Word32 -- ^ bytes received since last handshake
, bytesSent :: !Word32 -- ^ bytes sent since last handshake , bytesSent :: !Word32 -- ^ bytes sent since last handshake
} deriving (Show,Eq) } deriving (Show,Eq)
newMeasurement :: Measurement newMeasurement :: Measurement
newMeasurement = Measurement newMeasurement = Measurement
{ nbHandshakes = 0 { nbHandshakes = 0
, bytesReceived = 0 , bytesReceived = 0
, bytesSent = 0 , bytesSent = 0
} }
addBytesReceived :: Int -> Measurement -> Measurement addBytesReceived :: Int -> Measurement -> Measurement
addBytesReceived sz measure = addBytesReceived sz measure =
measure { bytesReceived = bytesReceived measure + fromIntegral sz } measure { bytesReceived = bytesReceived measure + fromIntegral sz }
addBytesSent :: Int -> Measurement -> Measurement addBytesSent :: Int -> Measurement -> Measurement
addBytesSent sz measure = addBytesSent sz measure =
measure { bytesSent = bytesSent measure + fromIntegral sz } measure { bytesSent = bytesSent measure + fromIntegral sz }
resetBytesCounters :: Measurement -> Measurement resetBytesCounters :: Measurement -> Measurement
resetBytesCounters measure = measure { bytesReceived = 0, bytesSent = 0 } resetBytesCounters measure = measure { bytesReceived = 0, bytesSent = 0 }
incrementNbHandshakes :: Measurement -> Measurement incrementNbHandshakes :: Measurement -> Measurement
incrementNbHandshakes measure = incrementNbHandshakes measure =
measure { nbHandshakes = nbHandshakes measure + 1 } measure { nbHandshakes = nbHandshakes measure + 1 }

View file

@ -10,46 +10,46 @@
-- with only explicit parameters, no TLS state is involved here. -- with only explicit parameters, no TLS state is involved here.
-- --
module Network.TLS.Packet module Network.TLS.Packet
( (
-- * params for encoding and decoding -- * params for encoding and decoding
CurrentParams(..) CurrentParams(..)
-- * marshall functions for header messages -- * marshall functions for header messages
, decodeHeader , decodeHeader
, encodeHeader , encodeHeader
, encodeHeaderNoVer -- use for SSL3 , encodeHeaderNoVer -- use for SSL3
-- * marshall functions for alert messages -- * marshall functions for alert messages
, decodeAlert , decodeAlert
, decodeAlerts , decodeAlerts
, encodeAlerts , encodeAlerts
-- * marshall functions for handshake messages -- * marshall functions for handshake messages
, decodeHandshakes , decodeHandshakes
, decodeHandshake , decodeHandshake
, encodeHandshake , encodeHandshake
, encodeHandshakes , encodeHandshakes
, encodeHandshakeHeader , encodeHandshakeHeader
, encodeHandshakeContent , encodeHandshakeContent
-- * marshall functions for change cipher spec message -- * marshall functions for change cipher spec message
, decodeChangeCipherSpec , decodeChangeCipherSpec
, encodeChangeCipherSpec , encodeChangeCipherSpec
, decodePreMasterSecret , decodePreMasterSecret
, encodePreMasterSecret , encodePreMasterSecret
-- * marshall extensions -- * marshall extensions
, decodeExtSecureRenegotiation , decodeExtSecureRenegotiation
, encodeExtSecureRenegotiation , encodeExtSecureRenegotiation
, decodeExtNextProtocolNegotiation , decodeExtNextProtocolNegotiation
, encodeExtNextProtocolNegotiation , encodeExtNextProtocolNegotiation
-- * generate things for packet content -- * generate things for packet content
, generateMasterSecret , generateMasterSecret
, generateKeyBlock , generateKeyBlock
, generateClientFinished , generateClientFinished
, generateServerFinished , generateServerFinished
) where ) where
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Wire import Network.TLS.Wire
@ -72,10 +72,10 @@ import qualified Crypto.Hash.SHA1 as SHA1
import qualified Crypto.Hash.MD5 as MD5 import qualified Crypto.Hash.MD5 as MD5
data CurrentParams = CurrentParams data CurrentParams = CurrentParams
{ cParamsVersion :: Version -- ^ current protocol version { cParamsVersion :: Version -- ^ current protocol version
, cParamsKeyXchgType :: CipherKeyExchangeType -- ^ current key exchange type , cParamsKeyXchgType :: CipherKeyExchangeType -- ^ current key exchange type
, cParamsSupportNPN :: Bool -- ^ support Next Protocol Negotiation extension , cParamsSupportNPN :: Bool -- ^ support Next Protocol Negotiation extension
} deriving (Show,Eq) } deriving (Show,Eq)
runGetErr :: String -> Get a -> ByteString -> Either TLSError a runGetErr :: String -> Get a -> ByteString -> Either TLSError a
runGetErr lbl f = either (Left . Error_Packet_Parsing) Right . runGet lbl f runGetErr lbl f = either (Left . Error_Packet_Parsing) Right . runGet lbl f
@ -83,32 +83,32 @@ runGetErr lbl f = either (Left . Error_Packet_Parsing) Right . runGet lbl f
{- marshall helpers -} {- marshall helpers -}
getVersion :: Get Version getVersion :: Get Version
getVersion = do getVersion = do
major <- getWord8 major <- getWord8
minor <- getWord8 minor <- getWord8
case verOfNum (major, minor) of case verOfNum (major, minor) of
Nothing -> fail ("invalid version : " ++ show major ++ "," ++ show minor) Nothing -> fail ("invalid version : " ++ show major ++ "," ++ show minor)
Just v -> return v Just v -> return v
putVersion :: Version -> Put putVersion :: Version -> Put
putVersion ver = putWord8 major >> putWord8 minor putVersion ver = putWord8 major >> putWord8 minor
where (major, minor) = numericalVer ver where (major, minor) = numericalVer ver
getHeaderType :: Get ProtocolType getHeaderType :: Get ProtocolType
getHeaderType = do getHeaderType = do
ty <- getWord8 ty <- getWord8
case valToType ty of case valToType ty of
Nothing -> fail ("invalid header type: " ++ show ty) Nothing -> fail ("invalid header type: " ++ show ty)
Just t -> return t Just t -> return t
putHeaderType :: ProtocolType -> Put putHeaderType :: ProtocolType -> Put
putHeaderType = putWord8 . valOfType putHeaderType = putWord8 . valOfType
getHandshakeType :: Get HandshakeType getHandshakeType :: Get HandshakeType
getHandshakeType = do getHandshakeType = do
ty <- getWord8 ty <- getWord8
case valToType ty of case valToType ty of
Nothing -> fail ("invalid handshake type: " ++ show ty) Nothing -> fail ("invalid handshake type: " ++ show ty)
Just t -> return t Just t -> return t
{- {-
- decode and encode headers - decode and encode headers
@ -118,122 +118,122 @@ decodeHeader = runGetErr "header" $ liftM3 Header getHeaderType getVersion getWo
encodeHeader :: Header -> ByteString encodeHeader :: Header -> ByteString
encodeHeader (Header pt ver len) = runPut (putHeaderType pt >> putVersion ver >> putWord16 len) encodeHeader (Header pt ver len) = runPut (putHeaderType pt >> putVersion ver >> putWord16 len)
{- FIXME check len <= 2^14 -} {- FIXME check len <= 2^14 -}
encodeHeaderNoVer :: Header -> ByteString encodeHeaderNoVer :: Header -> ByteString
encodeHeaderNoVer (Header pt _ len) = runPut (putHeaderType pt >> putWord16 len) encodeHeaderNoVer (Header pt _ len) = runPut (putHeaderType pt >> putWord16 len)
{- FIXME check len <= 2^14 -} {- FIXME check len <= 2^14 -}
{- {-
- decode and encode ALERT - decode and encode ALERT
-} -}
decodeAlert :: Get (AlertLevel, AlertDescription) decodeAlert :: Get (AlertLevel, AlertDescription)
decodeAlert = do decodeAlert = do
al <- getWord8 al <- getWord8
ad <- getWord8 ad <- getWord8
case (valToType al, valToType ad) of case (valToType al, valToType ad) of
(Just a, Just d) -> return (a, d) (Just a, Just d) -> return (a, d)
(Nothing, _) -> fail "cannot decode alert level" (Nothing, _) -> fail "cannot decode alert level"
(_, Nothing) -> fail "cannot decode alert description" (_, Nothing) -> fail "cannot decode alert description"
decodeAlerts :: ByteString -> Either TLSError [(AlertLevel, AlertDescription)] decodeAlerts :: ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts = runGetErr "alerts" $ loop decodeAlerts = runGetErr "alerts" $ loop
where loop = do where loop = do
r <- remaining r <- remaining
if r == 0 if r == 0
then return [] then return []
else liftM2 (:) decodeAlert loop else liftM2 (:) decodeAlert loop
encodeAlerts :: [(AlertLevel, AlertDescription)] -> ByteString encodeAlerts :: [(AlertLevel, AlertDescription)] -> ByteString
encodeAlerts l = runPut $ mapM_ encodeAlert l encodeAlerts l = runPut $ mapM_ encodeAlert l
where encodeAlert (al, ad) = putWord8 (valOfType al) >> putWord8 (valOfType ad) where encodeAlert (al, ad) = putWord8 (valOfType al) >> putWord8 (valOfType ad)
{- decode and encode HANDSHAKE -} {- decode and encode HANDSHAKE -}
decodeHandshakeHeader :: Get (HandshakeType, Bytes) decodeHandshakeHeader :: Get (HandshakeType, Bytes)
decodeHandshakeHeader = do decodeHandshakeHeader = do
ty <- getHandshakeType ty <- getHandshakeType
content <- getOpaque24 content <- getOpaque24
return (ty, content) return (ty, content)
decodeHandshakes :: ByteString -> Either TLSError [(HandshakeType, Bytes)] decodeHandshakes :: ByteString -> Either TLSError [(HandshakeType, Bytes)]
decodeHandshakes b = runGetErr "handshakes" getAll b where decodeHandshakes b = runGetErr "handshakes" getAll b where
getAll = do getAll = do
x <- decodeHandshakeHeader x <- decodeHandshakeHeader
empty <- isEmpty empty <- isEmpty
if empty if empty
then return [x] then return [x]
else getAll >>= \l -> return (x : l) else getAll >>= \l -> return (x : l)
decodeHandshake :: CurrentParams -> HandshakeType -> ByteString -> Either TLSError Handshake decodeHandshake :: CurrentParams -> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake cp ty = runGetErr "handshake" $ case ty of decodeHandshake cp ty = runGetErr "handshake" $ case ty of
HandshakeType_HelloRequest -> decodeHelloRequest HandshakeType_HelloRequest -> decodeHelloRequest
HandshakeType_ClientHello -> decodeClientHello HandshakeType_ClientHello -> decodeClientHello
HandshakeType_ServerHello -> decodeServerHello HandshakeType_ServerHello -> decodeServerHello
HandshakeType_Certificate -> decodeCertificates HandshakeType_Certificate -> decodeCertificates
HandshakeType_ServerKeyXchg -> decodeServerKeyXchg cp HandshakeType_ServerKeyXchg -> decodeServerKeyXchg cp
HandshakeType_CertRequest -> decodeCertRequest cp HandshakeType_CertRequest -> decodeCertRequest cp
HandshakeType_ServerHelloDone -> decodeServerHelloDone HandshakeType_ServerHelloDone -> decodeServerHelloDone
HandshakeType_CertVerify -> decodeCertVerify HandshakeType_CertVerify -> decodeCertVerify
HandshakeType_ClientKeyXchg -> decodeClientKeyXchg HandshakeType_ClientKeyXchg -> decodeClientKeyXchg
HandshakeType_Finished -> decodeFinished HandshakeType_Finished -> decodeFinished
HandshakeType_NPN -> do HandshakeType_NPN -> do
unless (cParamsSupportNPN cp) $ fail "unsupported handshake type" unless (cParamsSupportNPN cp) $ fail "unsupported handshake type"
decodeNextProtocolNegotiation decodeNextProtocolNegotiation
decodeHelloRequest :: Get Handshake decodeHelloRequest :: Get Handshake
decodeHelloRequest = return HelloRequest decodeHelloRequest = return HelloRequest
decodeClientHello :: Get Handshake decodeClientHello :: Get Handshake
decodeClientHello = do decodeClientHello = do
ver <- getVersion ver <- getVersion
random <- getClientRandom32 random <- getClientRandom32
session <- getSession session <- getSession
ciphers <- getWords16 ciphers <- getWords16
compressions <- getWords8 compressions <- getWords8
r <- remaining r <- remaining
exts <- if hasHelloExtensions ver && r > 0 exts <- if hasHelloExtensions ver && r > 0
then fmap fromIntegral getWord16 >>= getExtensions then fmap fromIntegral getWord16 >>= getExtensions
else return [] else return []
return $ ClientHello ver random session ciphers compressions exts return $ ClientHello ver random session ciphers compressions exts
decodeServerHello :: Get Handshake decodeServerHello :: Get Handshake
decodeServerHello = do decodeServerHello = do
ver <- getVersion ver <- getVersion
random <- getServerRandom32 random <- getServerRandom32
session <- getSession session <- getSession
cipherid <- getWord16 cipherid <- getWord16
compressionid <- getWord8 compressionid <- getWord8
r <- remaining r <- remaining
exts <- if hasHelloExtensions ver && r > 0 exts <- if hasHelloExtensions ver && r > 0
then fmap fromIntegral getWord16 >>= getExtensions then fmap fromIntegral getWord16 >>= getExtensions
else return [] else return []
return $ ServerHello ver random session cipherid compressionid exts return $ ServerHello ver random session cipherid compressionid exts
decodeServerHelloDone :: Get Handshake decodeServerHelloDone :: Get Handshake
decodeServerHelloDone = return ServerHelloDone decodeServerHelloDone = return ServerHelloDone
decodeCertificates :: Get Handshake decodeCertificates :: Get Handshake
decodeCertificates = do decodeCertificates = do
certs <- getWord24 >>= getCerts >>= return . map (decodeCertificate . L.fromChunks . (:[])) certs <- getWord24 >>= getCerts >>= return . map (decodeCertificate . L.fromChunks . (:[]))
let (l, r) = partitionEithers certs let (l, r) = partitionEithers certs
if length l > 0 if length l > 0
then fail ("error certificate parsing: " ++ show l) then fail ("error certificate parsing: " ++ show l)
else return $ Certificates r else return $ Certificates r
decodeFinished :: Get Handshake decodeFinished :: Get Handshake
decodeFinished = Finished <$> (remaining >>= getBytes) decodeFinished = Finished <$> (remaining >>= getBytes)
decodeNextProtocolNegotiation :: Get Handshake decodeNextProtocolNegotiation :: Get Handshake
decodeNextProtocolNegotiation = do decodeNextProtocolNegotiation = do
opaque <- getOpaque8 opaque <- getOpaque8
_ <- getOpaque8 _ <- getOpaque8
return $ NextProtocolNegotiation opaque return $ NextProtocolNegotiation opaque
getSignatureHashAlgorithm :: Get (HashAlgorithm, SignatureAlgorithm) getSignatureHashAlgorithm :: Get (HashAlgorithm, SignatureAlgorithm)
getSignatureHashAlgorithm = do getSignatureHashAlgorithm = do
h <- fromJust . valToType <$> getWord8 h <- fromJust . valToType <$> getWord8
s <- fromJust . valToType <$> getWord8 s <- fromJust . valToType <$> getWord8
return (h,s) return (h,s)
getSignatureHashAlgorithms :: Int -> Get [ (HashAlgorithm, SignatureAlgorithm) ] getSignatureHashAlgorithms :: Int -> Get [ (HashAlgorithm, SignatureAlgorithm) ]
getSignatureHashAlgorithms 0 = return [] getSignatureHashAlgorithms 0 = return []
@ -241,22 +241,22 @@ getSignatureHashAlgorithms len = liftM2 (:) getSignatureHashAlgorithm (getSignat
decodeCertRequest :: CurrentParams -> Get Handshake decodeCertRequest :: CurrentParams -> Get Handshake
decodeCertRequest cp = do decodeCertRequest cp = do
certTypes <- map (fromJust . valToType . fromIntegral) <$> getWords8 certTypes <- map (fromJust . valToType . fromIntegral) <$> getWords8
sigHashAlgs <- if cParamsVersion cp >= TLS12 sigHashAlgs <- if cParamsVersion cp >= TLS12
then do then do
sighashlen <- getWord16 sighashlen <- getWord16
Just <$> getSignatureHashAlgorithms (fromIntegral sighashlen) Just <$> getSignatureHashAlgorithms (fromIntegral sighashlen)
else return Nothing else return Nothing
dNameLen <- getWord16 dNameLen <- getWord16
when (cParamsVersion cp < TLS12 && dNameLen < 3) $ fail "certrequest distinguishname not of the correct size" when (cParamsVersion cp < TLS12 && dNameLen < 3) $ fail "certrequest distinguishname not of the correct size"
dName <- getBytes $ fromIntegral dNameLen dName <- getBytes $ fromIntegral dNameLen
return $ CertRequest certTypes sigHashAlgs (B.unpack dName) return $ CertRequest certTypes sigHashAlgs (B.unpack dName)
decodeCertVerify :: Get Handshake decodeCertVerify :: Get Handshake
decodeCertVerify = decodeCertVerify =
{- FIXME -} {- FIXME -}
return $ CertVerify [] return $ CertVerify []
decodeClientKeyXchg :: Get Handshake decodeClientKeyXchg :: Get Handshake
decodeClientKeyXchg = ClientKeyXchg <$> (remaining >>= getBytes) decodeClientKeyXchg = ClientKeyXchg <$> (remaining >>= getBytes)
@ -266,39 +266,39 @@ os2ip = B.foldl' (\a b -> (256 * a) .|. (fromIntegral b)) 0
decodeServerKeyXchg_DH :: Get ServerDHParams decodeServerKeyXchg_DH :: Get ServerDHParams
decodeServerKeyXchg_DH = do decodeServerKeyXchg_DH = do
p <- getOpaque16 p <- getOpaque16
g <- getOpaque16 g <- getOpaque16
y <- getOpaque16 y <- getOpaque16
return $ ServerDHParams { dh_p = os2ip p, dh_g = os2ip g, dh_Ys = os2ip y } return $ ServerDHParams { dh_p = os2ip p, dh_g = os2ip g, dh_Ys = os2ip y }
decodeServerKeyXchg_RSA :: Get ServerRSAParams decodeServerKeyXchg_RSA :: Get ServerRSAParams
decodeServerKeyXchg_RSA = do decodeServerKeyXchg_RSA = do
modulus <- getOpaque16 modulus <- getOpaque16
expo <- getOpaque16 expo <- getOpaque16
return $ ServerRSAParams { rsa_modulus = os2ip modulus, rsa_exponent = os2ip expo } return $ ServerRSAParams { rsa_modulus = os2ip modulus, rsa_exponent = os2ip expo }
decodeServerKeyXchg :: CurrentParams -> Get Handshake decodeServerKeyXchg :: CurrentParams -> Get Handshake
decodeServerKeyXchg cp = ServerKeyXchg <$> case cParamsKeyXchgType cp of decodeServerKeyXchg cp = ServerKeyXchg <$> case cParamsKeyXchgType cp of
CipherKeyExchange_RSA -> SKX_RSA . Just <$> decodeServerKeyXchg_RSA CipherKeyExchange_RSA -> SKX_RSA . Just <$> decodeServerKeyXchg_RSA
CipherKeyExchange_DH_Anon -> SKX_DH_Anon <$> decodeServerKeyXchg_DH CipherKeyExchange_DH_Anon -> SKX_DH_Anon <$> decodeServerKeyXchg_DH
CipherKeyExchange_DHE_RSA -> do CipherKeyExchange_DHE_RSA -> do
dhparams <- decodeServerKeyXchg_DH dhparams <- decodeServerKeyXchg_DH
signature <- getOpaque16 signature <- getOpaque16
return $ SKX_DHE_RSA dhparams (B.unpack signature) return $ SKX_DHE_RSA dhparams (B.unpack signature)
CipherKeyExchange_DHE_DSS -> do CipherKeyExchange_DHE_DSS -> do
dhparams <- decodeServerKeyXchg_DH dhparams <- decodeServerKeyXchg_DH
signature <- getOpaque16 signature <- getOpaque16
return $ SKX_DHE_DSS dhparams (B.unpack signature) return $ SKX_DHE_DSS dhparams (B.unpack signature)
_ -> do _ -> do
bs <- remaining >>= getBytes bs <- remaining >>= getBytes
return $ SKX_Unknown bs return $ SKX_Unknown bs
encodeHandshake :: Handshake -> ByteString encodeHandshake :: Handshake -> ByteString
encodeHandshake o = encodeHandshake o =
let content = runPut $ encodeHandshakeContent o in let content = runPut $ encodeHandshakeContent o in
let len = fromIntegral $ B.length content in let len = fromIntegral $ B.length content in
let header = runPut $ encodeHandshakeHeader (typeOfHandshake o) len in let header = runPut $ encodeHandshakeHeader (typeOfHandshake o) len in
B.concat [ header, content ] B.concat [ header, content ]
encodeHandshakes :: [Handshake] -> ByteString encodeHandshakes :: [Handshake] -> ByteString
encodeHandshakes hss = B.concat $ map encodeHandshake hss encodeHandshakes hss = B.concat $ map encodeHandshake hss
@ -309,46 +309,46 @@ encodeHandshakeHeader ty len = putWord8 (valOfType ty) >> putWord24 len
encodeHandshakeContent :: Handshake -> Put encodeHandshakeContent :: Handshake -> Put
encodeHandshakeContent (ClientHello version random session cipherIDs compressionIDs exts) = do encodeHandshakeContent (ClientHello version random session cipherIDs compressionIDs exts) = do
putVersion version putVersion version
putClientRandom32 random putClientRandom32 random
putSession session putSession session
putWords16 cipherIDs putWords16 cipherIDs
putWords8 compressionIDs putWords8 compressionIDs
putExtensions exts putExtensions exts
return () return ()
encodeHandshakeContent (ServerHello version random session cipherID compressionID exts) = encodeHandshakeContent (ServerHello version random session cipherID compressionID exts) =
putVersion version >> putServerRandom32 random >> putSession session putVersion version >> putServerRandom32 random >> putSession session
>> putWord16 cipherID >> putWord8 compressionID >> putWord16 cipherID >> putWord8 compressionID
>> putExtensions exts >> return () >> putExtensions exts >> return ()
encodeHandshakeContent (Certificates certs) = putOpaque24 (runPut $ mapM_ putCert certs) encodeHandshakeContent (Certificates certs) = putOpaque24 (runPut $ mapM_ putCert certs)
encodeHandshakeContent (ClientKeyXchg content) = do encodeHandshakeContent (ClientKeyXchg content) = do
putBytes content putBytes content
encodeHandshakeContent (ServerKeyXchg _) = do encodeHandshakeContent (ServerKeyXchg _) = do
-- FIXME -- FIXME
return () return ()
encodeHandshakeContent (HelloRequest) = return () encodeHandshakeContent (HelloRequest) = return ()
encodeHandshakeContent (ServerHelloDone) = return () encodeHandshakeContent (ServerHelloDone) = return ()
encodeHandshakeContent (CertRequest certTypes sigAlgs certAuthorities) = do encodeHandshakeContent (CertRequest certTypes sigAlgs certAuthorities) = do
putWords8 (map valOfType certTypes) putWords8 (map valOfType certTypes)
case sigAlgs of case sigAlgs of
Nothing -> return () Nothing -> return ()
Just l -> putWords16 $ map (\(x,y) -> (fromIntegral $ valOfType x) * 256 + (fromIntegral $ valOfType y)) l Just l -> putWords16 $ map (\(x,y) -> (fromIntegral $ valOfType x) * 256 + (fromIntegral $ valOfType y)) l
putBytes $ B.pack certAuthorities putBytes $ B.pack certAuthorities
encodeHandshakeContent (CertVerify _) = undefined encodeHandshakeContent (CertVerify _) = undefined
encodeHandshakeContent (Finished opaque) = putBytes opaque encodeHandshakeContent (Finished opaque) = putBytes opaque
encodeHandshakeContent (NextProtocolNegotiation protocol) = do encodeHandshakeContent (NextProtocolNegotiation protocol) = do
putOpaque8 protocol putOpaque8 protocol
putOpaque8 $ B.replicate paddingLen 0 putOpaque8 $ B.replicate paddingLen 0
where paddingLen = 32 - ((B.length protocol + 2) `mod` 32) where paddingLen = 32 - ((B.length protocol + 2) `mod` 32)
{- FIXME make sure it return error if not 32 available -} {- FIXME make sure it return error if not 32 available -}
getRandom32 :: Get Bytes getRandom32 :: Get Bytes
@ -371,10 +371,10 @@ putServerRandom32 (ServerRandom r) = putRandom32 r
getSession :: Get Session getSession :: Get Session
getSession = do getSession = do
len8 <- getWord8 len8 <- getWord8
case fromIntegral len8 of case fromIntegral len8 of
0 -> return $ Session Nothing 0 -> return $ Session Nothing
len -> Session . Just <$> getBytes len len -> Session . Just <$> getBytes len
putSession :: Session -> Put putSession :: Session -> Put
putSession (Session Nothing) = putWord8 0 putSession (Session Nothing) = putWord8 0
@ -383,10 +383,10 @@ putSession (Session (Just s)) = putOpaque8 s
getCerts :: Int -> Get [Bytes] getCerts :: Int -> Get [Bytes]
getCerts 0 = return [] getCerts 0 = return []
getCerts len = do getCerts len = do
certlen <- getWord24 certlen <- getWord24
cert <- getBytes certlen cert <- getBytes certlen
certxs <- getCerts (len - certlen - 3) certxs <- getCerts (len - certlen - 3)
return (cert : certxs) return (cert : certxs)
putCert :: X509 -> Put putCert :: X509 -> Put
putCert cert = putOpaque24 (B.concat $ L.toChunks $ encodeCertificate cert) putCert cert = putOpaque24 (B.concat $ L.toChunks $ encodeCertificate cert)
@ -394,11 +394,11 @@ putCert cert = putOpaque24 (B.concat $ L.toChunks $ encodeCertificate cert)
getExtensions :: Int -> Get [Extension] getExtensions :: Int -> Get [Extension]
getExtensions 0 = return [] getExtensions 0 = return []
getExtensions len = do getExtensions len = do
extty <- getWord16 extty <- getWord16
extdatalen <- getWord16 extdatalen <- getWord16
extdata <- getBytes $ fromIntegral extdatalen extdata <- getBytes $ fromIntegral extdatalen
extxs <- getExtensions (len - fromIntegral extdatalen - 4) extxs <- getExtensions (len - fromIntegral extdatalen - 4)
return $ (extty, extdata) : extxs return $ (extty, extdata) : extxs
putExtension :: Extension -> Put putExtension :: Extension -> Put
putExtension (ty, l) = putWord16 ty >> putOpaque16 l putExtension (ty, l) = putWord16 ty >> putOpaque16 l
@ -413,8 +413,8 @@ putExtensions es = putOpaque16 (runPut $ mapM_ putExtension es)
decodeChangeCipherSpec :: ByteString -> Either TLSError () decodeChangeCipherSpec :: ByteString -> Either TLSError ()
decodeChangeCipherSpec = runGetErr "changecipherspec" $ do decodeChangeCipherSpec = runGetErr "changecipherspec" $ do
x <- getWord8 x <- getWord8
when (x /= 1) (fail "unknown change cipher spec content") when (x /= 1) (fail "unknown change cipher spec content")
encodeChangeCipherSpec :: ByteString encodeChangeCipherSpec :: ByteString
encodeChangeCipherSpec = runPut (putWord8 1) encodeChangeCipherSpec = runPut (putWord8 1)
@ -425,18 +425,18 @@ encodeChangeCipherSpec = runPut (putWord8 1)
-} -}
decodeExtSecureRenegotiation :: Bool -> Bytes -> Either TLSError (Bytes, Maybe Bytes) decodeExtSecureRenegotiation :: Bool -> Bytes -> Either TLSError (Bytes, Maybe Bytes)
decodeExtSecureRenegotiation isServerHello = runGetErr "ext-secure-renegotiation" $ do decodeExtSecureRenegotiation isServerHello = runGetErr "ext-secure-renegotiation" $ do
l <- fromIntegral <$> getWord8 l <- fromIntegral <$> getWord8
if isServerHello if isServerHello
then do then do
cvd <- getBytes (l `div` 2) cvd <- getBytes (l `div` 2)
svd <- getBytes (l `div` 2) svd <- getBytes (l `div` 2)
return (cvd, Just svd) return (cvd, Just svd)
else getBytes (l `div` 2) >>= \cvd -> return (cvd, Nothing) else getBytes (l `div` 2) >>= \cvd -> return (cvd, Nothing)
encodeExtSecureRenegotiation :: Bytes -> Maybe Bytes -> Bytes encodeExtSecureRenegotiation :: Bytes -> Maybe Bytes -> Bytes
encodeExtSecureRenegotiation cvd msvd = runPut $ do encodeExtSecureRenegotiation cvd msvd = runPut $ do
let svd = maybe B.empty id msvd let svd = maybe B.empty id msvd
putOpaque8 (cvd `B.append` svd) putOpaque8 (cvd `B.append` svd)
decodeExtNextProtocolNegotiation :: Bytes -> Either TLSError [Bytes] decodeExtNextProtocolNegotiation :: Bytes -> Either TLSError [Bytes]
decodeExtNextProtocolNegotiation = runGetErr "ext-next-protocol-negotiation" p decodeExtNextProtocolNegotiation = runGetErr "ext-next-protocol-negotiation" p
@ -451,7 +451,7 @@ encodeExtNextProtocolNegotiation = runPut . mapM_ putOpaque8
-- rsa pre master secret -- rsa pre master secret
decodePreMasterSecret :: Bytes -> Either TLSError (Version, Bytes) decodePreMasterSecret :: Bytes -> Either TLSError (Version, Bytes)
decodePreMasterSecret = runGetErr "pre-master-secret" $ do decodePreMasterSecret = runGetErr "pre-master-secret" $ do
liftM2 (,) getVersion (getBytes 46) liftM2 (,) getVersion (getBytes 46)
encodePreMasterSecret :: Version -> Bytes -> Bytes encodePreMasterSecret :: Version -> Bytes -> Bytes
encodePreMasterSecret version bytes = runPut (putVersion version >> putBytes bytes) encodePreMasterSecret version bytes = runPut (putVersion version >> putBytes bytes)
@ -463,16 +463,16 @@ type PRF = Bytes -> Bytes -> Int -> Bytes
generateMasterSecret_SSL :: Bytes -> ClientRandom -> ServerRandom -> Bytes generateMasterSecret_SSL :: Bytes -> ClientRandom -> ServerRandom -> Bytes
generateMasterSecret_SSL premasterSecret (ClientRandom c) (ServerRandom s) = generateMasterSecret_SSL premasterSecret (ClientRandom c) (ServerRandom s) =
B.concat $ map (computeMD5) ["A","BB","CCC"] B.concat $ map (computeMD5) ["A","BB","CCC"]
where where
computeMD5 label = MD5.hash $ B.concat [ premasterSecret, computeSHA1 label ] computeMD5 label = MD5.hash $ B.concat [ premasterSecret, computeSHA1 label ]
computeSHA1 label = SHA1.hash $ B.concat [ label, premasterSecret, c, s ] computeSHA1 label = SHA1.hash $ B.concat [ label, premasterSecret, c, s ]
generateMasterSecret_TLS :: PRF -> Bytes -> ClientRandom -> ServerRandom -> Bytes generateMasterSecret_TLS :: PRF -> Bytes -> ClientRandom -> ServerRandom -> Bytes
generateMasterSecret_TLS prf premasterSecret (ClientRandom c) (ServerRandom s) = generateMasterSecret_TLS prf premasterSecret (ClientRandom c) (ServerRandom s) =
prf premasterSecret seed 48 prf premasterSecret seed 48
where where
seed = B.concat [ "master secret", c, s ] seed = B.concat [ "master secret", c, s ]
generateMasterSecret :: Version -> Bytes -> ClientRandom -> ServerRandom -> Bytes generateMasterSecret :: Version -> Bytes -> ClientRandom -> ServerRandom -> Bytes
generateMasterSecret SSL2 = generateMasterSecret_SSL generateMasterSecret SSL2 = generateMasterSecret_SSL
@ -483,15 +483,15 @@ generateMasterSecret TLS12 = generateMasterSecret_TLS prf_SHA256
generateKeyBlock_TLS :: PRF -> ClientRandom -> ServerRandom -> Bytes -> Int -> Bytes generateKeyBlock_TLS :: PRF -> ClientRandom -> ServerRandom -> Bytes -> Int -> Bytes
generateKeyBlock_TLS prf (ClientRandom c) (ServerRandom s) mastersecret kbsize = generateKeyBlock_TLS prf (ClientRandom c) (ServerRandom s) mastersecret kbsize =
prf mastersecret seed kbsize where seed = B.concat [ "key expansion", s, c ] prf mastersecret seed kbsize where seed = B.concat [ "key expansion", s, c ]
generateKeyBlock_SSL :: ClientRandom -> ServerRandom -> Bytes -> Int -> Bytes generateKeyBlock_SSL :: ClientRandom -> ServerRandom -> Bytes -> Int -> Bytes
generateKeyBlock_SSL (ClientRandom c) (ServerRandom s) mastersecret kbsize = generateKeyBlock_SSL (ClientRandom c) (ServerRandom s) mastersecret kbsize =
B.concat $ map computeMD5 $ take ((kbsize `div` 16) + 1) labels B.concat $ map computeMD5 $ take ((kbsize `div` 16) + 1) labels
where where
labels = [ uncurry BC.replicate x | x <- zip [1..] ['A'..'Z'] ] labels = [ uncurry BC.replicate x | x <- zip [1..] ['A'..'Z'] ]
computeMD5 label = MD5.hash $ B.concat [ mastersecret, computeSHA1 label ] computeMD5 label = MD5.hash $ B.concat [ mastersecret, computeSHA1 label ]
computeSHA1 label = SHA1.hash $ B.concat [ label, mastersecret, s, c ] computeSHA1 label = SHA1.hash $ B.concat [ label, mastersecret, s, c ]
generateKeyBlock :: Version -> ClientRandom -> ServerRandom -> Bytes -> Int -> Bytes generateKeyBlock :: Version -> ClientRandom -> ServerRandom -> Bytes -> Int -> Bytes
generateKeyBlock SSL2 = generateKeyBlock_SSL generateKeyBlock SSL2 = generateKeyBlock_SSL
@ -502,29 +502,29 @@ generateKeyBlock TLS12 = generateKeyBlock_TLS prf_SHA256
generateFinished_TLS :: PRF -> Bytes -> Bytes -> HashCtx -> Bytes generateFinished_TLS :: PRF -> Bytes -> Bytes -> HashCtx -> Bytes
generateFinished_TLS prf label mastersecret hashctx = prf mastersecret seed 12 generateFinished_TLS prf label mastersecret hashctx = prf mastersecret seed 12
where where
seed = B.concat [ label, hashFinal hashctx ] seed = B.concat [ label, hashFinal hashctx ]
generateFinished_SSL :: Bytes -> Bytes -> HashCtx -> Bytes generateFinished_SSL :: Bytes -> Bytes -> HashCtx -> Bytes
generateFinished_SSL sender mastersecret hashctx = B.concat [md5hash, sha1hash] generateFinished_SSL sender mastersecret hashctx = B.concat [md5hash, sha1hash]
where where
md5hash = MD5.hash $ B.concat [ mastersecret, pad2, md5left ] md5hash = MD5.hash $ B.concat [ mastersecret, pad2, md5left ]
sha1hash = SHA1.hash $ B.concat [ mastersecret, B.take 40 pad2, sha1left ] sha1hash = SHA1.hash $ B.concat [ mastersecret, B.take 40 pad2, sha1left ]
lefthash = hashFinal $ flip hashUpdateSSL (pad1, B.take 40 pad1) lefthash = hashFinal $ flip hashUpdateSSL (pad1, B.take 40 pad1)
$ foldl hashUpdate hashctx [sender,mastersecret] $ foldl hashUpdate hashctx [sender,mastersecret]
(md5left,sha1left) = B.splitAt 16 lefthash (md5left,sha1left) = B.splitAt 16 lefthash
pad2 = B.replicate 48 0x5c pad2 = B.replicate 48 0x5c
pad1 = B.replicate 48 0x36 pad1 = B.replicate 48 0x36
generateClientFinished :: Version -> Bytes -> HashCtx -> Bytes generateClientFinished :: Version -> Bytes -> HashCtx -> Bytes
generateClientFinished ver generateClientFinished ver
| ver < TLS10 = generateFinished_SSL "CLNT" | ver < TLS10 = generateFinished_SSL "CLNT"
| ver < TLS12 = generateFinished_TLS prf_MD5SHA1 "client finished" | ver < TLS12 = generateFinished_TLS prf_MD5SHA1 "client finished"
| otherwise = generateFinished_TLS prf_SHA256 "client finished" | otherwise = generateFinished_TLS prf_SHA256 "client finished"
generateServerFinished :: Version -> Bytes -> HashCtx -> Bytes generateServerFinished :: Version -> Bytes -> HashCtx -> Bytes
generateServerFinished ver generateServerFinished ver
| ver < TLS10 = generateFinished_SSL "SRVR" | ver < TLS10 = generateFinished_SSL "SRVR"
| ver < TLS12 = generateFinished_TLS prf_MD5SHA1 "server finished" | ver < TLS12 = generateFinished_TLS prf_MD5SHA1 "server finished"
| otherwise = generateFinished_TLS prf_SHA256 "server finished" | otherwise = generateFinished_TLS prf_SHA256 "server finished"

View file

@ -37,75 +37,75 @@ processPacket (Record ProtocolType_AppData _ fragment) = return $ AppData $ frag
processPacket (Record ProtocolType_Alert _ fragment) = return . Alert =<< returnEither (decodeAlerts $ fragmentGetBytes fragment) processPacket (Record ProtocolType_Alert _ fragment) = return . Alert =<< returnEither (decodeAlerts $ fragmentGetBytes fragment)
processPacket (Record ProtocolType_ChangeCipherSpec _ fragment) = do processPacket (Record ProtocolType_ChangeCipherSpec _ fragment) = do
returnEither $ decodeChangeCipherSpec $ fragmentGetBytes fragment returnEither $ decodeChangeCipherSpec $ fragmentGetBytes fragment
switchRxEncryption switchRxEncryption
return ChangeCipherSpec return ChangeCipherSpec
processPacket (Record ProtocolType_Handshake ver fragment) = do processPacket (Record ProtocolType_Handshake ver fragment) = do
keyxchg <- getCipherKeyExchangeType keyxchg <- getCipherKeyExchangeType
npn <- getExtensionNPN npn <- getExtensionNPN
let currentparams = CurrentParams let currentparams = CurrentParams
{ cParamsVersion = ver { cParamsVersion = ver
, cParamsKeyXchgType = maybe CipherKeyExchange_RSA id $ keyxchg , cParamsKeyXchgType = maybe CipherKeyExchange_RSA id $ keyxchg
, cParamsSupportNPN = npn , cParamsSupportNPN = npn
} }
handshakes <- returnEither (decodeHandshakes $ fragmentGetBytes fragment) handshakes <- returnEither (decodeHandshakes $ fragmentGetBytes fragment)
hss <- forM handshakes $ \(ty, content) -> do hss <- forM handshakes $ \(ty, content) -> do
case decodeHandshake currentparams ty content of case decodeHandshake currentparams ty content of
Left err -> throwError err Left err -> throwError err
Right hs -> return hs Right hs -> return hs
return $ Handshake hss return $ Handshake hss
processHandshake :: Handshake -> TLSSt () processHandshake :: Handshake -> TLSSt ()
processHandshake hs = do processHandshake hs = do
clientmode <- isClientContext clientmode <- isClientContext
case hs of case hs of
ClientHello cver ran _ _ _ ex -> unless clientmode $ do ClientHello cver ran _ _ _ ex -> unless clientmode $ do
mapM_ processClientExtension ex mapM_ processClientExtension ex
startHandshakeClient cver ran startHandshakeClient cver ran
Certificates certs -> when clientmode $ do processCertificates certs Certificates certs -> when clientmode $ do processCertificates certs
ClientKeyXchg content -> unless clientmode $ do ClientKeyXchg content -> unless clientmode $ do
processClientKeyXchg content processClientKeyXchg content
NextProtocolNegotiation selected_protocol -> NextProtocolNegotiation selected_protocol ->
unless clientmode $ do unless clientmode $ do
setNegotiatedProtocol selected_protocol setNegotiatedProtocol selected_protocol
Finished fdata -> processClientFinished fdata Finished fdata -> processClientFinished fdata
_ -> return () _ -> return ()
when (finishHandshakeTypeMaterial $ typeOfHandshake hs) (updateHandshakeDigest $ encodeHandshake hs) when (finishHandshakeTypeMaterial $ typeOfHandshake hs) (updateHandshakeDigest $ encodeHandshake hs)
where where
-- secure renegotiation -- secure renegotiation
processClientExtension (0xff01, content) = do processClientExtension (0xff01, content) = do
v <- getVerifiedData True v <- getVerifiedData True
let bs = encodeExtSecureRenegotiation v Nothing let bs = encodeExtSecureRenegotiation v Nothing
when (bs /= content) $ throwError $ when (bs /= content) $ throwError $
Error_Protocol ("client verified data not matching: " ++ show v ++ ":" ++ show content, True, HandshakeFailure) Error_Protocol ("client verified data not matching: " ++ show v ++ ":" ++ show content, True, HandshakeFailure)
setSecureRenegotiation True setSecureRenegotiation True
-- unknown extensions -- unknown extensions
processClientExtension _ = return () processClientExtension _ = return ()
decryptRSA :: ByteString -> TLSSt (Either KxError ByteString) decryptRSA :: ByteString -> TLSSt (Either KxError ByteString)
decryptRSA econtent = do decryptRSA econtent = do
ver <- stVersion <$> get ver <- stVersion <$> get
rsapriv <- fromJust "rsa private key" . hstRSAPrivateKey . fromJust "handshake" . stHandshake <$> get rsapriv <- fromJust "rsa private key" . hstRSAPrivateKey . fromJust "handshake" . stHandshake <$> get
return $ kxDecrypt rsapriv (if ver < TLS10 then econtent else B.drop 2 econtent) return $ kxDecrypt rsapriv (if ver < TLS10 then econtent else B.drop 2 econtent)
processServerHello :: Handshake -> TLSSt () processServerHello :: Handshake -> TLSSt ()
processServerHello (ServerHello sver ran _ _ _ ex) = do processServerHello (ServerHello sver ran _ _ _ ex) = do
-- FIXME notify the user to take action if the extension requested is missing -- FIXME notify the user to take action if the extension requested is missing
-- secreneg <- getSecureRenegotiation -- secreneg <- getSecureRenegotiation
-- when (secreneg && (isNothing $ lookup 0xff01 ex)) $ ... -- when (secreneg && (isNothing $ lookup 0xff01 ex)) $ ...
mapM_ processServerExtension ex mapM_ processServerExtension ex
setServerRandom ran setServerRandom ran
setVersion sver setVersion sver
where where
processServerExtension (0xff01, content) = do processServerExtension (0xff01, content) = do
cv <- getVerifiedData True cv <- getVerifiedData True
sv <- getVerifiedData False sv <- getVerifiedData False
let bs = encodeExtSecureRenegotiation cv (Just sv) let bs = encodeExtSecureRenegotiation cv (Just sv)
when (bs /= content) $ throwError $ Error_Protocol ("server secure renegotiation data not matching", True, HandshakeFailure) when (bs /= content) $ throwError $ Error_Protocol ("server secure renegotiation data not matching", True, HandshakeFailure)
return () return ()
processServerExtension _ = return () processServerExtension _ = return ()
processServerHello _ = error "processServerHello called on wrong type" processServerHello _ = error "processServerHello called on wrong type"
-- process the client key exchange message. the protocol expects the initial -- process the client key exchange message. the protocol expects the initial
@ -113,29 +113,29 @@ processServerHello _ = error "processServerHello called on wrong type"
-- in case the version mismatch, generate a random master secret -- in case the version mismatch, generate a random master secret
processClientKeyXchg :: ByteString -> TLSSt () processClientKeyXchg :: ByteString -> TLSSt ()
processClientKeyXchg encryptedPremaster = do processClientKeyXchg encryptedPremaster = do
expectedVer <- hstClientVersion . fromJust "handshake" . stHandshake <$> get expectedVer <- hstClientVersion . fromJust "handshake" . stHandshake <$> get
random <- genTLSRandom 48 random <- genTLSRandom 48
ePremaster <- decryptRSA encryptedPremaster ePremaster <- decryptRSA encryptedPremaster
case ePremaster of case ePremaster of
Left _ -> setMasterSecretFromPre random Left _ -> setMasterSecretFromPre random
Right premaster -> case decodePreMasterSecret premaster of Right premaster -> case decodePreMasterSecret premaster of
Left _ -> setMasterSecretFromPre random Left _ -> setMasterSecretFromPre random
Right (ver, _) Right (ver, _)
| ver /= expectedVer -> setMasterSecretFromPre random | ver /= expectedVer -> setMasterSecretFromPre random
| otherwise -> setMasterSecretFromPre premaster | otherwise -> setMasterSecretFromPre premaster
processClientFinished :: FinishedData -> TLSSt () processClientFinished :: FinishedData -> TLSSt ()
processClientFinished fdata = do processClientFinished fdata = do
cc <- stClientContext <$> get cc <- stClientContext <$> get
expected <- getHandshakeDigest (not cc) expected <- getHandshakeDigest (not cc)
when (expected /= fdata) $ do when (expected /= fdata) $ do
throwError $ Error_Protocol("bad record mac", True, BadRecordMac) throwError $ Error_Protocol("bad record mac", True, BadRecordMac)
updateVerifiedData False fdata updateVerifiedData False fdata
return () return ()
processCertificates :: [X509] -> TLSSt () processCertificates :: [X509] -> TLSSt ()
processCertificates certs = do processCertificates certs = do
let (X509 mainCert _ _ _ _) = head certs let (X509 mainCert _ _ _ _) = head certs
case certPubKey mainCert of case certPubKey mainCert of
PubKeyRSA pubkey -> setPublicKey (PubRSA pubkey) PubKeyRSA pubkey -> setPublicKey (PubRSA pubkey)
_ -> return () _ -> return ()

View file

@ -12,20 +12,20 @@
-- higher-level clients. -- higher-level clients.
-- --
module Network.TLS.Record module Network.TLS.Record
( Record(..) ( Record(..)
, Fragment , Fragment
, fragmentGetBytes , fragmentGetBytes
, fragmentPlaintext , fragmentPlaintext
, fragmentCiphertext , fragmentCiphertext
, recordToRaw , recordToRaw
, rawToRecord , rawToRecord
, recordToHeader , recordToHeader
, Plaintext , Plaintext
, Compressed , Compressed
, Ciphertext , Ciphertext
, engageRecord , engageRecord
, disengageRecord , disengageRecord
) where ) where
import Network.TLS.Record.Types import Network.TLS.Record.Types
import Network.TLS.Record.Engage import Network.TLS.Record.Engage

View file

@ -9,8 +9,8 @@
-- The record is decrypted, checked for integrity and then decompressed. -- The record is decrypted, checked for integrity and then decompressed.
-- --
module Network.TLS.Record.Disengage module Network.TLS.Record.Disengage
( disengageRecord ( disengageRecord
) where ) where
import Control.Monad.State import Control.Monad.State
import Control.Monad.Error import Control.Monad.Error
@ -30,88 +30,88 @@ disengageRecord = decryptRecord >=> uncompressRecord
uncompressRecord :: Record Compressed -> TLSSt (Record Plaintext) uncompressRecord :: Record Compressed -> TLSSt (Record Plaintext)
uncompressRecord record = onRecordFragment record $ fragmentUncompress $ \bytes -> uncompressRecord record = onRecordFragment record $ fragmentUncompress $ \bytes ->
withCompression $ compressionInflate bytes withCompression $ compressionInflate bytes
decryptRecord :: Record Ciphertext -> TLSSt (Record Compressed) decryptRecord :: Record Ciphertext -> TLSSt (Record Compressed)
decryptRecord record = onRecordFragment record $ fragmentUncipher $ \e -> do decryptRecord record = onRecordFragment record $ fragmentUncipher $ \e -> do
st <- get st <- get
if stRxEncrypted st if stRxEncrypted st
then decryptData e >>= getCipherData record then decryptData e >>= getCipherData record
else return e else return e
getCipherData :: Record a -> CipherData -> TLSSt ByteString getCipherData :: Record a -> CipherData -> TLSSt ByteString
getCipherData (Record pt ver _) cdata = do getCipherData (Record pt ver _) cdata = do
-- check if the MAC is valid. -- check if the MAC is valid.
macValid <- case cipherDataMAC cdata of macValid <- case cipherDataMAC cdata of
Nothing -> return True Nothing -> return True
Just digest -> do Just digest -> do
let new_hdr = Header pt ver (fromIntegral $ B.length $ cipherDataContent cdata) let new_hdr = Header pt ver (fromIntegral $ B.length $ cipherDataContent cdata)
expected_digest <- makeDigest False new_hdr $ cipherDataContent cdata expected_digest <- makeDigest False new_hdr $ cipherDataContent cdata
return (expected_digest `bytesEq` digest) return (expected_digest `bytesEq` digest)
-- check if the padding is filled with the correct pattern if it exists -- check if the padding is filled with the correct pattern if it exists
paddingValid <- case cipherDataPadding cdata of paddingValid <- case cipherDataPadding cdata of
Nothing -> return True Nothing -> return True
Just pad -> do Just pad -> do
cver <- gets stVersion cver <- gets stVersion
let b = B.length pad - 1 let b = B.length pad - 1
return (if cver < TLS10 then True else B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad) return (if cver < TLS10 then True else B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad)
unless (macValid &&! paddingValid) $ do unless (macValid &&! paddingValid) $ do
throwError $ Error_Protocol ("bad record mac", True, BadRecordMac) throwError $ Error_Protocol ("bad record mac", True, BadRecordMac)
return $ cipherDataContent cdata return $ cipherDataContent cdata
decryptData :: Bytes -> TLSSt CipherData decryptData :: Bytes -> TLSSt CipherData
decryptData econtent = do decryptData econtent = do
st <- get st <- get
let cipher = fromJust "cipher" $ stCipher st let cipher = fromJust "cipher" $ stCipher st
let bulk = cipherBulk cipher let bulk = cipherBulk cipher
let cst = fromJust "rx crypt state" $ stRxCryptState st let cst = fromJust "rx crypt state" $ stRxCryptState st
let digestSize = hashSize $ cipherHash cipher let digestSize = hashSize $ cipherHash cipher
let writekey = cstKey cst let writekey = cstKey cst
case bulkF bulk of case bulkF bulk of
BulkNoneF -> do BulkNoneF -> do
let contentlen = B.length econtent - digestSize let contentlen = B.length econtent - digestSize
case partition3 econtent (contentlen, digestSize, 0) of case partition3 econtent (contentlen, digestSize, 0) of
Nothing -> Nothing ->
throwError $ Error_Misc "partition3 failed" throwError $ Error_Misc "partition3 failed"
Just (content, mac, _) -> Just (content, mac, _) ->
return $ CipherData return $ CipherData
{ cipherDataContent = content { cipherDataContent = content
, cipherDataMAC = Just mac , cipherDataMAC = Just mac
, cipherDataPadding = Nothing , cipherDataPadding = Nothing
} }
BulkBlockF _ decryptF -> do BulkBlockF _ decryptF -> do
{- update IV -} {- update IV -}
let (iv, econtent') = let (iv, econtent') =
if hasExplicitBlockIV $ stVersion st if hasExplicitBlockIV $ stVersion st
then B.splitAt (bulkIVSize bulk) econtent then B.splitAt (bulkIVSize bulk) econtent
else (cstIV cst, econtent) else (cstIV cst, econtent)
let newiv = fromJust "new iv" $ takelast (bulkBlockSize bulk) econtent' let newiv = fromJust "new iv" $ takelast (bulkBlockSize bulk) econtent'
put $ st { stRxCryptState = Just $ cst { cstIV = newiv } } put $ st { stRxCryptState = Just $ cst { cstIV = newiv } }
let content' = decryptF writekey iv econtent' let content' = decryptF writekey iv econtent'
let paddinglength = fromIntegral (B.last content') + 1 let paddinglength = fromIntegral (B.last content') + 1
let contentlen = B.length content' - paddinglength - digestSize let contentlen = B.length content' - paddinglength - digestSize
let (content, mac, padding) = fromJust "p3" $ partition3 content' (contentlen, digestSize, paddinglength) let (content, mac, padding) = fromJust "p3" $ partition3 content' (contentlen, digestSize, paddinglength)
return $ CipherData return $ CipherData
{ cipherDataContent = content { cipherDataContent = content
, cipherDataMAC = Just mac , cipherDataMAC = Just mac
, cipherDataPadding = Just padding , cipherDataPadding = Just padding
} }
BulkStreamF initF _ decryptF -> do BulkStreamF initF _ decryptF -> do
let iv = cstIV cst let iv = cstIV cst
let (content', newiv) = decryptF (if iv /= B.empty then iv else initF writekey) econtent let (content', newiv) = decryptF (if iv /= B.empty then iv else initF writekey) econtent
{- update Ctx -} {- update Ctx -}
let contentlen = B.length content' - digestSize let contentlen = B.length content' - digestSize
let (content, mac, _) = fromJust "p3" $ partition3 content' (contentlen, digestSize, 0) let (content, mac, _) = fromJust "p3" $ partition3 content' (contentlen, digestSize, 0)
put $ st { stRxCryptState = Just $ cst { cstIV = newiv } } put $ st { stRxCryptState = Just $ cst { cstIV = newiv } }
return $ CipherData return $ CipherData
{ cipherDataContent = content { cipherDataContent = content
, cipherDataMAC = Just mac , cipherDataMAC = Just mac
, cipherDataPadding = Nothing , cipherDataPadding = Nothing
} }

View file

@ -9,8 +9,8 @@
-- The record is compressed, added some integrity field, then encrypted. -- The record is compressed, added some integrity field, then encrypted.
-- --
module Network.TLS.Record.Engage module Network.TLS.Record.Engage
( engageRecord ( engageRecord
) where ) where
import Control.Monad.State import Control.Monad.State
@ -28,8 +28,8 @@ engageRecord = compressRecord >=> encryptRecord
compressRecord :: Record Plaintext -> TLSSt (Record Compressed) compressRecord :: Record Plaintext -> TLSSt (Record Compressed)
compressRecord record = compressRecord record =
onRecordFragment record $ fragmentCompress $ \bytes -> do onRecordFragment record $ fragmentCompress $ \bytes -> do
withCompression $ compressionDeflate bytes withCompression $ compressionDeflate bytes
{- {-
- when Tx Encrypted is set, we pass the data through encryptContent, otherwise - when Tx Encrypted is set, we pass the data through encryptContent, otherwise
@ -37,53 +37,53 @@ compressRecord record =
-} -}
encryptRecord :: Record Compressed -> TLSSt (Record Ciphertext) encryptRecord :: Record Compressed -> TLSSt (Record Ciphertext)
encryptRecord record = onRecordFragment record $ fragmentCipher $ \bytes -> do encryptRecord record = onRecordFragment record $ fragmentCipher $ \bytes -> do
st <- get st <- get
if stTxEncrypted st if stTxEncrypted st
then encryptContent record bytes then encryptContent record bytes
else return bytes else return bytes
encryptContent :: Record Compressed -> ByteString -> TLSSt ByteString encryptContent :: Record Compressed -> ByteString -> TLSSt ByteString
encryptContent record content = do encryptContent record content = do
digest <- makeDigest True (recordToHeader record) content digest <- makeDigest True (recordToHeader record) content
encryptData $ B.concat [content, digest] encryptData $ B.concat [content, digest]
encryptData :: ByteString -> TLSSt ByteString encryptData :: ByteString -> TLSSt ByteString
encryptData content = do encryptData content = do
st <- get st <- get
let cipher = fromJust "cipher" $ stCipher st let cipher = fromJust "cipher" $ stCipher st
let bulk = cipherBulk cipher let bulk = cipherBulk cipher
let cst = fromJust "tx crypt state" $ stTxCryptState st let cst = fromJust "tx crypt state" $ stTxCryptState st
let writekey = cstKey cst let writekey = cstKey cst
case bulkF bulk of case bulkF bulk of
BulkNoneF -> return content BulkNoneF -> return content
BulkBlockF encrypt _ -> do BulkBlockF encrypt _ -> do
let blockSize = fromIntegral $ bulkBlockSize bulk let blockSize = fromIntegral $ bulkBlockSize bulk
let msg_len = B.length content let msg_len = B.length content
let padding = if blockSize > 0 let padding = if blockSize > 0
then then
let padbyte = blockSize - (msg_len `mod` blockSize) in let padbyte = blockSize - (msg_len `mod` blockSize) in
let padbyte' = if padbyte == 0 then blockSize else padbyte in let padbyte' = if padbyte == 0 then blockSize else padbyte in
B.replicate padbyte' (fromIntegral (padbyte' - 1)) B.replicate padbyte' (fromIntegral (padbyte' - 1))
else else
B.empty B.empty
-- before TLS 1.1, the block cipher IV is made of the residual of the previous block. -- before TLS 1.1, the block cipher IV is made of the residual of the previous block.
iv <- if hasExplicitBlockIV $ stVersion st iv <- if hasExplicitBlockIV $ stVersion st
then genTLSRandom (bulkIVSize bulk) then genTLSRandom (bulkIVSize bulk)
else return $ cstIV cst else return $ cstIV cst
let e = encrypt writekey iv (B.concat [ content, padding ]) let e = encrypt writekey iv (B.concat [ content, padding ])
if hasExplicitBlockIV $ stVersion st if hasExplicitBlockIV $ stVersion st
then return $ B.concat [iv,e] then return $ B.concat [iv,e]
else do else do
let newiv = fromJust "new iv" $ takelast (bulkIVSize bulk) e let newiv = fromJust "new iv" $ takelast (bulkIVSize bulk) e
put $ st { stTxCryptState = Just $ cst { cstIV = newiv } } put $ st { stTxCryptState = Just $ cst { cstIV = newiv } }
return e return e
BulkStreamF initF encryptF _ -> do BulkStreamF initF encryptF _ -> do
let iv = cstIV cst let iv = cstIV cst
let (e, newiv) = encryptF (if iv /= B.empty then iv else initF writekey) content let (e, newiv) = encryptF (if iv /= B.empty then iv else initF writekey) content
put $ st { stTxCryptState = Just $ cst { cstIV = newiv } } put $ st { stTxCryptState = Just $ cst { cstIV = newiv } }
return e return e

View file

@ -13,30 +13,30 @@
-- higher-level clients. -- higher-level clients.
-- --
module Network.TLS.Record.Types module Network.TLS.Record.Types
( Header(..) ( Header(..)
, ProtocolType(..) , ProtocolType(..)
, packetType , packetType
-- * TLS Records -- * TLS Records
, Record(..) , Record(..)
-- * TLS Record fragment and constructors -- * TLS Record fragment and constructors
, Fragment , Fragment
, fragmentPlaintext , fragmentPlaintext
, fragmentCiphertext , fragmentCiphertext
, fragmentGetBytes , fragmentGetBytes
, Plaintext , Plaintext
, Compressed , Compressed
, Ciphertext , Ciphertext
-- * manipulate record -- * manipulate record
, onRecordFragment , onRecordFragment
, fragmentCompress , fragmentCompress
, fragmentCipher , fragmentCipher
, fragmentUncipher , fragmentUncipher
, fragmentUncompress , fragmentUncompress
-- * serialize record -- * serialize record
, rawToRecord , rawToRecord
, recordToRaw , recordToRaw
, recordToHeader , recordToHeader
) where ) where
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.State import Network.TLS.State

View file

@ -29,17 +29,17 @@ import Network.TLS.Crypto
-} -}
makeRecord :: Packet -> TLSSt (Record Plaintext) makeRecord :: Packet -> TLSSt (Record Plaintext)
makeRecord pkt = do makeRecord pkt = do
ver <- stVersion <$> get ver <- stVersion <$> get
content <- writePacketContent pkt content <- writePacketContent pkt
return $ Record (packetType pkt) ver (fragmentPlaintext content) return $ Record (packetType pkt) ver (fragmentPlaintext content)
{- {-
- Handshake data need to update a digest - Handshake data need to update a digest
-} -}
processRecord :: Record Plaintext -> TLSSt (Record Plaintext) processRecord :: Record Plaintext -> TLSSt (Record Plaintext)
processRecord record@(Record ty _ fragment) = do processRecord record@(Record ty _ fragment) = do
when (ty == ProtocolType_Handshake) (updateHandshakeDigest $ fragmentGetBytes fragment) when (ty == ProtocolType_Handshake) (updateHandshakeDigest $ fragmentGetBytes fragment)
return record return record
{- {-
- ChangeCipherSpec state change need to be handled after encryption otherwise - ChangeCipherSpec state change need to be handled after encryption otherwise
@ -48,7 +48,7 @@ processRecord record@(Record ty _ fragment) = do
-} -}
postprocessRecord :: Record Ciphertext -> TLSSt (Record Ciphertext) postprocessRecord :: Record Ciphertext -> TLSSt (Record Ciphertext)
postprocessRecord record@(Record ProtocolType_ChangeCipherSpec _ _) = postprocessRecord record@(Record ProtocolType_ChangeCipherSpec _ _) =
switchTxEncryption >> return record switchTxEncryption >> return record
postprocessRecord record = return record postprocessRecord record = return record
{- {-
@ -56,7 +56,7 @@ postprocessRecord record = return record
-} -}
encodeRecord :: Record Ciphertext -> TLSSt ByteString encodeRecord :: Record Ciphertext -> TLSSt ByteString
encodeRecord record = return $ B.concat [ encodeHeader hdr, content ] encodeRecord record = return $ B.concat [ encodeHeader hdr, content ]
where (hdr, content) = recordToRaw record where (hdr, content) = recordToRaw record
{- {-
- just update TLS state machine - just update TLS state machine
@ -66,9 +66,9 @@ preProcessPacket (Alert _) = return ()
preProcessPacket (AppData _) = return () preProcessPacket (AppData _) = return ()
preProcessPacket (ChangeCipherSpec) = return () preProcessPacket (ChangeCipherSpec) = return ()
preProcessPacket (Handshake hss) = forM_ hss $ \hs -> do preProcessPacket (Handshake hss) = forM_ hss $ \hs -> do
case hs of case hs of
Finished fdata -> updateVerifiedData True fdata Finished fdata -> updateVerifiedData True fdata
_ -> return () _ -> return ()
{- {-
- writePacket transform a packet into marshalled data related to current state - writePacket transform a packet into marshalled data related to current state
@ -76,8 +76,8 @@ preProcessPacket (Handshake hss) = forM_ hss $ \hs -> do
-} -}
writePacket :: Packet -> TLSSt ByteString writePacket :: Packet -> TLSSt ByteString
writePacket pkt = do writePacket pkt = do
preProcessPacket pkt preProcessPacket pkt
makeRecord pkt >>= processRecord >>= engageRecord >>= postprocessRecord >>= encodeRecord makeRecord pkt >>= processRecord >>= engageRecord >>= postprocessRecord >>= encodeRecord
{------------------------------------------------------------------------------} {------------------------------------------------------------------------------}
{- SENDING Helpers -} {- SENDING Helpers -}
@ -88,11 +88,11 @@ writePacket pkt = do
-} -}
encryptRSA :: ByteString -> TLSSt ByteString encryptRSA :: ByteString -> TLSSt ByteString
encryptRSA content = do encryptRSA content = do
st <- get st <- get
let rsakey = fromJust "rsa public key" $ hstRSAPublicKey $ fromJust "handshake" $ stHandshake st let rsakey = fromJust "rsa public key" $ hstRSAPublicKey $ fromJust "handshake" $ stHandshake st
case withTLSRNG (stRandomGen st) (\g -> kxEncrypt g rsakey content) of case withTLSRNG (stRandomGen st) (\g -> kxEncrypt g rsakey content) of
Left err -> fail ("rsa encrypt failed: " ++ show err) Left err -> fail ("rsa encrypt failed: " ++ show err)
Right (econtent, rng') -> put (st { stRandomGen = rng' }) >> return econtent Right (econtent, rng') -> put (st { stRandomGen = rng' }) >> return econtent
writePacketContent :: Packet -> TLSSt ByteString writePacketContent :: Packet -> TLSSt ByteString
writePacketContent (Handshake hss) = return $ encodeHandshakes hss writePacketContent (Handshake hss) = return $ encodeHandshakes hss

View file

@ -10,51 +10,51 @@
-- which is use by the Receiving module and the Sending module. -- which is use by the Receiving module and the Sending module.
-- --
module Network.TLS.State module Network.TLS.State
( TLSState(..) ( TLSState(..)
, TLSSt , TLSSt
, runTLSState , runTLSState
, TLSHandshakeState(..) , TLSHandshakeState(..)
, TLSCryptState(..) , TLSCryptState(..)
, TLSMacState(..) , TLSMacState(..)
, newTLSState , newTLSState
, genTLSRandom , genTLSRandom
, withTLSRNG , withTLSRNG
, withCompression , withCompression
, assert -- FIXME move somewhere else (Internal.hs ?) , assert -- FIXME move somewhere else (Internal.hs ?)
, updateVerifiedData , updateVerifiedData
, finishHandshakeTypeMaterial , finishHandshakeTypeMaterial
, finishHandshakeMaterial , finishHandshakeMaterial
, makeDigest , makeDigest
, setMasterSecret , setMasterSecret
, setMasterSecretFromPre , setMasterSecretFromPre
, setPublicKey , setPublicKey
, setPrivateKey , setPrivateKey
, setKeyBlock , setKeyBlock
, setVersion , setVersion
, setCipher , setCipher
, setServerRandom , setServerRandom
, setSecureRenegotiation , setSecureRenegotiation
, getSecureRenegotiation , getSecureRenegotiation
, setExtensionNPN , setExtensionNPN
, getExtensionNPN , getExtensionNPN
, setNegotiatedProtocol , setNegotiatedProtocol
, getNegotiatedProtocol , getNegotiatedProtocol
, setServerNextProtocolSuggest , setServerNextProtocolSuggest
, getServerNextProtocolSuggest , getServerNextProtocolSuggest
, getVerifiedData , getVerifiedData
, setSession , setSession
, getSession , getSession
, getSessionData , getSessionData
, isSessionResuming , isSessionResuming
, switchTxEncryption , switchTxEncryption
, switchRxEncryption , switchRxEncryption
, getCipherKeyExchangeType , getCipherKeyExchangeType
, isClientContext , isClientContext
, startHandshakeClient , startHandshakeClient
, updateHandshakeDigest , updateHandshakeDigest
, getHandshakeDigest , getHandshakeDigest
, endHandshake , endHandshake
) where ) where
import Data.Word import Data.Word
import Data.Maybe (isNothing) import Data.Maybe (isNothing)
@ -75,138 +75,138 @@ import Crypto.Random
assert :: Monad m => String -> [(String,Bool)] -> m () assert :: Monad m => String -> [(String,Bool)] -> m ()
assert fctname list = forM_ list $ \ (name, assumption) -> do assert fctname list = forM_ list $ \ (name, assumption) -> do
when assumption $ fail (fctname ++ ": assumption about " ++ name ++ " failed") when assumption $ fail (fctname ++ ": assumption about " ++ name ++ " failed")
data TLSCryptState = TLSCryptState data TLSCryptState = TLSCryptState
{ cstKey :: !Bytes { cstKey :: !Bytes
, cstIV :: !Bytes , cstIV :: !Bytes
, cstMacSecret :: !Bytes , cstMacSecret :: !Bytes
} deriving (Show) } deriving (Show)
data TLSMacState = TLSMacState data TLSMacState = TLSMacState
{ msSequence :: Word64 { msSequence :: Word64
} deriving (Show) } deriving (Show)
data TLSHandshakeState = TLSHandshakeState data TLSHandshakeState = TLSHandshakeState
{ hstClientVersion :: !(Version) { hstClientVersion :: !(Version)
, hstClientRandom :: !ClientRandom , hstClientRandom :: !ClientRandom
, hstServerRandom :: !(Maybe ServerRandom) , hstServerRandom :: !(Maybe ServerRandom)
, hstMasterSecret :: !(Maybe Bytes) , hstMasterSecret :: !(Maybe Bytes)
, hstRSAPublicKey :: !(Maybe PublicKey) , hstRSAPublicKey :: !(Maybe PublicKey)
, hstRSAPrivateKey :: !(Maybe PrivateKey) , hstRSAPrivateKey :: !(Maybe PrivateKey)
, hstHandshakeDigest :: !HashCtx , hstHandshakeDigest :: !HashCtx
} deriving (Show) } deriving (Show)
data StateRNG = forall g . CryptoRandomGen g => StateRNG g data StateRNG = forall g . CryptoRandomGen g => StateRNG g
instance Show StateRNG where instance Show StateRNG where
show _ = "rng[..]" show _ = "rng[..]"
data TLSState = TLSState data TLSState = TLSState
{ stClientContext :: Bool { stClientContext :: Bool
, stVersion :: !Version , stVersion :: !Version
, stHandshake :: !(Maybe TLSHandshakeState) , stHandshake :: !(Maybe TLSHandshakeState)
, stSession :: Session , stSession :: Session
, stSessionResuming :: Bool , stSessionResuming :: Bool
, stTxEncrypted :: Bool , stTxEncrypted :: Bool
, stRxEncrypted :: Bool , stRxEncrypted :: Bool
, stTxCryptState :: !(Maybe TLSCryptState) , stTxCryptState :: !(Maybe TLSCryptState)
, stRxCryptState :: !(Maybe TLSCryptState) , stRxCryptState :: !(Maybe TLSCryptState)
, stTxMacState :: !(Maybe TLSMacState) , stTxMacState :: !(Maybe TLSMacState)
, stRxMacState :: !(Maybe TLSMacState) , stRxMacState :: !(Maybe TLSMacState)
, stCipher :: Maybe Cipher , stCipher :: Maybe Cipher
, stCompression :: Compression , stCompression :: Compression
, stRandomGen :: StateRNG , stRandomGen :: StateRNG
, stSecureRenegotiation :: Bool -- RFC 5746 , stSecureRenegotiation :: Bool -- RFC 5746
, stClientVerifiedData :: Bytes -- RFC 5746 , stClientVerifiedData :: Bytes -- RFC 5746
, stServerVerifiedData :: Bytes -- RFC 5746 , stServerVerifiedData :: Bytes -- RFC 5746
, stExtensionNPN :: Bool -- NPN draft extension , stExtensionNPN :: Bool -- NPN draft extension
, stNegotiatedProtocol :: Maybe B.ByteString -- NPN protocol , stNegotiatedProtocol :: Maybe B.ByteString -- NPN protocol
, stServerNextProtocolSuggest :: Maybe [B.ByteString] , stServerNextProtocolSuggest :: Maybe [B.ByteString]
} deriving (Show) } deriving (Show)
newtype TLSSt a = TLSSt { runTLSSt :: ErrorT TLSError (State TLSState) a } newtype TLSSt a = TLSSt { runTLSSt :: ErrorT TLSError (State TLSState) a }
deriving (Monad, MonadError TLSError) deriving (Monad, MonadError TLSError)
instance Functor TLSSt where instance Functor TLSSt where
fmap f = TLSSt . fmap f . runTLSSt fmap f = TLSSt . fmap f . runTLSSt
instance MonadState TLSState TLSSt where instance MonadState TLSState TLSSt where
put x = TLSSt (lift $ put x) put x = TLSSt (lift $ put x)
get = TLSSt (lift get) get = TLSSt (lift get)
runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState) runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
runTLSState f st = runState (runErrorT (runTLSSt f)) st runTLSState f st = runState (runErrorT (runTLSSt f)) st
newTLSState :: CryptoRandomGen g => g -> TLSState newTLSState :: CryptoRandomGen g => g -> TLSState
newTLSState rng = TLSState newTLSState rng = TLSState
{ stClientContext = False { stClientContext = False
, stVersion = TLS10 , stVersion = TLS10
, stHandshake = Nothing , stHandshake = Nothing
, stSession = Session Nothing , stSession = Session Nothing
, stSessionResuming = False , stSessionResuming = False
, stTxEncrypted = False , stTxEncrypted = False
, stRxEncrypted = False , stRxEncrypted = False
, stTxCryptState = Nothing , stTxCryptState = Nothing
, stRxCryptState = Nothing , stRxCryptState = Nothing
, stTxMacState = Nothing , stTxMacState = Nothing
, stRxMacState = Nothing , stRxMacState = Nothing
, stCipher = Nothing , stCipher = Nothing
, stCompression = nullCompression , stCompression = nullCompression
, stRandomGen = StateRNG rng , stRandomGen = StateRNG rng
, stSecureRenegotiation = False , stSecureRenegotiation = False
, stClientVerifiedData = B.empty , stClientVerifiedData = B.empty
, stServerVerifiedData = B.empty , stServerVerifiedData = B.empty
, stExtensionNPN = False , stExtensionNPN = False
, stNegotiatedProtocol = Nothing , stNegotiatedProtocol = Nothing
, stServerNextProtocolSuggest = Nothing , stServerNextProtocolSuggest = Nothing
} }
withTLSRNG :: StateRNG -> (forall g . CryptoRandomGen g => g -> Either e (a,g)) -> Either e (a, StateRNG) withTLSRNG :: StateRNG -> (forall g . CryptoRandomGen g => g -> Either e (a,g)) -> Either e (a, StateRNG)
withTLSRNG (StateRNG rng) f = case f rng of withTLSRNG (StateRNG rng) f = case f rng of
Left err -> Left err Left err -> Left err
Right (a, rng') -> Right (a, StateRNG rng') Right (a, rng') -> Right (a, StateRNG rng')
withCompression :: (Compression -> (Compression, a)) -> TLSSt a withCompression :: (Compression -> (Compression, a)) -> TLSSt a
withCompression f = do withCompression f = do
compression <- stCompression <$> get compression <- stCompression <$> get
let (nc, a) = f compression let (nc, a) = f compression
modify (\st -> st { stCompression = nc }) modify (\st -> st { stCompression = nc })
return a return a
genTLSRandom :: (MonadState TLSState m, MonadError TLSError m) => Int -> m Bytes genTLSRandom :: (MonadState TLSState m, MonadError TLSError m) => Int -> m Bytes
genTLSRandom n = do genTLSRandom n = do
st <- get st <- get
case withTLSRNG (stRandomGen st) (genBytes n) of case withTLSRNG (stRandomGen st) (genBytes n) of
Left err -> throwError $ Error_Random $ show err Left err -> throwError $ Error_Random $ show err
Right (bytes, rng') -> put (st { stRandomGen = rng' }) >> return bytes Right (bytes, rng') -> put (st { stRandomGen = rng' }) >> return bytes
makeDigest :: MonadState TLSState m => Bool -> Header -> Bytes -> m Bytes makeDigest :: MonadState TLSState m => Bool -> Header -> Bytes -> m Bytes
makeDigest w hdr content = do makeDigest w hdr content = do
st <- get st <- get
let ver = stVersion st let ver = stVersion st
let cst = fromJust "crypt state" $ if w then stTxCryptState st else stRxCryptState st let cst = fromJust "crypt state" $ if w then stTxCryptState st else stRxCryptState st
let ms = fromJust "mac state" $ if w then stTxMacState st else stRxMacState st let ms = fromJust "mac state" $ if w then stTxMacState st else stRxMacState st
let cipher = fromJust "cipher" $ stCipher st let cipher = fromJust "cipher" $ stCipher st
let hashf = hashF $ cipherHash cipher let hashf = hashF $ cipherHash cipher
let (macF, msg) = let (macF, msg) =
if ver < TLS10 if ver < TLS10
then (macSSL hashf, B.concat [ encodeWord64 $ msSequence ms, encodeHeaderNoVer hdr, content ]) then (macSSL hashf, B.concat [ encodeWord64 $ msSequence ms, encodeHeaderNoVer hdr, content ])
else (hmac hashf 64, B.concat [ encodeWord64 $ msSequence ms, encodeHeader hdr, content ]) else (hmac hashf 64, B.concat [ encodeWord64 $ msSequence ms, encodeHeader hdr, content ])
let digest = macF (cstMacSecret cst) msg let digest = macF (cstMacSecret cst) msg
let newms = ms { msSequence = (msSequence ms) + 1 } let newms = ms { msSequence = (msSequence ms) + 1 }
modify (\_ -> if w then st { stTxMacState = Just newms } else st { stRxMacState = Just newms }) modify (\_ -> if w then st { stTxMacState = Just newms } else st { stRxMacState = Just newms })
return digest return digest
updateVerifiedData :: MonadState TLSState m => Bool -> Bytes -> m () updateVerifiedData :: MonadState TLSState m => Bool -> Bytes -> m ()
updateVerifiedData sending bs = do updateVerifiedData sending bs = do
cc <- isClientContext cc <- isClientContext
if cc /= sending if cc /= sending
then modify (\st -> st { stServerVerifiedData = bs }) then modify (\st -> st { stServerVerifiedData = bs })
else modify (\st -> st { stClientVerifiedData = bs }) else modify (\st -> st { stClientVerifiedData = bs })
finishHandshakeTypeMaterial :: HandshakeType -> Bool finishHandshakeTypeMaterial :: HandshakeType -> Bool
finishHandshakeTypeMaterial HandshakeType_ClientHello = True finishHandshakeTypeMaterial HandshakeType_ClientHello = True
@ -233,24 +233,24 @@ setServerRandom ran = updateHandshake "srand" (\hst -> hst { hstServerRandom = J
setMasterSecret :: MonadState TLSState m => Bytes -> m () setMasterSecret :: MonadState TLSState m => Bytes -> m ()
setMasterSecret masterSecret = do setMasterSecret masterSecret = do
hasValidHandshake "master secret" hasValidHandshake "master secret"
updateHandshake "master secret" (\hst -> hst { hstMasterSecret = Just masterSecret } ) updateHandshake "master secret" (\hst -> hst { hstMasterSecret = Just masterSecret } )
setKeyBlock setKeyBlock
return () return ()
setMasterSecretFromPre :: MonadState TLSState m => Bytes -> m () setMasterSecretFromPre :: MonadState TLSState m => Bytes -> m ()
setMasterSecretFromPre premasterSecret = do setMasterSecretFromPre premasterSecret = do
hasValidHandshake "generate master secret" hasValidHandshake "generate master secret"
st <- get st <- get
setMasterSecret $ genSecret st setMasterSecret $ genSecret st
where where
genSecret st = genSecret st =
let hst = fromJust "handshake" $ stHandshake st in let hst = fromJust "handshake" $ stHandshake st in
generateMasterSecret (stVersion st) generateMasterSecret (stVersion st)
premasterSecret premasterSecret
(hstClientRandom hst) (hstClientRandom hst)
(fromJust "server random" $ hstServerRandom hst) (fromJust "server random" $ hstServerRandom hst)
setPublicKey :: MonadState TLSState m => PublicKey -> m () setPublicKey :: MonadState TLSState m => PublicKey -> m ()
setPublicKey pk = updateHandshake "publickey" (\hst -> hst { hstRSAPublicKey = Just pk }) setPublicKey pk = updateHandshake "publickey" (\hst -> hst { hstRSAPublicKey = Just pk })
@ -260,14 +260,14 @@ setPrivateKey pk = updateHandshake "privatekey" (\hst -> hst { hstRSAPrivateKey
getSessionData :: MonadState TLSState m => m (Maybe SessionData) getSessionData :: MonadState TLSState m => m (Maybe SessionData)
getSessionData = do getSessionData = do
st <- get st <- get
return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st) return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
where wrapSessionData st masterSecret = do where wrapSessionData st masterSecret = do
return $ SessionData return $ SessionData
{ sessionVersion = stVersion st { sessionVersion = stVersion st
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher st , sessionCipher = cipherID $ fromJust "cipher" $ stCipher st
, sessionSecret = masterSecret , sessionSecret = masterSecret
} }
setSession :: MonadState TLSState m => Session -> Bool -> m () setSession :: MonadState TLSState m => Session -> Bool -> m ()
setSession session resuming = modify (\st -> st { stSession = session, stSessionResuming = resuming }) setSession session resuming = modify (\st -> st { stSession = session, stSessionResuming = resuming })
@ -280,41 +280,41 @@ isSessionResuming = gets stSessionResuming
setKeyBlock :: MonadState TLSState m => m () setKeyBlock :: MonadState TLSState m => m ()
setKeyBlock = do setKeyBlock = do
st <- get st <- get
let hst = fromJust "handshake" $ stHandshake st let hst = fromJust "handshake" $ stHandshake st
let cc = stClientContext st let cc = stClientContext st
let cipher = fromJust "cipher" $ stCipher st let cipher = fromJust "cipher" $ stCipher st
let keyblockSize = cipherKeyBlockSize cipher let keyblockSize = cipherKeyBlockSize cipher
let bulk = cipherBulk cipher let bulk = cipherBulk cipher
let digestSize = hashSize $ cipherHash cipher let digestSize = hashSize $ cipherHash cipher
let keySize = bulkKeySize bulk let keySize = bulkKeySize bulk
let ivSize = bulkIVSize bulk let ivSize = bulkIVSize bulk
let kb = generateKeyBlock (stVersion st) (hstClientRandom hst) let kb = generateKeyBlock (stVersion st) (hstClientRandom hst)
(fromJust "server random" $ hstServerRandom hst) (fromJust "server random" $ hstServerRandom hst)
(fromJust "master secret" $ hstMasterSecret hst) keyblockSize (fromJust "master secret" $ hstMasterSecret hst) keyblockSize
let (cMACSecret, sMACSecret, cWriteKey, sWriteKey, cWriteIV, sWriteIV) = let (cMACSecret, sMACSecret, cWriteKey, sWriteKey, cWriteIV, sWriteIV) =
fromJust "p6" $ partition6 kb (digestSize, digestSize, keySize, keySize, ivSize, ivSize) fromJust "p6" $ partition6 kb (digestSize, digestSize, keySize, keySize, ivSize, ivSize)
let cstClient = TLSCryptState let cstClient = TLSCryptState
{ cstKey = cWriteKey { cstKey = cWriteKey
, cstIV = cWriteIV , cstIV = cWriteIV
, cstMacSecret = cMACSecret } , cstMacSecret = cMACSecret }
let cstServer = TLSCryptState let cstServer = TLSCryptState
{ cstKey = sWriteKey { cstKey = sWriteKey
, cstIV = sWriteIV , cstIV = sWriteIV
, cstMacSecret = sMACSecret } , cstMacSecret = sMACSecret }
let msClient = TLSMacState { msSequence = 0 } let msClient = TLSMacState { msSequence = 0 }
let msServer = TLSMacState { msSequence = 0 } let msServer = TLSMacState { msSequence = 0 }
put $ st put $ st
{ stTxCryptState = Just $ if cc then cstClient else cstServer { stTxCryptState = Just $ if cc then cstClient else cstServer
, stRxCryptState = Just $ if cc then cstServer else cstClient , stRxCryptState = Just $ if cc then cstServer else cstClient
, stTxMacState = Just $ if cc then msClient else msServer , stTxMacState = Just $ if cc then msClient else msServer
, stRxMacState = Just $ if cc then msServer else msClient , stRxMacState = Just $ if cc then msServer else msClient
} }
setCipher :: MonadState TLSState m => Cipher -> m () setCipher :: MonadState TLSState m => Cipher -> m ()
setCipher cipher = modify (\st -> st { stCipher = Just cipher }) setCipher cipher = modify (\st -> st { stCipher = Just cipher })
@ -358,42 +358,42 @@ isClientContext = get >>= return . stClientContext
-- create a new empty handshake state -- create a new empty handshake state
newEmptyHandshake :: Version -> ClientRandom -> HashCtx -> TLSHandshakeState newEmptyHandshake :: Version -> ClientRandom -> HashCtx -> TLSHandshakeState
newEmptyHandshake ver crand digestInit = TLSHandshakeState newEmptyHandshake ver crand digestInit = TLSHandshakeState
{ hstClientVersion = ver { hstClientVersion = ver
, hstClientRandom = crand , hstClientRandom = crand
, hstServerRandom = Nothing , hstServerRandom = Nothing
, hstMasterSecret = Nothing , hstMasterSecret = Nothing
, hstRSAPublicKey = Nothing , hstRSAPublicKey = Nothing
, hstRSAPrivateKey = Nothing , hstRSAPrivateKey = Nothing
, hstHandshakeDigest = digestInit , hstHandshakeDigest = digestInit
} }
startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m () startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m ()
startHandshakeClient ver crand = do startHandshakeClient ver crand = do
-- FIXME check if handshake is already not null -- FIXME check if handshake is already not null
let initCtx = if ver < TLS12 then hashMD5SHA1 else hashSHA256 let initCtx = if ver < TLS12 then hashMD5SHA1 else hashSHA256
chs <- get >>= return . stHandshake chs <- get >>= return . stHandshake
when (isNothing chs) $ when (isNothing chs) $
modify (\st -> st { stHandshake = Just $ newEmptyHandshake ver crand initCtx }) modify (\st -> st { stHandshake = Just $ newEmptyHandshake ver crand initCtx })
hasValidHandshake :: MonadState TLSState m => String -> m () hasValidHandshake :: MonadState TLSState m => String -> m ()
hasValidHandshake name = get >>= \st -> assert name [ ("valid handshake", isNothing $ stHandshake st) ] hasValidHandshake name = get >>= \st -> assert name [ ("valid handshake", isNothing $ stHandshake st) ]
updateHandshake :: MonadState TLSState m => String -> (TLSHandshakeState -> TLSHandshakeState) -> m () updateHandshake :: MonadState TLSState m => String -> (TLSHandshakeState -> TLSHandshakeState) -> m ()
updateHandshake n f = do updateHandshake n f = do
hasValidHandshake n hasValidHandshake n
modify (\st -> st { stHandshake = f <$> stHandshake st }) modify (\st -> st { stHandshake = f <$> stHandshake st })
updateHandshakeDigest :: MonadState TLSState m => Bytes -> m () updateHandshakeDigest :: MonadState TLSState m => Bytes -> m ()
updateHandshakeDigest content = updateHandshake "update digest" $ \hs -> updateHandshakeDigest content = updateHandshake "update digest" $ \hs ->
hs { hstHandshakeDigest = hashUpdate (hstHandshakeDigest hs) content } hs { hstHandshakeDigest = hashUpdate (hstHandshakeDigest hs) content }
getHandshakeDigest :: MonadState TLSState m => Bool -> m Bytes getHandshakeDigest :: MonadState TLSState m => Bool -> m Bytes
getHandshakeDigest client = do getHandshakeDigest client = do
st <- get st <- get
let hst = fromJust "handshake" $ stHandshake st let hst = fromJust "handshake" $ stHandshake st
let hashctx = hstHandshakeDigest hst let hashctx = hstHandshakeDigest hst
let msecret = fromJust "master secret" $ hstMasterSecret hst let msecret = fromJust "master secret" $ hstMasterSecret hst
return $ (if client then generateClientFinished else generateServerFinished) (stVersion st) msecret hashctx return $ (if client then generateClientFinished else generateServerFinished) (stVersion st) msecret hashctx
endHandshake :: MonadState TLSState m => m () endHandshake :: MonadState TLSState m => m ()
endHandshake = modify (\st -> st { stHandshake = Nothing }) endHandshake = modify (\st -> st { stHandshake = Nothing })

View file

@ -10,40 +10,40 @@
-- the Struct module contains all definitions and values of the TLS protocol -- the Struct module contains all definitions and values of the TLS protocol
-- --
module Network.TLS.Struct module Network.TLS.Struct
( Bytes ( Bytes
, Version(..) , Version(..)
, ConnectionEnd(..) , ConnectionEnd(..)
, CipherType(..) , CipherType(..)
, CipherData(..) , CipherData(..)
, Extension , Extension
, CertificateType(..) , CertificateType(..)
, HashAlgorithm(..) , HashAlgorithm(..)
, SignatureAlgorithm(..) , SignatureAlgorithm(..)
, ProtocolType(..) , ProtocolType(..)
, TLSError(..) , TLSError(..)
, ServerDHParams(..) , ServerDHParams(..)
, ServerRSAParams(..) , ServerRSAParams(..)
, ServerKeyXchgAlgorithmData(..) , ServerKeyXchgAlgorithmData(..)
, Packet(..) , Packet(..)
, Header(..) , Header(..)
, ServerRandom(..) , ServerRandom(..)
, ClientRandom(..) , ClientRandom(..)
, serverRandom , serverRandom
, clientRandom , clientRandom
, FinishedData , FinishedData
, SessionID , SessionID
, Session(..) , Session(..)
, SessionData(..) , SessionData(..)
, AlertLevel(..) , AlertLevel(..)
, AlertDescription(..) , AlertDescription(..)
, HandshakeType(..) , HandshakeType(..)
, Handshake(..) , Handshake(..)
, numericalVer , numericalVer
, verOfNum , verOfNum
, TypeValuable, valOfType, valToType , TypeValuable, valOfType, valToType
, packetType , packetType
, typeOfHandshake , typeOfHandshake
) where ) where
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import qualified Data.ByteString as B (length) import qualified Data.ByteString as B (length)
@ -64,77 +64,77 @@ data ConnectionEnd = ConnectionServer | ConnectionClient
data CipherType = CipherStream | CipherBlock | CipherAEAD data CipherType = CipherStream | CipherBlock | CipherAEAD
data CipherData = CipherData data CipherData = CipherData
{ cipherDataContent :: Bytes { cipherDataContent :: Bytes
, cipherDataMAC :: Maybe Bytes , cipherDataMAC :: Maybe Bytes
, cipherDataPadding :: Maybe Bytes , cipherDataPadding :: Maybe Bytes
} deriving (Show,Eq) } deriving (Show,Eq)
data CertificateType = data CertificateType =
CertificateType_RSA_Sign -- TLS10 CertificateType_RSA_Sign -- TLS10
| CertificateType_DSS_Sign -- TLS10 | CertificateType_DSS_Sign -- TLS10
| CertificateType_RSA_Fixed_DH -- TLS10 | CertificateType_RSA_Fixed_DH -- TLS10
| CertificateType_DSS_Fixed_DH -- TLS10 | CertificateType_DSS_Fixed_DH -- TLS10
| CertificateType_RSA_Ephemeral_DH -- TLS12 | CertificateType_RSA_Ephemeral_DH -- TLS12
| CertificateType_DSS_Ephemeral_DH -- TLS12 | CertificateType_DSS_Ephemeral_DH -- TLS12
| CertificateType_fortezza_dms -- TLS12 | CertificateType_fortezza_dms -- TLS12
| CertificateType_Unknown Word8 | CertificateType_Unknown Word8
deriving (Show,Eq) deriving (Show,Eq)
data HashAlgorithm = data HashAlgorithm =
HashNone HashNone
| HashMD5 | HashMD5
| HashSHA1 | HashSHA1
| HashSHA224 | HashSHA224
| HashSHA256 | HashSHA256
| HashSHA384 | HashSHA384
| HashSHA512 | HashSHA512
| HashOther Word8 | HashOther Word8
deriving (Show,Eq) deriving (Show,Eq)
data SignatureAlgorithm = data SignatureAlgorithm =
SignatureAnonymous SignatureAnonymous
| SignatureRSA | SignatureRSA
| SignatureDSS | SignatureDSS
| SignatureECDSA | SignatureECDSA
| SignatureOther Word8 | SignatureOther Word8
deriving (Show,Eq) deriving (Show,Eq)
data ProtocolType = data ProtocolType =
ProtocolType_ChangeCipherSpec ProtocolType_ChangeCipherSpec
| ProtocolType_Alert | ProtocolType_Alert
| ProtocolType_Handshake | ProtocolType_Handshake
| ProtocolType_AppData | ProtocolType_AppData
deriving (Eq, Show) deriving (Eq, Show)
-- | TLSError that might be returned through the TLS stack -- | TLSError that might be returned through the TLS stack
data TLSError = data TLSError =
Error_Misc String -- ^ mainly for instance of Error Error_Misc String -- ^ mainly for instance of Error
| Error_Protocol (String, Bool, AlertDescription) | Error_Protocol (String, Bool, AlertDescription)
| Error_Certificate String | Error_Certificate String
| Error_HandshakePolicy String -- ^ handshake policy failed. | Error_HandshakePolicy String -- ^ handshake policy failed.
| Error_Random String | Error_Random String
| Error_EOF | Error_EOF
| Error_Packet String | Error_Packet String
| Error_Packet_Size_Mismatch (Int, Int) | Error_Packet_Size_Mismatch (Int, Int)
| Error_Packet_unexpected String String | Error_Packet_unexpected String String
| Error_Packet_Parsing String | Error_Packet_Parsing String
| Error_Internal_Packet_ByteProcessed Int Int Int | Error_Internal_Packet_ByteProcessed Int Int Int
| Error_Unknown_Version Word8 Word8 | Error_Unknown_Version Word8 Word8
| Error_Unknown_Type String | Error_Unknown_Type String
deriving (Eq, Show, Typeable) deriving (Eq, Show, Typeable)
instance Error TLSError where instance Error TLSError where
noMsg = Error_Misc "" noMsg = Error_Misc ""
strMsg = Error_Misc strMsg = Error_Misc
instance Exception TLSError instance Exception TLSError
data Packet = data Packet =
Handshake [Handshake] Handshake [Handshake]
| Alert [(AlertLevel, AlertDescription)] | Alert [(AlertLevel, AlertDescription)]
| ChangeCipherSpec | ChangeCipherSpec
| AppData ByteString | AppData ByteString
deriving (Show,Eq) deriving (Show,Eq)
data Header = Header ProtocolType Version Word16 deriving (Show,Eq) data Header = Header ProtocolType Version Word16 deriving (Show,Eq)
@ -144,10 +144,10 @@ type SessionID = Bytes
newtype Session = Session (Maybe SessionID) deriving (Show, Eq) newtype Session = Session (Maybe SessionID) deriving (Show, Eq)
data SessionData = SessionData data SessionData = SessionData
{ sessionVersion :: Version { sessionVersion :: Version
, sessionCipher :: CipherID , sessionCipher :: CipherID
, sessionSecret :: Bytes , sessionSecret :: Bytes
} }
type CipherID = Word16 type CipherID = Word16
type CompressionID = Word8 type CompressionID = Word8
@ -164,84 +164,84 @@ clientRandom :: Bytes -> Maybe ClientRandom
clientRandom l = constrRandom32 ClientRandom l clientRandom l = constrRandom32 ClientRandom l
data AlertLevel = data AlertLevel =
AlertLevel_Warning AlertLevel_Warning
| AlertLevel_Fatal | AlertLevel_Fatal
deriving (Show,Eq) deriving (Show,Eq)
data AlertDescription = data AlertDescription =
CloseNotify CloseNotify
| UnexpectedMessage | UnexpectedMessage
| BadRecordMac | BadRecordMac
| DecryptionFailed -- ^ deprecated alert, should never be sent by compliant implementation | DecryptionFailed -- ^ deprecated alert, should never be sent by compliant implementation
| RecordOverflow | RecordOverflow
| DecompressionFailure | DecompressionFailure
| HandshakeFailure | HandshakeFailure
| BadCertificate | BadCertificate
| UnsupportedCertificate | UnsupportedCertificate
| CertificateRevoked | CertificateRevoked
| CertificateExpired | CertificateExpired
| CertificateUnknown | CertificateUnknown
| IllegalParameter | IllegalParameter
| UnknownCa | UnknownCa
| AccessDenied | AccessDenied
| DecodeError | DecodeError
| DecryptError | DecryptError
| ExportRestriction | ExportRestriction
| ProtocolVersion | ProtocolVersion
| InsufficientSecurity | InsufficientSecurity
| InternalError | InternalError
| UserCanceled | UserCanceled
| NoRenegotiation | NoRenegotiation
deriving (Show,Eq) deriving (Show,Eq)
data HandshakeType = data HandshakeType =
HandshakeType_HelloRequest HandshakeType_HelloRequest
| HandshakeType_ClientHello | HandshakeType_ClientHello
| HandshakeType_ServerHello | HandshakeType_ServerHello
| HandshakeType_Certificate | HandshakeType_Certificate
| HandshakeType_ServerKeyXchg | HandshakeType_ServerKeyXchg
| HandshakeType_CertRequest | HandshakeType_CertRequest
| HandshakeType_ServerHelloDone | HandshakeType_ServerHelloDone
| HandshakeType_CertVerify | HandshakeType_CertVerify
| HandshakeType_ClientKeyXchg | HandshakeType_ClientKeyXchg
| HandshakeType_Finished | HandshakeType_Finished
| HandshakeType_NPN -- Next Protocol Negotiation extension | HandshakeType_NPN -- Next Protocol Negotiation extension
deriving (Show,Eq) deriving (Show,Eq)
data ServerDHParams = ServerDHParams data ServerDHParams = ServerDHParams
{ dh_p :: Integer -- ^ prime modulus { dh_p :: Integer -- ^ prime modulus
, dh_g :: Integer -- ^ generator , dh_g :: Integer -- ^ generator
, dh_Ys :: Integer -- ^ public value (g^X mod p) , dh_Ys :: Integer -- ^ public value (g^X mod p)
} deriving (Show,Eq) } deriving (Show,Eq)
data ServerRSAParams = ServerRSAParams data ServerRSAParams = ServerRSAParams
{ rsa_modulus :: Integer { rsa_modulus :: Integer
, rsa_exponent :: Integer , rsa_exponent :: Integer
} deriving (Show,Eq) } deriving (Show,Eq)
data ServerKeyXchgAlgorithmData = data ServerKeyXchgAlgorithmData =
SKX_DH_Anon ServerDHParams SKX_DH_Anon ServerDHParams
| SKX_DHE_DSS ServerDHParams [Word8] | SKX_DHE_DSS ServerDHParams [Word8]
| SKX_DHE_RSA ServerDHParams [Word8] | SKX_DHE_RSA ServerDHParams [Word8]
| SKX_RSA (Maybe ServerRSAParams) | SKX_RSA (Maybe ServerRSAParams)
| SKX_DH_DSS (Maybe ServerRSAParams) | SKX_DH_DSS (Maybe ServerRSAParams)
| SKX_DH_RSA (Maybe ServerRSAParams) | SKX_DH_RSA (Maybe ServerRSAParams)
| SKX_Unknown Bytes | SKX_Unknown Bytes
deriving (Show,Eq) deriving (Show,Eq)
data Handshake = data Handshake =
ClientHello !Version !ClientRandom !Session ![CipherID] ![CompressionID] [Extension] ClientHello !Version !ClientRandom !Session ![CipherID] ![CompressionID] [Extension]
| ServerHello !Version !ServerRandom !Session !CipherID !CompressionID [Extension] | ServerHello !Version !ServerRandom !Session !CipherID !CompressionID [Extension]
| Certificates [X509] | Certificates [X509]
| HelloRequest | HelloRequest
| ServerHelloDone | ServerHelloDone
| ClientKeyXchg Bytes | ClientKeyXchg Bytes
| ServerKeyXchg ServerKeyXchgAlgorithmData | ServerKeyXchg ServerKeyXchgAlgorithmData
| CertRequest [CertificateType] (Maybe [ (HashAlgorithm, SignatureAlgorithm) ]) [Word8] | CertRequest [CertificateType] (Maybe [ (HashAlgorithm, SignatureAlgorithm) ]) [Word8]
| CertVerify [Word8] | CertVerify [Word8]
| Finished FinishedData | Finished FinishedData
| NextProtocolNegotiation Bytes -- NPN extension | NextProtocolNegotiation Bytes -- NPN extension
deriving (Show,Eq) deriving (Show,Eq)
packetType :: Packet -> ProtocolType packetType :: Packet -> ProtocolType
packetType (Handshake _) = ProtocolType_Handshake packetType (Handshake _) = ProtocolType_Handshake
@ -278,170 +278,170 @@ verOfNum (3, 3) = Just TLS12
verOfNum _ = Nothing verOfNum _ = Nothing
class TypeValuable a where class TypeValuable a where
valOfType :: a -> Word8 valOfType :: a -> Word8
valToType :: Word8 -> Maybe a valToType :: Word8 -> Maybe a
instance TypeValuable ConnectionEnd where instance TypeValuable ConnectionEnd where
valOfType ConnectionServer = 0 valOfType ConnectionServer = 0
valOfType ConnectionClient = 1 valOfType ConnectionClient = 1
valToType 0 = Just ConnectionServer valToType 0 = Just ConnectionServer
valToType 1 = Just ConnectionClient valToType 1 = Just ConnectionClient
valToType _ = Nothing valToType _ = Nothing
instance TypeValuable CipherType where instance TypeValuable CipherType where
valOfType CipherStream = 0 valOfType CipherStream = 0
valOfType CipherBlock = 1 valOfType CipherBlock = 1
valOfType CipherAEAD = 2 valOfType CipherAEAD = 2
valToType 0 = Just CipherStream valToType 0 = Just CipherStream
valToType 1 = Just CipherBlock valToType 1 = Just CipherBlock
valToType 2 = Just CipherAEAD valToType 2 = Just CipherAEAD
valToType _ = Nothing valToType _ = Nothing
instance TypeValuable ProtocolType where instance TypeValuable ProtocolType where
valOfType ProtocolType_ChangeCipherSpec = 20 valOfType ProtocolType_ChangeCipherSpec = 20
valOfType ProtocolType_Alert = 21 valOfType ProtocolType_Alert = 21
valOfType ProtocolType_Handshake = 22 valOfType ProtocolType_Handshake = 22
valOfType ProtocolType_AppData = 23 valOfType ProtocolType_AppData = 23
valToType 20 = Just ProtocolType_ChangeCipherSpec valToType 20 = Just ProtocolType_ChangeCipherSpec
valToType 21 = Just ProtocolType_Alert valToType 21 = Just ProtocolType_Alert
valToType 22 = Just ProtocolType_Handshake valToType 22 = Just ProtocolType_Handshake
valToType 23 = Just ProtocolType_AppData valToType 23 = Just ProtocolType_AppData
valToType _ = Nothing valToType _ = Nothing
instance TypeValuable HandshakeType where instance TypeValuable HandshakeType where
valOfType HandshakeType_HelloRequest = 0 valOfType HandshakeType_HelloRequest = 0
valOfType HandshakeType_ClientHello = 1 valOfType HandshakeType_ClientHello = 1
valOfType HandshakeType_ServerHello = 2 valOfType HandshakeType_ServerHello = 2
valOfType HandshakeType_Certificate = 11 valOfType HandshakeType_Certificate = 11
valOfType HandshakeType_ServerKeyXchg = 12 valOfType HandshakeType_ServerKeyXchg = 12
valOfType HandshakeType_CertRequest = 13 valOfType HandshakeType_CertRequest = 13
valOfType HandshakeType_ServerHelloDone = 14 valOfType HandshakeType_ServerHelloDone = 14
valOfType HandshakeType_CertVerify = 15 valOfType HandshakeType_CertVerify = 15
valOfType HandshakeType_ClientKeyXchg = 16 valOfType HandshakeType_ClientKeyXchg = 16
valOfType HandshakeType_Finished = 20 valOfType HandshakeType_Finished = 20
valOfType HandshakeType_NPN = 67 valOfType HandshakeType_NPN = 67
valToType 0 = Just HandshakeType_HelloRequest valToType 0 = Just HandshakeType_HelloRequest
valToType 1 = Just HandshakeType_ClientHello valToType 1 = Just HandshakeType_ClientHello
valToType 2 = Just HandshakeType_ServerHello valToType 2 = Just HandshakeType_ServerHello
valToType 11 = Just HandshakeType_Certificate valToType 11 = Just HandshakeType_Certificate
valToType 12 = Just HandshakeType_ServerKeyXchg valToType 12 = Just HandshakeType_ServerKeyXchg
valToType 13 = Just HandshakeType_CertRequest valToType 13 = Just HandshakeType_CertRequest
valToType 14 = Just HandshakeType_ServerHelloDone valToType 14 = Just HandshakeType_ServerHelloDone
valToType 15 = Just HandshakeType_CertVerify valToType 15 = Just HandshakeType_CertVerify
valToType 16 = Just HandshakeType_ClientKeyXchg valToType 16 = Just HandshakeType_ClientKeyXchg
valToType 20 = Just HandshakeType_Finished valToType 20 = Just HandshakeType_Finished
valToType 67 = Just HandshakeType_NPN valToType 67 = Just HandshakeType_NPN
valToType _ = Nothing valToType _ = Nothing
instance TypeValuable AlertLevel where instance TypeValuable AlertLevel where
valOfType AlertLevel_Warning = 1 valOfType AlertLevel_Warning = 1
valOfType AlertLevel_Fatal = 2 valOfType AlertLevel_Fatal = 2
valToType 1 = Just AlertLevel_Warning valToType 1 = Just AlertLevel_Warning
valToType 2 = Just AlertLevel_Fatal valToType 2 = Just AlertLevel_Fatal
valToType _ = Nothing valToType _ = Nothing
instance TypeValuable AlertDescription where instance TypeValuable AlertDescription where
valOfType CloseNotify = 0 valOfType CloseNotify = 0
valOfType UnexpectedMessage = 10 valOfType UnexpectedMessage = 10
valOfType BadRecordMac = 20 valOfType BadRecordMac = 20
valOfType DecryptionFailed = 21 valOfType DecryptionFailed = 21
valOfType RecordOverflow = 22 valOfType RecordOverflow = 22
valOfType DecompressionFailure = 30 valOfType DecompressionFailure = 30
valOfType HandshakeFailure = 40 valOfType HandshakeFailure = 40
valOfType BadCertificate = 42 valOfType BadCertificate = 42
valOfType UnsupportedCertificate = 43 valOfType UnsupportedCertificate = 43
valOfType CertificateRevoked = 44 valOfType CertificateRevoked = 44
valOfType CertificateExpired = 45 valOfType CertificateExpired = 45
valOfType CertificateUnknown = 46 valOfType CertificateUnknown = 46
valOfType IllegalParameter = 47 valOfType IllegalParameter = 47
valOfType UnknownCa = 48 valOfType UnknownCa = 48
valOfType AccessDenied = 49 valOfType AccessDenied = 49
valOfType DecodeError = 50 valOfType DecodeError = 50
valOfType DecryptError = 51 valOfType DecryptError = 51
valOfType ExportRestriction = 60 valOfType ExportRestriction = 60
valOfType ProtocolVersion = 70 valOfType ProtocolVersion = 70
valOfType InsufficientSecurity = 71 valOfType InsufficientSecurity = 71
valOfType InternalError = 80 valOfType InternalError = 80
valOfType UserCanceled = 90 valOfType UserCanceled = 90
valOfType NoRenegotiation = 100 valOfType NoRenegotiation = 100
valToType 0 = Just CloseNotify valToType 0 = Just CloseNotify
valToType 10 = Just UnexpectedMessage valToType 10 = Just UnexpectedMessage
valToType 20 = Just BadRecordMac valToType 20 = Just BadRecordMac
valToType 21 = Just DecryptionFailed valToType 21 = Just DecryptionFailed
valToType 22 = Just RecordOverflow valToType 22 = Just RecordOverflow
valToType 30 = Just DecompressionFailure valToType 30 = Just DecompressionFailure
valToType 40 = Just HandshakeFailure valToType 40 = Just HandshakeFailure
valToType 42 = Just BadCertificate valToType 42 = Just BadCertificate
valToType 43 = Just UnsupportedCertificate valToType 43 = Just UnsupportedCertificate
valToType 44 = Just CertificateRevoked valToType 44 = Just CertificateRevoked
valToType 45 = Just CertificateExpired valToType 45 = Just CertificateExpired
valToType 46 = Just CertificateUnknown valToType 46 = Just CertificateUnknown
valToType 47 = Just IllegalParameter valToType 47 = Just IllegalParameter
valToType 48 = Just UnknownCa valToType 48 = Just UnknownCa
valToType 49 = Just AccessDenied valToType 49 = Just AccessDenied
valToType 50 = Just DecodeError valToType 50 = Just DecodeError
valToType 51 = Just DecryptError valToType 51 = Just DecryptError
valToType 60 = Just ExportRestriction valToType 60 = Just ExportRestriction
valToType 70 = Just ProtocolVersion valToType 70 = Just ProtocolVersion
valToType 71 = Just InsufficientSecurity valToType 71 = Just InsufficientSecurity
valToType 80 = Just InternalError valToType 80 = Just InternalError
valToType 90 = Just UserCanceled valToType 90 = Just UserCanceled
valToType 100 = Just NoRenegotiation valToType 100 = Just NoRenegotiation
valToType _ = Nothing valToType _ = Nothing
instance TypeValuable CertificateType where instance TypeValuable CertificateType where
valOfType CertificateType_RSA_Sign = 1 valOfType CertificateType_RSA_Sign = 1
valOfType CertificateType_DSS_Sign = 2 valOfType CertificateType_DSS_Sign = 2
valOfType CertificateType_RSA_Fixed_DH = 3 valOfType CertificateType_RSA_Fixed_DH = 3
valOfType CertificateType_DSS_Fixed_DH = 4 valOfType CertificateType_DSS_Fixed_DH = 4
valOfType CertificateType_RSA_Ephemeral_DH = 5 valOfType CertificateType_RSA_Ephemeral_DH = 5
valOfType CertificateType_DSS_Ephemeral_DH = 6 valOfType CertificateType_DSS_Ephemeral_DH = 6
valOfType CertificateType_fortezza_dms = 20 valOfType CertificateType_fortezza_dms = 20
valOfType (CertificateType_Unknown i) = i valOfType (CertificateType_Unknown i) = i
valToType 1 = Just CertificateType_RSA_Sign valToType 1 = Just CertificateType_RSA_Sign
valToType 2 = Just CertificateType_DSS_Sign valToType 2 = Just CertificateType_DSS_Sign
valToType 3 = Just CertificateType_RSA_Fixed_DH valToType 3 = Just CertificateType_RSA_Fixed_DH
valToType 4 = Just CertificateType_DSS_Fixed_DH valToType 4 = Just CertificateType_DSS_Fixed_DH
valToType 5 = Just CertificateType_RSA_Ephemeral_DH valToType 5 = Just CertificateType_RSA_Ephemeral_DH
valToType 6 = Just CertificateType_DSS_Ephemeral_DH valToType 6 = Just CertificateType_DSS_Ephemeral_DH
valToType 20 = Just CertificateType_fortezza_dms valToType 20 = Just CertificateType_fortezza_dms
valToType i = Just (CertificateType_Unknown i) valToType i = Just (CertificateType_Unknown i)
instance TypeValuable HashAlgorithm where instance TypeValuable HashAlgorithm where
valOfType HashNone = 0 valOfType HashNone = 0
valOfType HashMD5 = 1 valOfType HashMD5 = 1
valOfType HashSHA1 = 2 valOfType HashSHA1 = 2
valOfType HashSHA224 = 3 valOfType HashSHA224 = 3
valOfType HashSHA256 = 4 valOfType HashSHA256 = 4
valOfType HashSHA384 = 5 valOfType HashSHA384 = 5
valOfType HashSHA512 = 6 valOfType HashSHA512 = 6
valOfType (HashOther i) = i valOfType (HashOther i) = i
valToType 0 = Just HashNone valToType 0 = Just HashNone
valToType 1 = Just HashMD5 valToType 1 = Just HashMD5
valToType 2 = Just HashSHA1 valToType 2 = Just HashSHA1
valToType 3 = Just HashSHA224 valToType 3 = Just HashSHA224
valToType 4 = Just HashSHA256 valToType 4 = Just HashSHA256
valToType 5 = Just HashSHA384 valToType 5 = Just HashSHA384
valToType 6 = Just HashSHA512 valToType 6 = Just HashSHA512
valToType i = Just (HashOther i) valToType i = Just (HashOther i)
instance TypeValuable SignatureAlgorithm where instance TypeValuable SignatureAlgorithm where
valOfType SignatureAnonymous = 0 valOfType SignatureAnonymous = 0
valOfType SignatureRSA = 1 valOfType SignatureRSA = 1
valOfType SignatureDSS = 2 valOfType SignatureDSS = 2
valOfType SignatureECDSA = 3 valOfType SignatureECDSA = 3
valOfType (SignatureOther i) = i valOfType (SignatureOther i) = i
valToType 0 = Just SignatureAnonymous valToType 0 = Just SignatureAnonymous
valToType 1 = Just SignatureRSA valToType 1 = Just SignatureRSA
valToType 2 = Just SignatureDSS valToType 2 = Just SignatureDSS
valToType 3 = Just SignatureECDSA valToType 3 = Just SignatureECDSA
valToType i = Just (SignatureOther i) valToType i = Just (SignatureOther i)

View file

@ -1,13 +1,13 @@
module Network.TLS.Util module Network.TLS.Util
( sub ( sub
, takelast , takelast
, partition3 , partition3
, partition6 , partition6
, fromJust , fromJust
, and' , and'
, (&&!) , (&&!)
, bytesEq , bytesEq
) where ) where
import Data.List (foldl') import Data.List (foldl')
import Network.TLS.Struct (Bytes) import Network.TLS.Struct (Bytes)
@ -15,32 +15,32 @@ import qualified Data.ByteString as B
sub :: Bytes -> Int -> Int -> Maybe Bytes sub :: Bytes -> Int -> Int -> Maybe Bytes
sub b offset len sub b offset len
| B.length b < offset + len = Nothing | B.length b < offset + len = Nothing
| otherwise = Just $ B.take len $ snd $ B.splitAt offset b | otherwise = Just $ B.take len $ snd $ B.splitAt offset b
takelast :: Int -> Bytes -> Maybe Bytes takelast :: Int -> Bytes -> Maybe Bytes
takelast i b takelast i b
| B.length b >= i = sub b (B.length b - i) i | B.length b >= i = sub b (B.length b - i) i
| otherwise = Nothing | otherwise = Nothing
partition3 :: Bytes -> (Int,Int,Int) -> Maybe (Bytes, Bytes, Bytes) partition3 :: Bytes -> (Int,Int,Int) -> Maybe (Bytes, Bytes, Bytes)
partition3 bytes (d1,d2,d3) = if B.length bytes /= s then Nothing else Just (p1,p2,p3) partition3 bytes (d1,d2,d3) = if B.length bytes /= s then Nothing else Just (p1,p2,p3)
where where
s = sum [d1,d2,d3] s = sum [d1,d2,d3]
(p1, r1) = B.splitAt d1 bytes (p1, r1) = B.splitAt d1 bytes
(p2, r2) = B.splitAt d2 r1 (p2, r2) = B.splitAt d2 r1
(p3, _) = B.splitAt d3 r2 (p3, _) = B.splitAt d3 r2
partition6 :: Bytes -> (Int,Int,Int,Int,Int,Int) -> Maybe (Bytes, Bytes, Bytes, Bytes, Bytes, Bytes) partition6 :: Bytes -> (Int,Int,Int,Int,Int,Int) -> Maybe (Bytes, Bytes, Bytes, Bytes, Bytes, Bytes)
partition6 bytes (d1,d2,d3,d4,d5,d6) = if B.length bytes < s then Nothing else Just (p1,p2,p3,p4,p5,p6) partition6 bytes (d1,d2,d3,d4,d5,d6) = if B.length bytes < s then Nothing else Just (p1,p2,p3,p4,p5,p6)
where where
s = sum [d1,d2,d3,d4,d5,d6] s = sum [d1,d2,d3,d4,d5,d6]
(p1, r1) = B.splitAt d1 bytes (p1, r1) = B.splitAt d1 bytes
(p2, r2) = B.splitAt d2 r1 (p2, r2) = B.splitAt d2 r1
(p3, r3) = B.splitAt d3 r2 (p3, r3) = B.splitAt d3 r2
(p4, r4) = B.splitAt d4 r3 (p4, r4) = B.splitAt d4 r3
(p5, r5) = B.splitAt d5 r4 (p5, r5) = B.splitAt d5 r4
(p6, _) = B.splitAt d6 r5 (p6, _) = B.splitAt d6 r5
fromJust :: String -> Maybe a -> a fromJust :: String -> Maybe a -> a
fromJust what Nothing = error ("fromJust " ++ what ++ ": Nothing") -- yuck fromJust what Nothing = error ("fromJust " ++ what ++ ": Nothing") -- yuck

View file

@ -9,34 +9,34 @@
-- all multibytes values are written as big endian. -- all multibytes values are written as big endian.
-- --
module Network.TLS.Wire module Network.TLS.Wire
( Get ( Get
, runGet , runGet
, remaining , remaining
, getWord8 , getWord8
, getWords8 , getWords8
, getWord16 , getWord16
, getWords16 , getWords16
, getWord24 , getWord24
, getBytes , getBytes
, getOpaque8 , getOpaque8
, getOpaque16 , getOpaque16
, getOpaque24 , getOpaque24
, processBytes , processBytes
, isEmpty , isEmpty
, Put , Put
, runPut , runPut
, putWord8 , putWord8
, putWords8 , putWords8
, putWord16 , putWord16
, putWords16 , putWords16
, putWord24 , putWord24
, putBytes , putBytes
, putOpaque8 , putOpaque8
, putOpaque16 , putOpaque16
, putOpaque24 , putOpaque24
, encodeWord16 , encodeWord16
, encodeWord64 , encodeWord64
) where ) where
import Data.Serialize.Get hiding (runGet) import Data.Serialize.Get hiding (runGet)
import qualified Data.Serialize.Get as G import qualified Data.Serialize.Get as G
@ -62,10 +62,10 @@ getWords16 = getWord16 >>= \lenb -> replicateM (fromIntegral lenb `div` 2) getWo
getWord24 :: Get Int getWord24 :: Get Int
getWord24 = do getWord24 = do
a <- fromIntegral <$> getWord8 a <- fromIntegral <$> getWord8
b <- fromIntegral <$> getWord8 b <- fromIntegral <$> getWord8
c <- fromIntegral <$> getWord8 c <- fromIntegral <$> getWord8
return $ (a `shiftL` 16) .|. (b `shiftL` 8) .|. c return $ (a `shiftL` 16) .|. (b `shiftL` 8) .|. c
getOpaque8 :: Get Bytes getOpaque8 :: Get Bytes
getOpaque8 = getWord8 >>= getBytes . fromIntegral getOpaque8 = getWord8 >>= getBytes . fromIntegral
@ -81,23 +81,23 @@ processBytes i f = isolate i f
putWords8 :: [Word8] -> Put putWords8 :: [Word8] -> Put
putWords8 l = do putWords8 l = do
putWord8 $ fromIntegral (length l) putWord8 $ fromIntegral (length l)
mapM_ putWord8 l mapM_ putWord8 l
putWord16 :: Word16 -> Put putWord16 :: Word16 -> Put
putWord16 = putWord16be putWord16 = putWord16be
putWords16 :: [Word16] -> Put putWords16 :: [Word16] -> Put
putWords16 l = do putWords16 l = do
putWord16 $ 2 * (fromIntegral $ length l) putWord16 $ 2 * (fromIntegral $ length l)
mapM_ putWord16 l mapM_ putWord16 l
putWord24 :: Int -> Put putWord24 :: Int -> Put
putWord24 i = do putWord24 i = do
let a = fromIntegral ((i `shiftR` 16) .&. 0xff) let a = fromIntegral ((i `shiftR` 16) .&. 0xff)
let b = fromIntegral ((i `shiftR` 8) .&. 0xff) let b = fromIntegral ((i `shiftR` 8) .&. 0xff)
let c = fromIntegral (i .&. 0xff) let c = fromIntegral (i .&. 0xff)
mapM_ putWord8 [a,b,c] mapM_ putWord8 [a,b,c]
putBytes :: Bytes -> Put putBytes :: Bytes -> Put
putBytes = putByteString putBytes = putByteString