diff --git a/core/Benchmarks/Benchmarks.hs b/core/Benchmarks/Benchmarks.hs index 9cea486..903ba1d 100644 --- a/core/Benchmarks/Benchmarks.hs +++ b/core/Benchmarks/Benchmarks.hs @@ -8,6 +8,8 @@ import Criterion.Main import Control.Concurrent.Chan import Network.TLS import Data.X509 +import Data.X509.Validation +import Data.Default.Class import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L @@ -15,15 +17,22 @@ import qualified Data.ByteString.Lazy as L recvDataNonNull ctx = recvData ctx >>= \l -> if B.null l then recvDataNonNull ctx else return l getParams connectVer cipher = (cParams, sParams) - where sParams = defaultParamsServer - { pAllowedVersions = [connectVer] - , pCiphers = [cipher] - , pCredentials = Credentials [ (CertificateChain [simpleX509 $ PubKeyRSA pubKey], PrivKeyRSA privKey) ] - } - cParams = defaultParamsClient - { pAllowedVersions = [connectVer] - , pCiphers = [cipher] + where sParams = def { serverSupported = supported + , serverShared = def { + sharedCredentials = Credentials [ (CertificateChain [simpleX509 $ PubKeyRSA pubKey], PrivKeyRSA privKey) ] + } + } + cParams = (defaultParamsClient "" B.empty) + { clientSupported = supported + , clientShared = def { sharedValidationCache = ValidationCache + { cacheAdd = \_ _ _ -> return () + , cacheQuery = \_ _ _ -> return ValidationCachePass + } + } } + supported = def { supportedCiphers = [cipher] + , supportedVersions = [connectVer] + } (pubKey, privKey) = getGlobalRSAPair runTLSPipe params tlsServer tlsClient d name = bench name $ do diff --git a/core/Network/TLS.hs b/core/Network/TLS.hs index 8a4a540..13a6778 100644 --- a/core/Network/TLS.hs +++ b/core/Network/TLS.hs @@ -8,19 +8,17 @@ module Network.TLS ( -- * Context configuration - Params(..) - , RoleParams(..) - , ClientParams(..) + ClientParams(..) , ServerParams(..) - , updateClientParams - , updateServerParams + , ClientHooks(..) + , ServerHooks(..) + , Supported(..) + , Shared(..) , Logging(..) , Measurement(..) , CertificateUsage(..) , CertificateRejectReason(..) , defaultParamsClient - , defaultParamsServer - , defaultLogging , MaxFragmentEnum(..) , HashAndSignatureAlgorithm , HashAlgorithm(..) @@ -36,7 +34,6 @@ module Network.TLS , SessionData(..) , SessionManager(..) , noSessionManager - , setSessionManager -- * Backend abstraction , Backend(..) @@ -94,18 +91,27 @@ module Network.TLS -- * Exceptions , TLSException(..) + + -- * Validation Cache + , ValidationCache + , exceptionValidationCache ) where import Network.TLS.Backend (Backend(..)) -import Network.TLS.Struct ( Version(..), TLSError(..), TLSException(..) +import Network.TLS.Struct ( TLSError(..), TLSException(..) , HashAndSignatureAlgorithm, HashAlgorithm(..), SignatureAlgorithm(..) , Header(..), ProtocolType(..), CertificateType(..) , AlertDescription(..)) import Network.TLS.Crypto (KxError(..)) import Network.TLS.Cipher +import Network.TLS.Hooks +import Network.TLS.Measurement import Network.TLS.Credentials import Network.TLS.Compression (CompressionC(..), Compression(..), nullCompression) import Network.TLS.Context +import Network.TLS.Parameters import Network.TLS.Core import Network.TLS.Session +import Network.TLS.X509 +import Network.TLS.Types import Data.X509 (PubKey(..), PrivKey(..)) diff --git a/core/Network/TLS/Context.hs b/core/Network/TLS/Context.hs index 225b7af..41971b7 100644 --- a/core/Network/TLS/Context.hs +++ b/core/Network/TLS/Context.hs @@ -8,44 +8,18 @@ module Network.TLS.Context ( -- * Context configuration - Params(..) - , RoleParams(..) - , ClientParams(..) - , ServerParams(..) - , updateClientParams - , updateServerParams - , Logging(..) - , SessionID - , SessionData(..) - , MaxFragmentEnum(..) - , Measurement(..) - , CertificateUsage(..) - , CertificateRejectReason(..) - , defaultLogging - , defaultParamsClient - , defaultParamsServer - , withSessionManager - , setSessionManager - , getClientParams - , getServerParams - , credentialsGet + TLSParams -- * Context object and accessor - , Context + , Context(..) , Hooks(..) - , ctxParams - , ctxConnection , ctxEOF , ctxHasSSLv2ClientHello , ctxDisableSSLv2ClientHello , ctxEstablished - , ctxCiphers , ctxLogging , ctxWithHooks - , ctxRxState - , ctxTxState - , ctxHandshake - , ctxNeedEmptyPacket + , modifyHooks , setEOF , setEstablished , contextFlush @@ -63,13 +37,6 @@ module Network.TLS.Context , Information(..) , contextGetInformation - -- * deprecated types - , TLSParams - , TLSLogging - , TLSCertificateUsage - , TLSCertificateRejectReason - , TLSCtx - -- * New contexts , contextNew -- * Deprecated new contexts methods @@ -91,177 +58,55 @@ module Network.TLS.Context ) where import Network.TLS.Backend -import Network.TLS.Extension +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.Cipher (Cipher(..), CipherKeyExchangeType(..)) -import Network.TLS.Compression (Compression) import Network.TLS.Credentials import Network.TLS.State -import Network.TLS.Handshake.State import Network.TLS.Hooks import Network.TLS.Record.State import Network.TLS.Parameters import Network.TLS.Measurement import Network.TLS.Types (Role(..)) +import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith) import Data.Maybe (isJust) -import qualified Data.ByteString as B import Crypto.Random import Control.Concurrent.MVar import Control.Monad.State -import Control.Exception (throwIO, Exception()) import Data.IORef -import Data.Tuple -- deprecated imports import Network.Socket (Socket) import System.IO (Handle) --- | Information related to a running context, e.g. current cipher -data Information = Information - { infoVersion :: Version - , infoCipher :: Cipher - , infoCompression :: Compression - } deriving (Show,Eq) --- | A TLS Context keep tls specific state, parameters and backend information. -data Context = Context - { ctxConnection :: Backend -- ^ return the backend object associated with this context - , ctxParams :: Params - , ctxCiphers :: [Cipher] -- ^ prepared list of allowed ciphers according to parameters - , ctxState :: MVar TLSState - , ctxMeasurement :: IORef Measurement - , ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not. - , ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful. - , ctxNeedEmptyPacket :: IORef Bool -- ^ empty packet workaround for CBC guessability. - , ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello. - -- the flag will be set to false regardless of its initial value - -- after the first packet received. - , ctxTxState :: MVar RecordState -- ^ current tx state - , ctxRxState :: MVar RecordState -- ^ current rx state - , ctxHandshake :: MVar (Maybe HandshakeState) -- ^ optional handshake state - , ctxHooks :: IORef Hooks -- ^ hooks for this context - , ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state) - , ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state) - , ctxLockState :: MVar () -- ^ lock used during read/write when receiving and sending packet. - -- it is usually nested in a write or read lock. - } +class TLSParams a where + getTLSCommonParams :: a -> CommonParams + getTLSRole :: a -> Role + getCiphers :: a -> [Cipher] + doHandshake :: a -> Context -> IO () + doHandshakeWith :: a -> Context -> Handshake -> IO () --- deprecated types, setup as aliases for compatibility. -type TLSParams = Params -type TLSCtx = Context -type TLSLogging = Logging -type TLSCertificateUsage = CertificateUsage -type TLSCertificateRejectReason = CertificateRejectReason +instance TLSParams ClientParams where + getTLSCommonParams cparams = ( clientSupported cparams + , clientShared cparams + , clientCommonHooks cparams + ) + getTLSRole _ = ClientRole + getCiphers cparams = supportedCiphers $ clientSupported cparams + doHandshake = handshakeClient + doHandshakeWith = handshakeClientWith -updateMeasure :: Context -> (Measurement -> Measurement) -> IO () -updateMeasure ctx f = do - x <- readIORef (ctxMeasurement ctx) - writeIORef (ctxMeasurement ctx) $! f x - -withMeasure :: Context -> (Measurement -> IO a) -> IO a -withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f - -contextFlush :: Context -> IO () -contextFlush = backendFlush . ctxConnection - -contextClose :: Context -> IO () -contextClose = backendClose . ctxConnection - --- | Information about the current context -contextGetInformation :: Context -> IO (Maybe Information) -contextGetInformation ctx = do - ver <- usingState_ ctx $ gets stVersion - (cipher,comp) <- failOnEitherError $ runRxState ctx $ gets $ \st -> (stCipher st, stCompression st) - case (ver, cipher) of - (Just v, Just c) -> return $ Just $ Information v c comp - _ -> return Nothing - -contextSend :: Context -> Bytes -> IO () -contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b - -contextRecv :: Context -> Int -> IO Bytes -contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz - -ctxEOF :: Context -> IO Bool -ctxEOF ctx = readIORef $ ctxEOF_ ctx - -ctxHasSSLv2ClientHello :: Context -> IO Bool -ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx - -ctxDisableSSLv2ClientHello :: Context -> IO () -ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False - -setEOF :: Context -> IO () -setEOF ctx = writeIORef (ctxEOF_ ctx) True - -ctxEstablished :: Context -> IO Bool -ctxEstablished ctx = readIORef $ ctxEstablished_ ctx - -ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a -ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f - -setEstablished :: Context -> Bool -> IO () -setEstablished ctx v = writeIORef (ctxEstablished_ ctx) v - -ctxLogging :: Context -> Logging -ctxLogging = pLogging . ctxParams - --- | create a new context using the backend and parameters specified. -contextNew :: (MonadIO m, CPRG rng, HasBackend backend) - => backend -- ^ Backend abstraction with specific method to interact with the connection type. - -> Params -- ^ Parameters of the context. - -> rng -- ^ Random number generator associated with this context. - -> m Context -contextNew backend params rng = liftIO $ do - initializeBackend backend - let role = case roleParams params of - Client {} -> ClientRole - Server {} -> ServerRole - let st = newTLSState rng role - - stvar <- newMVar st - eof <- newIORef False - established <- newIORef False - stats <- newIORef newMeasurement - -- we enable the reception of SSLv2 ClientHello message only in the - -- server context, where we might be dealing with an old/compat client. - sslv2Compat <- newIORef (role == ServerRole) - needEmptyPacket <- newIORef False - hooks <- newIORef defaultHooks - tx <- newMVar newRecordState - rx <- newMVar newRecordState - hs <- newMVar Nothing +instance TLSParams ServerParams where + getTLSCommonParams sparams = ( serverSupported sparams + , serverShared sparams + , serverCommonHooks sparams + ) + getTLSRole _ = ServerRole -- on the server we filter our allowed ciphers here according -- to the credentials and DHE parameters loaded - let ciphers = case roleParams params of - Client {} -> pCiphers params - Server sParams -> filterServer sParams $ pCiphers params - lockWrite <- newMVar () - lockRead <- newMVar () - lockState <- newMVar () - - when (null ciphers) $ error "no ciphers available with those parameters" - - return $ Context - { ctxConnection = getBackend backend - , ctxParams = params - , ctxCiphers = ciphers - , ctxState = stvar - , ctxTxState = tx - , ctxRxState = rx - , ctxHandshake = hs - , ctxMeasurement = stats - , ctxEOF_ = eof - , ctxEstablished_ = established - , ctxSSLv2ClientHello = sslv2Compat - , ctxNeedEmptyPacket = needEmptyPacket - , ctxHooks = hooks - , ctxLockWrite = lockWrite - , ctxLockRead = lockRead - , ctxLockState = lockState - } - where filterServer sParams ciphers = filter authorizedCKE ciphers + getCiphers sparams = filter authorizedCKE (supportedCiphers $ serverSupported sparams) where authorizedCKE cipher = case cipherKeyExchange cipher of CipherKeyExchange_RSA -> canEncryptRSA @@ -277,26 +122,83 @@ contextNew backend params rng = liftIO $ do CipherKeyExchange_ECDH_RSA -> False CipherKeyExchange_ECDHE_ECDSA -> False - canDHE = isJust $ serverDHEParams sParams + canDHE = isJust $ serverDHEParams sparams canSignDSS = SignatureDSS `elem` signingAlgs canSignRSA = SignatureRSA `elem` signingAlgs canEncryptRSA = isJust $ credentialsFindForDecrypting creds signingAlgs = credentialsListSigningAlgorithms creds - creds = credentialsGet params + creds = sharedCredentials $ serverShared sparams + doHandshake = handshakeServer + doHandshakeWith = handshakeServerWith + +-- | create a new context using the backend and parameters specified. +contextNew :: (MonadIO m, CPRG rng, HasBackend backend, TLSParams params) + => backend -- ^ Backend abstraction with specific method to interact with the connection type. + -> params -- ^ Parameters of the context. + -> rng -- ^ Random number generator associated with this context. + -> m Context +contextNew backend params rng = liftIO $ do + initializeBackend backend + + let role = getTLSRole params + st = newTLSState rng role + (supported, shared, commonHooks) = getTLSCommonParams params + ciphers = getCiphers params + + when (null ciphers) $ error "no ciphers available with those parameters" + + stvar <- newMVar st + eof <- newIORef False + established <- newIORef False + stats <- newIORef newMeasurement + -- we enable the reception of SSLv2 ClientHello message only in the + -- server context, where we might be dealing with an old/compat client. + sslv2Compat <- newIORef (role == ServerRole) + needEmptyPacket <- newIORef False + hooks <- newIORef defaultHooks + tx <- newMVar newRecordState + rx <- newMVar newRecordState + hs <- newMVar Nothing + lockWrite <- newMVar () + lockRead <- newMVar () + lockState <- newMVar () + + return $ Context + { ctxConnection = getBackend backend + , ctxShared = shared + , ctxSupported = supported + , ctxCommonHooks = commonHooks + , ctxCiphers = ciphers + , ctxState = stvar + , ctxTxState = tx + , ctxRxState = rx + , ctxHandshake = hs + , ctxDoHandshake = doHandshake params + , ctxDoHandshakeWith = doHandshakeWith params + , ctxMeasurement = stats + , ctxEOF_ = eof + , ctxEstablished_ = established + , ctxSSLv2ClientHello = sslv2Compat + , ctxNeedEmptyPacket = needEmptyPacket + , ctxHooks = hooks + , ctxLockWrite = lockWrite + , ctxLockRead = lockRead + , ctxLockState = lockState + } -- | create a new context on an handle. -contextNewOnHandle :: (MonadIO m, CPRG rng) +contextNewOnHandle :: (MonadIO m, CPRG rng, TLSParams params) => Handle -- ^ Handle of the connection. - -> Params -- ^ Parameters of the context. + -> params -- ^ Parameters of the context. -> rng -- ^ Random number generator associated with this context. -> m Context contextNewOnHandle handle params st = contextNew handle params st {-# DEPRECATED contextNewOnHandle "use contextNew" #-} -- | create a new context on a socket. -contextNewOnSocket :: (MonadIO m, CPRG rng) +contextNewOnSocket :: (MonadIO m, CPRG rng, TLSParams params) => Socket -- ^ Socket of the connection. - -> Params -- ^ Parameters of the context. + -> params -- ^ Parameters of the context. -> rng -- ^ Random number generator associated with this context. -> m Context contextNewOnSocket sock params st = contextNew sock params st @@ -305,62 +207,3 @@ contextNewOnSocket sock params st = contextNew sock params st contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO () contextHookSetHandshakeRecv context f = liftIO $ modifyIORef (ctxHooks context) (\hooks -> hooks { hookRecvHandshake = f }) - -throwCore :: (MonadIO m, Exception e) => e -> m a -throwCore = liftIO . throwIO - -failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a -failOnEitherError f = do - ret <- f - case ret of - Left err -> throwCore err - Right r -> return r - -usingState :: Context -> TLSSt a -> IO (Either TLSError a) -usingState ctx f = - modifyMVar (ctxState ctx) $ \st -> - let (a, newst) = runTLSState f st - in newst `seq` return (newst, a) - -usingState_ :: Context -> TLSSt a -> IO a -usingState_ ctx f = failOnEitherError $ usingState ctx f - -usingHState :: Context -> HandshakeM a -> IO a -usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst -> - case mst of - Nothing -> throwCore $ Error_Misc "missing handshake" - Just st -> return $ swap (Just `fmap` runHandshake st f) - -getHState :: Context -> IO (Maybe HandshakeState) -getHState ctx = liftIO $ readMVar (ctxHandshake ctx) - -runTxState :: Context -> RecordM a -> IO (Either TLSError a) -runTxState ctx f = do - ver <- usingState_ ctx (getVersionWithDefault $ maximum $ pAllowedVersions $ ctxParams ctx) - modifyMVar (ctxTxState ctx) $ \st -> - case runRecordM f ver st of - Left err -> return (st, Left err) - Right (a, newSt) -> return (newSt, Right a) - -runRxState :: Context -> RecordM a -> IO (Either TLSError a) -runRxState ctx f = do - ver <- usingState_ ctx getVersion - modifyMVar (ctxRxState ctx) $ \st -> - case runRecordM f ver st of - Left err -> return (st, Left err) - Right (a, newSt) -> return (newSt, Right a) - -getStateRNG :: Context -> Int -> IO Bytes -getStateRNG ctx n = usingState_ ctx $ genRandom n - -withReadLock :: Context -> IO a -> IO a -withReadLock ctx f = withMVar (ctxLockRead ctx) (const f) - -withWriteLock :: Context -> IO a -> IO a -withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f) - -withRWLock :: Context -> IO a -> IO a -withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f - -withStateLock :: Context -> IO a -> IO a -withStateLock ctx f = withMVar (ctxLockState ctx) (const f) diff --git a/core/Network/TLS/Context/Internal.hs b/core/Network/TLS/Context/Internal.hs new file mode 100644 index 0000000..0488c34 --- /dev/null +++ b/core/Network/TLS/Context/Internal.hs @@ -0,0 +1,224 @@ +-- | +-- Module : Network.TLS.Context.Internal +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : experimental +-- Portability : unknown +-- +module Network.TLS.Context.Internal + ( + -- * Context configuration + ClientParams(..) + , ServerParams(..) + , defaultParamsClient + , SessionID + , SessionData(..) + , MaxFragmentEnum(..) + , Measurement(..) + + -- * Context object and accessor + , Context(..) + , Hooks(..) + , ctxEOF + , ctxHasSSLv2ClientHello + , ctxDisableSSLv2ClientHello + , ctxEstablished + , ctxLogging + , ctxWithHooks + , modifyHooks + , setEOF + , setEstablished + , contextFlush + , contextClose + , contextSend + , contextRecv + , updateMeasure + , withMeasure + , withReadLock + , withWriteLock + , withStateLock + , withRWLock + + -- * information + , Information(..) + , contextGetInformation + + -- * Using context states + , throwCore + , usingState + , usingState_ + , runTxState + , runRxState + , usingHState + , getHState + , getStateRNG + ) where + +import Network.TLS.Backend +import Network.TLS.Extension +import Network.TLS.Cipher +import Network.TLS.Struct +import Network.TLS.Compression (Compression) +import Network.TLS.State +import Network.TLS.Handshake.State +import Network.TLS.Hooks +import Network.TLS.Record.State +import Network.TLS.Parameters +import Network.TLS.Measurement +import qualified Data.ByteString as B + +import Control.Concurrent.MVar +import Control.Monad.State +import Control.Exception (throwIO, Exception()) +import Data.IORef +import Data.Tuple + + +-- | Information related to a running context, e.g. current cipher +data Information = Information + { infoVersion :: Version + , infoCipher :: Cipher + , infoCompression :: Compression + } deriving (Show,Eq) + +-- | A TLS Context keep tls specific state, parameters and backend information. +data Context = Context + { ctxConnection :: Backend -- ^ return the backend object associated with this context + , ctxSupported :: Supported + , ctxShared :: Shared + , ctxCommonHooks :: CommonHooks + , ctxCiphers :: [Cipher] -- ^ prepared list of allowed ciphers according to parameters + , ctxState :: MVar TLSState + , ctxMeasurement :: IORef Measurement + , ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not. + , ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful. + , ctxNeedEmptyPacket :: IORef Bool -- ^ empty packet workaround for CBC guessability. + , ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello. + -- the flag will be set to false regardless of its initial value + -- after the first packet received. + , ctxTxState :: MVar RecordState -- ^ current tx state + , ctxRxState :: MVar RecordState -- ^ current rx state + , ctxHandshake :: MVar (Maybe HandshakeState) -- ^ optional handshake state + , ctxDoHandshake :: Context -> IO () + , ctxDoHandshakeWith :: Context -> Handshake -> IO () + , ctxHooks :: IORef Hooks -- ^ hooks for this context + , ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state) + , ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state) + , ctxLockState :: MVar () -- ^ lock used during read/write when receiving and sending packet. + -- it is usually nested in a write or read lock. + } + +updateMeasure :: Context -> (Measurement -> Measurement) -> IO () +updateMeasure ctx f = do + x <- readIORef (ctxMeasurement ctx) + writeIORef (ctxMeasurement ctx) $! f x + +withMeasure :: Context -> (Measurement -> IO a) -> IO a +withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f + +contextFlush :: Context -> IO () +contextFlush = backendFlush . ctxConnection + +contextClose :: Context -> IO () +contextClose = backendClose . ctxConnection + +-- | Information about the current context +contextGetInformation :: Context -> IO (Maybe Information) +contextGetInformation ctx = do + ver <- usingState_ ctx $ gets stVersion + (cipher,comp) <- failOnEitherError $ runRxState ctx $ gets $ \st -> (stCipher st, stCompression st) + case (ver, cipher) of + (Just v, Just c) -> return $ Just $ Information v c comp + _ -> return Nothing + +contextSend :: Context -> Bytes -> IO () +contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b + +contextRecv :: Context -> Int -> IO Bytes +contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz + +ctxEOF :: Context -> IO Bool +ctxEOF ctx = readIORef $ ctxEOF_ ctx + +ctxHasSSLv2ClientHello :: Context -> IO Bool +ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx + +ctxDisableSSLv2ClientHello :: Context -> IO () +ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False + +setEOF :: Context -> IO () +setEOF ctx = writeIORef (ctxEOF_ ctx) True + +ctxEstablished :: Context -> IO Bool +ctxEstablished ctx = readIORef $ ctxEstablished_ ctx + +ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a +ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f + +modifyHooks :: Context -> (Hooks -> Hooks) -> IO () +modifyHooks ctx f = modifyIORef (ctxHooks ctx) f + +setEstablished :: Context -> Bool -> IO () +setEstablished ctx v = writeIORef (ctxEstablished_ ctx) v + +ctxLogging :: Context -> Logging +ctxLogging = logging . ctxCommonHooks + +throwCore :: (MonadIO m, Exception e) => e -> m a +throwCore = liftIO . throwIO + +failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a +failOnEitherError f = do + ret <- f + case ret of + Left err -> throwCore err + Right r -> return r + +usingState :: Context -> TLSSt a -> IO (Either TLSError a) +usingState ctx f = + modifyMVar (ctxState ctx) $ \st -> + let (a, newst) = runTLSState f st + in newst `seq` return (newst, a) + +usingState_ :: Context -> TLSSt a -> IO a +usingState_ ctx f = failOnEitherError $ usingState ctx f + +usingHState :: Context -> HandshakeM a -> IO a +usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst -> + case mst of + Nothing -> throwCore $ Error_Misc "missing handshake" + Just st -> return $ swap (Just `fmap` runHandshake st f) + +getHState :: Context -> IO (Maybe HandshakeState) +getHState ctx = liftIO $ readMVar (ctxHandshake ctx) + +runTxState :: Context -> RecordM a -> IO (Either TLSError a) +runTxState ctx f = do + ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx) + modifyMVar (ctxTxState ctx) $ \st -> + case runRecordM f ver st of + Left err -> return (st, Left err) + Right (a, newSt) -> return (newSt, Right a) + +runRxState :: Context -> RecordM a -> IO (Either TLSError a) +runRxState ctx f = do + ver <- usingState_ ctx getVersion + modifyMVar (ctxRxState ctx) $ \st -> + case runRecordM f ver st of + Left err -> return (st, Left err) + Right (a, newSt) -> return (newSt, Right a) + +getStateRNG :: Context -> Int -> IO Bytes +getStateRNG ctx n = usingState_ ctx $ genRandom n + +withReadLock :: Context -> IO a -> IO a +withReadLock ctx f = withMVar (ctxLockRead ctx) (const f) + +withWriteLock :: Context -> IO a -> IO a +withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f) + +withRWLock :: Context -> IO a -> IO a +withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f + +withStateLock :: Context -> IO a -> IO a +withStateLock ctx f = withMVar (ctxLockState ctx) (const f) diff --git a/core/Network/TLS/Core.hs b/core/Network/TLS/Core.hs index 8c36bd6..e2909cb 100644 --- a/core/Network/TLS/Core.hs +++ b/core/Network/TLS/Core.hs @@ -29,6 +29,7 @@ module Network.TLS.Core import Network.TLS.Context import Network.TLS.Struct import Network.TLS.State (getSession) +import Network.TLS.Parameters import Network.TLS.IO import Network.TLS.Session import Network.TLS.Handshake @@ -79,17 +80,22 @@ recvData ctx = liftIO $ do terminate err AlertLevel_Fatal InternalError (show err) process (Handshake [ch@(ClientHello {})]) = - -- on server context receiving a client hello == renegotiation + withRWLock ctx ((ctxDoHandshakeWith ctx) ctx ch) >> recvData ctx + {- case roleParams $ ctxParams ctx of Server sparams -> withRWLock ctx (handshakeServerWith sparams ctx ch) >> recvData ctx Client {} -> let reason = "unexpected client hello in client context" in terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason - process (Handshake [HelloRequest]) = + -} + process (Handshake [hr@HelloRequest]) = + withRWLock ctx ((ctxDoHandshakeWith ctx) ctx hr) >> recvData ctx + {- -- on client context, receiving a hello request == renegotiation case roleParams $ ctxParams ctx of Server {} -> let reason = "unexpected hello request in server context" in terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason Client cparams -> withRWLock ctx (handshakeClient cparams ctx) >> recvData ctx + -} process (Alert [(AlertLevel_Warning, CloseNotify)]) = tryBye >> setEOF ctx >> return B.empty process (Alert [(AlertLevel_Fatal, desc)]) = do @@ -107,7 +113,7 @@ recvData ctx = liftIO $ do session <- usingState_ ctx getSession case session of Session Nothing -> return () - Session (Just sid) -> withSessionManager (ctxParams ctx) (\s -> sessionInvalidate s sid) + Session (Just sid) -> sessionInvalidate (sharedSessionManager $ ctxShared ctx) sid catchException (sendPacket ctx $ Alert [(level, desc)]) (\_ -> return ()) setEOF ctx E.throwIO (Terminated False reason err) diff --git a/core/Network/TLS/Handshake.hs b/core/Network/TLS/Handshake.hs index 2baa77a..f232a31 100644 --- a/core/Network/TLS/Handshake.hs +++ b/core/Network/TLS/Handshake.hs @@ -7,11 +7,13 @@ -- module Network.TLS.Handshake ( handshake + , handshakeClientWith , handshakeServerWith , handshakeClient + , handshakeServer ) where -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.IO import Network.TLS.Util (catchException) @@ -26,11 +28,8 @@ import Control.Exception (fromException) -- | Handshake for a new TLS connection -- This is to be called at the beginning of a connection, and during renegotiation handshake :: MonadIO m => Context -> m () -handshake ctx = do - let handshakeF = case roleParams $ ctxParams ctx of - Server sparams -> handshakeServer sparams - Client cparams -> handshakeClient cparams - liftIO $ handleException $ withRWLock ctx (handshakeF ctx) +handshake ctx = + liftIO $ handleException $ withRWLock ctx (ctxDoHandshake ctx $ ctx) where handleException f = catchException f $ \exception -> do let tlserror = maybe (Error_Misc $ show exception) id $ fromException exception setEstablished ctx False diff --git a/core/Network/TLS/Handshake/Certificate.hs b/core/Network/TLS/Handshake/Certificate.hs index b57cbaf..730130b 100644 --- a/core/Network/TLS/Handshake/Certificate.hs +++ b/core/Network/TLS/Handshake/Certificate.hs @@ -10,8 +10,9 @@ module Network.TLS.Handshake.Certificate , rejectOnException ) where -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct +import Network.TLS.X509 import Control.Monad.State import Control.Exception (SomeException) @@ -26,5 +27,5 @@ certificateRejected CertificateRejectUnknownCA = certificateRejected (CertificateRejectOther s) = throwCore $ Error_Protocol ("certificate rejected: " ++ s, True, CertificateUnknown) -rejectOnException :: SomeException -> IO TLSCertificateUsage +rejectOnException :: SomeException -> IO CertificateUsage rejectOnException e = return $ CertificateUsageReject $ CertificateRejectOther $ show e diff --git a/core/Network/TLS/Handshake/Client.hs b/core/Network/TLS/Handshake/Client.hs index 1ba28d2..59e32d9 100644 --- a/core/Network/TLS/Handshake/Client.hs +++ b/core/Network/TLS/Handshake/Client.hs @@ -8,10 +8,12 @@ -- module Network.TLS.Handshake.Client ( handshakeClient + , handshakeClientWith ) where import Network.TLS.Crypto -import Network.TLS.Context +import Network.TLS.Context.Internal +import Network.TLS.Parameters import Network.TLS.Struct import Network.TLS.Cipher import Network.TLS.Compression @@ -41,6 +43,10 @@ import Network.TLS.Handshake.Signature import Network.TLS.Handshake.Key import Network.TLS.Handshake.State +handshakeClientWith :: ClientParams -> Context -> Handshake -> IO () +handshakeClientWith cparams ctx HelloRequest = handshakeClient cparams ctx +handshakeClientWith _ _ _ = throwCore $ Error_Protocol ("unexpected handshake message received in handshakeClientWith", True, HandshakeFailure) + -- client part of handshake. send a bunch of handshake of client -- values intertwined with response from the server. handshakeClient :: ClientParams -> Context -> IO () @@ -50,31 +56,32 @@ handshakeClient cparams ctx = do recvServerHello sentExtensions sessionResuming <- usingState_ ctx isSessionResuming if sessionResuming - then sendChangeCipherAndFinish ctx ClientRole + then sendChangeCipherAndFinish sendMaybeNPN ctx ClientRole else do sendClientData cparams ctx - sendChangeCipherAndFinish ctx ClientRole + sendChangeCipherAndFinish sendMaybeNPN ctx ClientRole recvChangeCipherAndFinish ctx handshakeTerminate ctx - where params = ctxParams ctx - ciphers = pCiphers params - compressions = pCompressions params + where ciphers = ctxCiphers ctx + compressions = supportedCompressions $ ctxSupported ctx getExtensions = sequence [sniExtension,secureReneg,npnExtention] >>= return . catMaybes toExtensionRaw :: Extension e => e -> ExtensionRaw toExtensionRaw ext = (extensionID ext, extensionEncode ext) secureReneg = - if pUseSecureRenegotiation params + if supportedSecureRenegotiation $ ctxSupported ctx then usingState_ ctx (getVerifiedData ClientRole) >>= \vd -> return $ Just $ toExtensionRaw $ SecureRenegotiation vd Nothing else return Nothing - npnExtention = if isJust $ onNPNServerSuggest cparams + npnExtention = if isJust $ onNPNServerSuggest $ clientHooks cparams then return $ Just $ toExtensionRaw $ NextProtocolNegotiation [] else return Nothing - sniExtension = return ((\h -> toExtensionRaw $ ServerName [(ServerNameHostName h)]) <$> clientUseServerName cparams) + sniExtension = if clientUseServerNameIndication cparams + then return $ Just $ toExtensionRaw $ ServerName [ServerNameHostName $ fst $ clientServerIdentification cparams] + else return Nothing sendClientHello = do crand <- getStateRNG ctx 32 >>= return . ClientRandom let clientSession = Session . maybe Nothing (Just . fst) $ clientWantSessionResume cparams - highestVer = maximum $ pAllowedVersions params + highestVer = maximum $ supportedVersions $ ctxSupported ctx extensions <- getExtensions startHandshake ctx highestVer crand usingState_ ctx $ setVersionIfUnset highestVer @@ -84,6 +91,18 @@ handshakeClient cparams ctx = do ] return $ map fst extensions + sendMaybeNPN = do + suggest <- usingState_ ctx $ getServerNextProtocolSuggest + case (onNPNServerSuggest $ clientHooks cparams, suggest) of + -- client offered, server picked up. send NPN handshake. + (Just io, Just protos) -> do proto <- liftIO $ io protos + sendPacket ctx (Handshake [HsNextProtocolNegotiation proto]) + usingState_ ctx $ setNegotiatedProtocol proto + -- client offered, server didn't pick up. do nothing. + (Just _, Nothing) -> return () + -- client didn't offer. do nothing. + (Nothing, _) -> return () + recvServerHello sentExts = runRecvState ctx (RecvStateHandshake $ onServerHello ctx cparams sentExts) -- | send client Data after receiving all server data (hello/certificates/key). @@ -108,7 +127,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi return () Just req -> do - certChain <- liftIO $ onCertificateRequest cparams req `catchException` + certChain <- liftIO $ (onCertificateRequest $ clientHooks cparams) req `catchException` throwMiscErrorOnException "certificate request callback failed" usingHState ctx $ setClientCertSent False @@ -176,7 +195,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi malg <- case usedVersion of TLS12 -> do Just (_, Just hashSigs, _) <- usingHState ctx $ getClientCertRequest - let suppHashSigs = pHashSignatures $ ctxParams ctx + let suppHashSigs = supportedHashSignatures $ ctxSupported ctx hashSigs' = filter (\ a -> a `elem` hashSigs) suppHashSigs when (null hashSigs') $ @@ -218,14 +237,14 @@ throwMiscErrorOnException msg e = onServerHello :: Context -> ClientParams -> [ExtensionID] -> Handshake -> IO (RecvState IO) onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cipher compression exts) = do when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion) - case find ((==) rver) allowedvers of + case find ((==) rver) (supportedVersions $ ctxSupported ctx) of Nothing -> throwCore $ Error_Protocol ("server version " ++ show rver ++ " is not supported", True, ProtocolVersion) Just _ -> return () -- find the compression and cipher methods that the server want to use. - cipherAlg <- case find ((==) cipher . cipherID) ciphers of + cipherAlg <- case find ((==) cipher . cipherID) (ctxCiphers ctx) of Nothing -> throwCore $ Error_Protocol ("server choose unknown cipher", True, HandshakeFailure) Just alg -> return alg - compressAlg <- case find ((==) compression . compressionID) compressions of + compressAlg <- case find ((==) compression . compressionID) (supportedCompressions $ ctxSupported ctx) of Nothing -> throwCore $ Error_Protocol ("server choose unknown compression", True, HandshakeFailure) Just alg -> return alg @@ -251,25 +270,25 @@ onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cip _ -> return () case resumingSession of - Nothing -> return $ RecvStateHandshake (processCertificate ctx) + Nothing -> return $ RecvStateHandshake (processCertificate cparams ctx) Just sessionData -> do usingHState ctx (setMasterSecret rver ClientRole $ sessionSecret sessionData) return $ RecvStateNext expectChangeCipher - where params = ctxParams ctx - allowedvers = pAllowedVersions params - ciphers = pCiphers params - compressions = pCompressions params onServerHello _ _ _ p = unexpected (show p) (Just "server hello") -processCertificate :: Context -> Handshake -> IO (RecvState IO) -processCertificate ctx (Certificates certs) = do - usage <- liftIO $ catchException (onCertificatesRecv params certs) rejectOnException +processCertificate :: ClientParams -> Context -> Handshake -> IO (RecvState IO) +processCertificate cparams ctx (Certificates certs) = do + usage <- catchException (wrapCertificateChecks <$> checkCert) rejectOnException case usage of CertificateUsageAccept -> return () CertificateUsageReject reason -> certificateRejected reason return $ RecvStateHandshake (processServerKeyExchange ctx) - where params = ctxParams ctx -processCertificate ctx p = processServerKeyExchange ctx p + where shared = clientShared cparams + checkCert = (onServerCertificate $ clientHooks cparams) (sharedCAStore shared) + (sharedValidationCache shared) + (clientServerIdentification cparams) + certs +processCertificate _ ctx p = processServerKeyExchange ctx p expectChangeCipher :: Packet -> IO (RecvState IO) expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish diff --git a/core/Network/TLS/Handshake/Common.hs b/core/Network/TLS/Handshake/Common.hs index 3e0c675..36a0e0a 100644 --- a/core/Network/TLS/Handshake/Common.hs +++ b/core/Network/TLS/Handshake/Common.hs @@ -16,7 +16,8 @@ module Network.TLS.Handshake.Common import Control.Concurrent.MVar -import Network.TLS.Context +import Network.TLS.Parameters +import Network.TLS.Context.Internal import Network.TLS.Session import Network.TLS.Struct import Network.TLS.IO @@ -45,8 +46,8 @@ unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" e newSession :: Context -> IO Session newSession ctx - | pUseSession $ ctxParams ctx = getStateRNG ctx 32 >>= return . Session . Just - | otherwise = return $ Session Nothing + | supportedSession $ ctxSupported ctx = getStateRNG ctx 32 >>= return . Session . Just + | otherwise = return $ Session Nothing -- | when a new handshake is done, wrap up & clean up. handshakeTerminate :: Context -> IO () @@ -56,7 +57,7 @@ handshakeTerminate ctx = do case session of Session (Just sessionId) -> do sessionData <- getSessionData ctx - withSessionManager (ctxParams ctx) (\s -> liftIO $ sessionEstablish s sessionId (fromJust "session-data" sessionData)) + liftIO $ sessionEstablish (sharedSessionManager $ ctxShared ctx) sessionId (fromJust "session-data" sessionData) _ -> return () -- forget all handshake data now and reset bytes counters. liftIO $ modifyMVar_ (ctxHandshake ctx) (return . const Nothing) @@ -65,24 +66,14 @@ handshakeTerminate ctx = do setEstablished ctx True return () -sendChangeCipherAndFinish :: Context -> Role -> IO () -sendChangeCipherAndFinish ctx role = do +sendChangeCipherAndFinish :: IO () -- ^ message possibly sent between ChangeCipherSpec and Finished. + -> Context + -> Role + -> IO () +sendChangeCipherAndFinish betweenCall ctx role = do sendPacket ctx ChangeCipherSpec - - when (role == ClientRole) $ do - let cparams = getClientParams $ ctxParams ctx - suggest <- usingState_ ctx $ getServerNextProtocolSuggest - case (onNPNServerSuggest cparams, suggest) of - -- client offered, server picked up. send NPN handshake. - (Just io, Just protos) -> do proto <- liftIO $ io protos - sendPacket ctx (Handshake [HsNextProtocolNegotiation proto]) - usingState_ ctx $ setNegotiatedProtocol proto - -- client offered, server didn't pick up. do nothing. - (Just _, Nothing) -> return () - -- client didn't offer. do nothing. - (Nothing, _) -> return () + betweenCall liftIO $ contextFlush ctx - cf <- usingState_ ctx getVersion >>= \ver -> usingHState ctx $ getHandshakeDigest ver role sendPacket ctx (Handshake [Finished cf]) liftIO $ contextFlush ctx diff --git a/core/Network/TLS/Handshake/Key.hs b/core/Network/TLS/Handshake/Key.hs index 3ea65b6..fd6eb9c 100644 --- a/core/Network/TLS/Handshake/Key.hs +++ b/core/Network/TLS/Handshake/Key.hs @@ -22,7 +22,7 @@ import Network.TLS.Handshake.State import Network.TLS.State (withRNG, getVersion) import Network.TLS.Crypto import Network.TLS.Types -import Network.TLS.Context +import Network.TLS.Context.Internal {- if the RSA encryption fails we just return an empty bytestring, and let the protocol - fail by itself; however it would be probably better to just report it since it's an internal problem. diff --git a/core/Network/TLS/Handshake/Process.hs b/core/Network/TLS/Handshake/Process.hs index 030ffb6..ff95a1a 100644 --- a/core/Network/TLS/Handshake/Process.hs +++ b/core/Network/TLS/Handshake/Process.hs @@ -23,7 +23,7 @@ import Network.TLS.Util import Network.TLS.Packet import Network.TLS.Struct import Network.TLS.State -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Crypto import Network.TLS.Handshake.State import Network.TLS.Handshake.Key diff --git a/core/Network/TLS/Handshake/Server.hs b/core/Network/TLS/Handshake/Server.hs index f2df07e..b66279b 100644 --- a/core/Network/TLS/Handshake/Server.hs +++ b/core/Network/TLS/Handshake/Server.hs @@ -11,7 +11,8 @@ module Network.TLS.Handshake.Server , handshakeServerWith ) where -import Network.TLS.Context +import Network.TLS.Parameters +import Network.TLS.Context.Internal import Network.TLS.Session import Network.TLS.Struct import Network.TLS.Cipher @@ -80,7 +81,7 @@ handshakeServer sparams ctx = liftIO $ do handshakeServerWith :: ServerParams -> Context -> Handshake -> IO () handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientSession ciphers compressions exts _) = do -- check if policy allow this new handshake to happens - handshakeAuthorized <- withMeasure ctx (onHandshake $ ctxParams ctx) + handshakeAuthorized <- withMeasure ctx (onHandshake $ ctxCommonHooks ctx) unless handshakeAuthorized (throwCore $ Error_HandshakePolicy "server: handshake denied") updateMeasure ctx incrementNbHandshakes @@ -88,7 +89,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS processHandshake ctx clientHello when (clientVersion == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion) - chosenVersion <- case findHighestVersionFrom clientVersion (pAllowedVersions params) of + chosenVersion <- case findHighestVersionFrom clientVersion (supportedVersions $ ctxSupported ctx) of Nothing -> throwCore $ Error_Protocol ("client version " ++ show clientVersion ++ " is not supported", True, ProtocolVersion) Just v -> return v @@ -98,8 +99,8 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS Error_Protocol ("no compression in common with the client", True, HandshakeFailure) let ciphersFilteredVersion = filter (cipherAllowedForVersion chosenVersion) commonCiphers - usedCipher = (onCipherChoosing sparams) chosenVersion ciphersFilteredVersion - creds = credentialsGet params + usedCipher = (onCipherChoosing $ serverHooks sparams) chosenVersion ciphersFilteredVersion + creds = sharedCredentials $ ctxShared ctx cred <- case cipherKeyExchange usedCipher of CipherKeyExchange_RSA -> return $ credentialsFindForDecrypting creds CipherKeyExchange_DH_Anon -> return $ Nothing @@ -108,20 +109,19 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS _ -> throwCore $ Error_Protocol ("key exchange algorithm not implemented", True, HandshakeFailure) resumeSessionData <- case clientSession of - (Session (Just clientSessionId)) -> withSessionManager params (\s -> liftIO $ sessionResume s clientSessionId) + (Session (Just clientSessionId)) -> liftIO $ sessionResume (sharedSessionManager $ ctxShared ctx) clientSessionId (Session Nothing) -> return Nothing doHandshake sparams cred ctx chosenVersion usedCipher usedCompression clientSession resumeSessionData exts where - params = ctxParams ctx commonCipherIDs = intersect ciphers (map cipherID $ ctxCiphers ctx) commonCiphers = filter (flip elem commonCipherIDs . cipherID) (ctxCiphers ctx) - commonCompressions = compressionIntersectID (pCompressions params) compressions + commonCompressions = compressionIntersectID (supportedCompressions $ ctxSupported ctx) compressions usedCompression = head commonCompressions -handshakeServerWith _ _ _ = fail "unexpected handshake type received. expecting client hello" +handshakeServerWith _ _ _ = throwCore $ Error_Protocol ("unexpected handshake message received in handshakeServerWith", True, HandshakeFailure) doHandshake :: ServerParams -> Maybe Credential -> Context -> Version -> Cipher -> Compression -> Session -> Maybe SessionData @@ -133,13 +133,13 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes liftIO $ contextFlush ctx -- Receive client info until client Finished. recvClientData sparams ctx - sendChangeCipherAndFinish ctx ServerRole + sendChangeCipherAndFinish (return ()) ctx ServerRole Just sessionData -> do usingState_ ctx (setSession clientSession True) serverhello <- makeServerHello clientSession sendPacket ctx $ Handshake [serverhello] usingHState ctx $ setMasterSecret chosenVersion ServerRole $ sessionSecret sessionData - sendChangeCipherAndFinish ctx ServerRole + sendChangeCipherAndFinish (return ()) ctx ServerRole recvChangeCipherAndFinish ctx handshakeTerminate ctx where @@ -169,7 +169,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes else return [] nextProtocols <- if clientRequestedNPN - then liftIO $ onSuggestNextProtocols sparams + then liftIO $ onSuggestNextProtocols $ serverHooks sparams else return Nothing npnExt <- case nextProtocols of Just protos -> do usingState_ ctx $ do setExtensionNPN True @@ -212,7 +212,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes let certTypes = [ CertificateType_RSA_Sign ] hashSigs = if usedVersion < TLS12 then Nothing - else Just (pHashSignatures $ ctxParams ctx) + else Just (supportedHashSignatures $ ctxSupported ctx) creq = CertRequest certTypes hashSigs (map extractCAname $ serverCACertificates sparams) usingHState ctx $ setCertReqSent True @@ -240,7 +240,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes usedVersion <- usingState_ ctx getVersion let mhash = case usedVersion of - TLS12 -> case filter ((==) sigAlg . snd) $ pHashSignatures $ ctxParams ctx of + TLS12 -> case filter ((==) sigAlg . snd) $ supportedHashSignatures $ ctxSupported ctx of [] -> error ("no hash signature for " ++ show sigAlg) x:_ -> Just (fst x) _ -> Nothing @@ -269,7 +269,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC -- Call application callback to see whether the -- certificate chain is acceptable. -- - usage <- liftIO $ catchException (onClientCertificate sparams certs) rejectOnException + usage <- liftIO $ catchException (onClientCertificate (serverHooks sparams) certs) rejectOnException case usage of CertificateUsageAccept -> return () CertificateUsageReject reason -> certificateRejected reason @@ -320,7 +320,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC -- the signature is wrong. In either case, -- ask the application if it wants to -- proceed, we will do that. - res <- liftIO $ onUnverifiedClientCert sparams + res <- liftIO $ onUnverifiedClientCert (serverHooks sparams) if res then do -- When verification fails, but the diff --git a/core/Network/TLS/Handshake/Signature.hs b/core/Network/TLS/Handshake/Signature.hs index 175bda8..e055c7d 100644 --- a/core/Network/TLS/Handshake/Signature.hs +++ b/core/Network/TLS/Handshake/Signature.hs @@ -18,7 +18,7 @@ module Network.TLS.Handshake.Signature import Crypto.PubKey.HashDescr import Network.TLS.Crypto -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.Packet (generateCertificateVerify_SSL, encodeSignedDHParams) import Network.TLS.State diff --git a/core/Network/TLS/Hooks.hs b/core/Network/TLS/Hooks.hs index 54e092f..4413967 100644 --- a/core/Network/TLS/Hooks.hs +++ b/core/Network/TLS/Hooks.hs @@ -7,13 +7,13 @@ -- module Network.TLS.Hooks ( Logging(..) - , defaultLogging , Hooks(..) , defaultHooks ) where import qualified Data.ByteString as B import Network.TLS.Struct (Header, Handshake(..)) +import Data.Default.Class -- | Hooks for logging data Logging = Logging @@ -31,6 +31,9 @@ defaultLogging = Logging , loggingIORecv = (\_ _ -> return ()) } +instance Default Logging where + def = defaultLogging + -- | A collection of hooks actions. data Hooks = Hooks { hookRecvHandshake :: Handshake -> IO Handshake @@ -41,3 +44,5 @@ defaultHooks = Hooks { hookRecvHandshake = \hs -> return hs } +instance Default Hooks where + def = defaultHooks diff --git a/core/Network/TLS/IO.hs b/core/Network/TLS/IO.hs index 1eb1050..39c0ba4 100644 --- a/core/Network/TLS/IO.hs +++ b/core/Network/TLS/IO.hs @@ -13,10 +13,11 @@ module Network.TLS.IO , recvPacket ) where -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet +import Network.TLS.Hooks import Network.TLS.Sending import Network.TLS.Receiving import qualified Data.ByteString as B diff --git a/core/Network/TLS/Parameters.hs b/core/Network/TLS/Parameters.hs index 79f806e..c1544ee 100644 --- a/core/Network/TLS/Parameters.hs +++ b/core/Network/TLS/Parameters.hs @@ -5,33 +5,23 @@ -- Stability : experimental -- Portability : unknown -- --- extension RecordWildCards only needed because of some GHC bug --- relative to insufficient polymorphic field -{-# LANGUAGE RecordWildCards #-} module Network.TLS.Parameters ( - -- * Parameters - Params(..) - , RoleParams(..) - , ClientParams(..) + ClientParams(..) , ServerParams(..) - , updateClientParams - , updateServerParams - , Logging(..) - , SessionID - , SessionData(..) + , CommonParams + , ClientHooks(..) + , ServerHooks(..) + , CommonHooks(..) + , Supported(..) + , Shared(..) + -- * special default + , defaultParamsClient + -- * Parameters , MaxFragmentEnum(..) - , Measurement(..) + , Logging(..) , CertificateUsage(..) , CertificateRejectReason(..) - , defaultLogging - , defaultParamsClient - , defaultParamsServer - , withSessionManager - , setSessionManager - , getClientParams - , getServerParams - , credentialsGet ) where import Network.BSD (HostName) @@ -48,15 +38,137 @@ import Network.TLS.Hooks import Network.TLS.Measurement import Network.TLS.X509 import Data.Monoid -import Data.List (intercalate) +import Data.Default.Class import qualified Data.ByteString as B -data ClientParams = ClientParams - { clientUseMaxFragmentLength :: Maybe MaxFragmentEnum - , clientUseServerName :: Maybe HostName - , clientWantSessionResume :: Maybe (SessionID, SessionData) -- ^ try to establish a connection using this session. +type CommonParams = (Supported, Shared, CommonHooks) - -- | This action is called when the server sends a +data ClientParams = ClientParams + { clientUseMaxFragmentLength :: Maybe MaxFragmentEnum + -- | Define the name of the server, along with an extra service identification blob. + -- this is important that the hostname part is properly filled for security reason, + -- as it allow to properly associate the remote side with the given certificate + -- during a handshake. + -- + -- The extra blob is useful to differentiate services running on the same host, but that + -- might have different certificates given. It's only used as part of the X509 validation + -- infrastructure. + , clientServerIdentification :: (HostName, Bytes) + -- | Allow the use of the Server Name Indication TLS extension during handshake, which allow + -- the client to specify which host name, it's trying to access. This is useful to distinguish + -- CNAME aliasing (e.g. web virtual host). + , clientUseServerNameIndication :: Bool + -- | try to establish a connection using this session. + , clientWantSessionResume :: Maybe (SessionID, SessionData) + , clientShared :: Shared + , clientHooks :: ClientHooks + , clientCommonHooks :: CommonHooks + , clientSupported :: Supported + } deriving (Show) + +defaultParamsClient :: HostName -> Bytes -> ClientParams +defaultParamsClient serverName serverId = ClientParams + { clientWantSessionResume = Nothing + , clientUseMaxFragmentLength = Nothing + , clientServerIdentification = (serverName, serverId) + , clientUseServerNameIndication = True + , clientShared = def + , clientHooks = def + , clientCommonHooks = def + , clientSupported = def + } + +data ServerParams = ServerParams + { -- | request a certificate from client. + serverWantClientCert :: Bool + + -- | This is a list of certificates from which the + -- disinguished names are sent in certificate request + -- messages. For TLS1.0, it should not be empty. + , serverCACertificates :: [SignedCertificate] + + -- | Server Optional Diffie Hellman parameters. If this value is not + -- properly set, no Diffie Hellman key exchange will take place. + , serverDHEParams :: Maybe DHParams + + , serverShared :: Shared + , serverHooks :: ServerHooks + , serverCommonHooks :: CommonHooks + , serverSupported :: Supported + } deriving (Show) + +defaultParamsServer :: ServerParams +defaultParamsServer = ServerParams + { serverWantClientCert = False + , serverCACertificates = [] + , serverDHEParams = Nothing + , serverHooks = def + , serverShared = def + , serverCommonHooks = def + , serverSupported = def + } + +instance Default ServerParams where + def = defaultParamsServer + +-- | List all the supported algorithms, versions, ciphers, etc supported. +data Supported = Supported + { + -- | Supported Versions by this context + -- On the client side, the highest version will be used to establish the connection. + -- On the server side, the highest version that is less or equal than the client version will be chosed. + supportedVersions :: [Version] + -- | Supported cipher methods + , supportedCiphers :: [Cipher] + -- | supported compressions methods + , supportedCompressions :: [Compression] + -- | All supported hash/signature algorithms pair for client + -- certificate verification, ordered by decreasing priority. + , supportedHashSignatures :: [HashAndSignatureAlgorithm] + -- | Set if we support secure renegotiation. + , supportedSecureRenegotiation :: Bool + -- | Set if we support session. + , supportedSession :: Bool + } deriving (Show,Eq) + +defaultSupported :: Supported +defaultSupported = Supported + { supportedVersions = [TLS10,TLS11,TLS12] + , supportedCiphers = [] + , supportedCompressions = [nullCompression] + , supportedHashSignatures = [ (Struct.HashSHA512, SignatureRSA) + , (Struct.HashSHA384, SignatureRSA) + , (Struct.HashSHA256, SignatureRSA) + , (Struct.HashSHA224, SignatureRSA) + , (Struct.HashSHA1, SignatureDSS) + ] + , supportedSecureRenegotiation = True + , supportedSession = True + } + +instance Default Supported where + def = defaultSupported + +data Shared = Shared + { sharedCredentials :: Credentials + , sharedSessionManager :: SessionManager + , sharedCAStore :: CertificateStore + , sharedValidationCache :: ValidationCache + } + +instance Show Shared where + show _ = "Shared" +instance Default Shared where + def = Shared + { sharedCAStore = mempty + , sharedCredentials = mempty + , sharedSessionManager = noSessionManager + , sharedValidationCache = def + } + +-- | A set of callbacks run by the clients for various corners of TLS establishment +data ClientHooks = ClientHooks + { -- | This action is called when the server sends a -- certificate request. The parameter is the information -- from the request. The action should select a certificate -- chain of one of the given certificate types where the @@ -75,24 +187,32 @@ data ClientParams = ClientParams -- Returning a certificate chain not matching the -- distinguished names may lead to problems or not, -- depending whether the server accepts it. - , onCertificateRequest :: ([CertificateType], + onCertificateRequest :: ([CertificateType], Maybe [HashAndSignatureAlgorithm], [DistinguishedName]) -> IO (Maybe (CertificateChain, PrivKey)) , onNPNServerSuggest :: Maybe ([B.ByteString] -> IO B.ByteString) + , onServerCertificate :: CertificateStore -> ValidationCache -> ServiceID -> CertificateChain -> IO [FailedReason] } -data ServerParams = ServerParams - { serverWantClientCert :: Bool -- ^ request a certificate from client. +defaultClientHooks :: ClientHooks +defaultClientHooks = ClientHooks + { onCertificateRequest = \ _ -> return Nothing + , onNPNServerSuggest = Nothing + , onServerCertificate = validateDefault + } - -- | This is a list of certificates from which the - -- disinguished names are sent in certificate request - -- messages. For TLS1.0, it should not be empty. - , serverCACertificates :: [SignedCertificate] +instance Show ClientHooks where + show _ = "ClientHooks" +instance Default ClientHooks where + def = defaultClientHooks +-- | A set of callbacks run by the server for various corners of the TLS establishment +data ServerHooks = ServerHooks + { -- | This action is called when a client certificate chain -- is received from the client. When it returns a -- CertificateUsageReject value, the handshake is aborted. - , onClientCertificate :: CertificateChain -> IO CertificateUsage + onClientCertificate :: CertificateChain -> IO CertificateUsage -- | This action is called when the client certificate -- cannot be verified. A 'Nothing' argument indicates a @@ -109,113 +229,35 @@ data ServerParams = ServerParams -- The client cipher list cannot be empty. , onCipherChoosing :: Version -> [Cipher] -> Cipher - -- | Server Optional Diffie Hellman parameters - , serverDHEParams :: Maybe DHParams - -- | suggested next protocols accoring to the next protocol negotiation extension. - , onSuggestNextProtocols :: IO (Maybe [B.ByteString]) + , onSuggestNextProtocols :: IO (Maybe [B.ByteString]) } -data RoleParams = Client ClientParams | Server ServerParams +defaultServerHooks :: ServerHooks +defaultServerHooks = ServerHooks + { onCipherChoosing = \_ -> head + , onClientCertificate = \_ -> return $ CertificateUsageReject $ CertificateRejectOther "no client certificates expected" + , onUnverifiedClientCert = return False + , onSuggestNextProtocols = return Nothing + } -data Params = Params - { pAllowedVersions :: [Version] -- ^ allowed versions that we can use. - -- the default version used for connection is the highest version in the list - , pCiphers :: [Cipher] -- ^ all ciphers supported ordered by priority. - , pCompressions :: [Compression] -- ^ all compression supported ordered by priority. - , pHashSignatures :: [HashAndSignatureAlgorithm] -- ^ All supported hash/signature algorithms pair for client certificate verification, ordered by decreasing priority. - , pUseSecureRenegotiation :: Bool -- ^ notify that we want to use secure renegotation - , pUseSession :: Bool -- ^ generate new session if specified - , pCertificates :: Maybe (CertificateChain, Maybe PrivKey) -- ^ the cert chain for this context with the associated keys if any. - , pCredentials :: Credentials -- ^ credentials - , pLogging :: Logging -- ^ callback for logging +instance Show ServerHooks where + show _ = "ClientHooks" +instance Default ServerHooks where + def = defaultServerHooks + +data CommonHooks = CommonHooks + { onCertificatesRecv :: CertificateChain -> IO CertificateUsage -- ^ callback to verify received cert chain. , onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake - , onCertificatesRecv :: CertificateChain -> IO CertificateUsage -- ^ callback to verify received cert chain. - , pSessionManager :: SessionManager - , roleParams :: RoleParams - } -{-# DEPRECATED pCertificates "use pCredentials instead of pCertificates. removed in tls-1.3" #-} - -credentialsGet :: Params -> Credentials -credentialsGet params = pCredentials params `mappend` - case pCertificates params of - Just (cchain, Just priv) -> Credentials [(cchain, priv)] - _ -> Credentials [] - --- | Set a new session manager in a parameters structure. -setSessionManager :: SessionManager -> Params -> Params -setSessionManager manager (Params {..}) = Params { pSessionManager = manager, .. } - -withSessionManager :: Params -> (SessionManager -> a) -> a -withSessionManager (Params { pSessionManager = man }) f = f man - -getClientParams :: Params -> ClientParams -getClientParams params = - case roleParams params of - Client clientParams -> clientParams - _ -> error "server params in client context" - -getServerParams :: Params -> ServerParams -getServerParams params = - case roleParams params of - Server serverParams -> serverParams - _ -> error "client params in server context" - -defaultParamsClient :: Params -defaultParamsClient = Params - { pAllowedVersions = [TLS10,TLS11,TLS12] - , pCiphers = [] - , pCompressions = [nullCompression] - , pHashSignatures = [ (Struct.HashSHA512, SignatureRSA) - , (Struct.HashSHA384, SignatureRSA) - , (Struct.HashSHA256, SignatureRSA) - , (Struct.HashSHA224, SignatureRSA) - , (Struct.HashSHA1, SignatureDSS) - ] - , pUseSecureRenegotiation = True - , pUseSession = True - , pCertificates = Nothing - , pCredentials = mempty - , pLogging = defaultLogging - , onHandshake = (\_ -> return True) - , onCertificatesRecv = (\_ -> return CertificateUsageAccept) - , pSessionManager = noSessionManager - , roleParams = Client $ ClientParams - { clientWantSessionResume = Nothing - , clientUseMaxFragmentLength = Nothing - , clientUseServerName = Nothing - , onCertificateRequest = \ _ -> return Nothing - , onNPNServerSuggest = Nothing - } + , logging :: Logging -- ^ callback for logging } -defaultParamsServer :: Params -defaultParamsServer = defaultParamsClient { roleParams = Server role } - where role = ServerParams - { serverWantClientCert = False - , onCipherChoosing = \_ -> head - , serverCACertificates = [] - , serverDHEParams = Nothing - , onClientCertificate = \ _ -> return $ CertificateUsageReject $ CertificateRejectOther "no client certificates expected" - , onUnverifiedClientCert = return False - , onSuggestNextProtocols = return Nothing - } +instance Show CommonHooks where + show _ = "CommonHooks" -updateRoleParams :: (ClientParams -> ClientParams) -> (ServerParams -> ServerParams) -> Params -> Params -updateRoleParams fc fs params = case roleParams params of - Client c -> params { roleParams = Client (fc c) } - Server s -> params { roleParams = Server (fs s) } - -updateClientParams :: (ClientParams -> ClientParams) -> Params -> Params -updateClientParams f = updateRoleParams f id - -updateServerParams :: (ServerParams -> ServerParams) -> Params -> Params -updateServerParams f = updateRoleParams id f - -instance Show Params where - show p = "Params { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v) - [ ("allowedVersions", show $ pAllowedVersions p) - , ("ciphers", show $ pCiphers p) - , ("compressions", show $ pCompressions p) - , ("certificates", show $ pCertificates p) - ]) ++ " }" +instance Default CommonHooks where + def = CommonHooks + { onCertificatesRecv = \_ -> return CertificateUsageAccept + , logging = def + , onHandshake = \_ -> return True + } diff --git a/core/Network/TLS/Receiving.hs b/core/Network/TLS/Receiving.hs index cdd2a5e..f835f6d 100644 --- a/core/Network/TLS/Receiving.hs +++ b/core/Network/TLS/Receiving.hs @@ -16,7 +16,7 @@ import Control.Monad.State import Control.Monad.Error import Control.Concurrent.MVar -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet diff --git a/core/Network/TLS/Sending.hs b/core/Network/TLS/Sending.hs index 1977c7c..178ca4f 100644 --- a/core/Network/TLS/Sending.hs +++ b/core/Network/TLS/Sending.hs @@ -23,7 +23,8 @@ import Network.TLS.Cap import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet -import Network.TLS.Context +import Network.TLS.Context.Internal +import Network.TLS.Parameters import Network.TLS.State import Network.TLS.Handshake.State import Network.TLS.Cipher @@ -67,7 +68,7 @@ writePacket ctx pkt = do -- so we use cstIV as is, however in other case we generate an explicit IV prepareRecord :: Context -> RecordM a -> IO (Either TLSError a) prepareRecord ctx f = do - ver <- usingState_ ctx (getVersionWithDefault $ maximum $ pAllowedVersions $ ctxParams ctx) + ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx) txState <- readMVar $ ctxTxState ctx let sz = case stCipher $ txState of Nothing -> 0 diff --git a/core/Network/TLS/Types.hs b/core/Network/TLS/Types.hs index 31a6ca7..0f3be1f 100644 --- a/core/Network/TLS/Types.hs +++ b/core/Network/TLS/Types.hs @@ -32,7 +32,7 @@ data SessionData = SessionData { sessionVersion :: Version , sessionCipher :: CipherID , sessionSecret :: ByteString - } + } deriving (Show,Eq) -- | Cipher identification type CipherID = Word16 diff --git a/core/Network/TLS/X509.hs b/core/Network/TLS/X509.hs index 8e37878..197637b 100644 --- a/core/Network/TLS/X509.hs +++ b/core/Network/TLS/X509.hs @@ -16,9 +16,18 @@ module Network.TLS.X509 , getCertificateChainLeaf , CertificateRejectReason(..) , CertificateUsage(..) + , CertificateStore + , ValidationCache + , exceptionValidationCache + , validateDefault + , FailedReason + , ServiceID + , wrapCertificateChecks ) where import Data.X509 +import Data.X509.Validation +import Data.X509.CertificateStore isNullCertificateChain :: CertificateChain -> Bool isNullCertificateChain (CertificateChain l) = null l @@ -41,3 +50,10 @@ data CertificateUsage = | CertificateUsageReject CertificateRejectReason -- ^ usage of certificate rejected deriving (Show,Eq) +wrapCertificateChecks :: [FailedReason] -> CertificateUsage +wrapCertificateChecks [] = CertificateUsageAccept +wrapCertificateChecks l + | Expired `elem` l = CertificateUsageReject $ CertificateRejectExpired + | InFuture `elem` l = CertificateUsageReject $ CertificateRejectExpired + | UnknownCA `elem` l = CertificateUsageReject $ CertificateRejectUnknownCA + | otherwise = CertificateUsageReject $ CertificateRejectOther (show l) diff --git a/core/Tests/Connection.hs b/core/Tests/Connection.hs index ef22c18..dcd7403 100644 --- a/core/Tests/Connection.hs +++ b/core/Tests/Connection.hs @@ -14,6 +14,8 @@ import PubKey import PipeChan import Network.TLS import Data.X509 +import Data.X509.Validation +import Data.Default.Class import Control.Applicative import Control.Concurrent.Chan import Control.Concurrent @@ -70,11 +72,11 @@ streamCipher = blockCipher } } -supportedCiphers :: [Cipher] -supportedCiphers = [blockCipher,blockCipherDHE_RSA,blockCipherDHE_DSS,streamCipher] +knownCiphers :: [Cipher] +knownCiphers = [blockCipher,blockCipherDHE_RSA,blockCipherDHE_DSS,streamCipher] -supportedVersions :: [Version] -supportedVersions = [SSL3,TLS10,TLS11,TLS12] +knownVersions :: [Version] +knownVersions = [SSL3,TLS10,TLS11,TLS12] arbitraryPairParams = do (dsaPub, dsaPriv) <- (\(p,r) -> (PubKeyDSA p, PrivKeyDSA r)) <$> arbitraryDSAPair @@ -83,44 +85,54 @@ arbitraryPairParams = do cert <- arbitraryX509WithKey (pub, priv) return (CertificateChain [cert], priv) ) [ (pubKey, privKey), (dsaPub, dsaPriv) ] - connectVersion <- elements supportedVersions - let allowedVersions = [ v | v <- supportedVersions, v <= connectVersion ] + connectVersion <- elements knownVersions + let allowedVersions = [ v | v <- knownVersions, v <= connectVersion ] serAllowedVersions <- (:[]) `fmap` elements allowedVersions serverCiphers <- arbitraryCiphers clientCiphers <- oneof [arbitraryCiphers] `suchThat` (\cs -> or [x `elem` serverCiphers | x <- cs]) secNeg <- arbitrary - --let cred = (CertificateChain [servCert], PrivKeyRSA privKey) - let serverState = defaultParamsServer - { pAllowedVersions = serAllowedVersions - , pCiphers = serverCiphers - , pCredentials = Credentials creds - , pUseSecureRenegotiation = secNeg - , pLogging = logging "server: " - , roleParams = roleParams $ updateServerParams (\sp -> sp { serverDHEParams = Just dhParams }) defaultParamsServer + +-- , pLogging = logging "server: " +-- , pLogging = logging "client: " + + let serverState = def + { serverSupported = def { supportedCiphers = serverCiphers + , supportedVersions = serAllowedVersions + , supportedSecureRenegotiation = secNeg + } + , serverDHEParams = Just dhParams + , serverShared = def { sharedCredentials = Credentials creds } } - let clientState = defaultParamsClient - { pAllowedVersions = allowedVersions - , pCiphers = clientCiphers - , pUseSecureRenegotiation = secNeg - , pLogging = logging "client: " + let clientState = (defaultParamsClient "" B.empty) + { clientSupported = def { supportedCiphers = clientCiphers + , supportedVersions = allowedVersions + , supportedSecureRenegotiation = secNeg + } + , clientShared = def { sharedValidationCache = ValidationCache + { cacheAdd = \_ _ _ -> return () + , cacheQuery = \_ _ _ -> return ValidationCachePass + } + } } return (clientState, serverState) where logging pre = if debug - then defaultLogging { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++) + then def { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++) , loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++) } - else defaultLogging - arbitraryCiphers = resize (length supportedCiphers + 1) $ listOf1 (elements supportedCiphers) + else def + arbitraryCiphers = resize (length knownCiphers + 1) $ listOf1 (elements knownCiphers) -setPairParamsSessionManager :: SessionManager -> (Params, Params) -> (Params, Params) +setPairParamsSessionManager :: SessionManager -> (ClientParams, ServerParams) -> (ClientParams, ServerParams) setPairParamsSessionManager manager (clientState, serverState) = (nc,ns) - where nc = setSessionManager manager clientState - ns = setSessionManager manager serverState + where nc = clientState { clientShared = updateSessionManager $ clientShared clientState } + ns = serverState { serverShared = updateSessionManager $ serverShared serverState } + updateSessionManager shared = shared { sharedSessionManager = manager } -setPairParamsSessionResuming sessionStuff (clientState, serverState) = (nc,serverState) - where nc = updateClientParams (\cparams -> cparams { clientWantSessionResume = Just sessionStuff }) clientState +setPairParamsSessionResuming sessionStuff (clientState, serverState) = + ( clientState { clientWantSessionResume = Just sessionStuff } + , serverState) newPairContext pipe (cParams, sParams) = do let noFlush = return () diff --git a/core/Tests/Tests.hs b/core/Tests/Tests.hs index c7418cb..f5360b6 100644 --- a/core/Tests/Tests.hs +++ b/core/Tests/Tests.hs @@ -72,8 +72,12 @@ prop_handshake_initiate = do prop_handshake_npn_initiate :: PropertyM IO () prop_handshake_npn_initiate = do (clientParam,serverParam) <- pick arbitraryPairParams - let clientParam' = updateClientParams (\cp -> cp { onNPNServerSuggest = Just $ \protos -> return (head protos) }) clientParam - serverParam' = updateServerParams (\sp -> sp { onSuggestNextProtocols = return $ Just [C8.pack "spdy/2", C8.pack "http/1.1"] }) serverParam + let clientParam' = clientParam { clientHooks = (clientHooks clientParam) + { onNPNServerSuggest = Just $ \protos -> return (head protos) } + } + serverParam' = serverParam { serverHooks = (serverHooks serverParam) + { onSuggestNextProtocols = return $ Just [C8.pack "spdy/2", C8.pack "http/1.1"] } + } params' = (clientParam',serverParam') runTLSPipe params' tlsServer tlsClient where tlsServer ctx queue = do diff --git a/core/tls.cabal b/core/tls.cabal index 28474a2..b932ea7 100644 --- a/core/tls.cabal +++ b/core/tls.cabal @@ -37,6 +37,7 @@ Library , cereal >= 0.3 , bytestring , network + , data-default-class , crypto-random >= 0.0 && < 0.1 , crypto-numbers , crypto-pubkey-types >= 0.4 @@ -45,6 +46,7 @@ Library , asn1-encoding , x509 >= 1.4.3 && < 1.5.0 , x509-store + , x509-validation >= 1.5.0 && < 1.6.0 Exposed-modules: Network.TLS Network.TLS.Cipher Network.TLS.Compression @@ -53,6 +55,7 @@ Library Network.TLS.Struct Network.TLS.Core Network.TLS.Context + Network.TLS.Context.Internal Network.TLS.Credentials Network.TLS.Backend Network.TLS.Crypto @@ -102,6 +105,7 @@ Test-Suite test-tls Build-Depends: base >= 3 && < 5 , mtl , cereal >= 0.3 + , data-default-class , QuickCheck >= 2 , test-framework , test-framework-quickcheck2 @@ -109,6 +113,7 @@ Test-Suite test-tls , crypto-pubkey >= 0.2 , bytestring , x509 + , x509-validation , tls , time , crypto-random @@ -122,6 +127,8 @@ Benchmark bench-tls Build-depends: base >= 4 && < 5 , tls , x509 + , x509-validation + , data-default-class , crypto-random , criterion , cprng-aes