Change the way parameters are created.

This is still WIP and this commit is truly horrific. Sadly, it's just
too much effort to do clean commit with this, and it doesn't mix with
experimentation either.
This commit is contained in:
Vincent Hanquez 2014-01-25 16:51:51 +00:00
parent 8e128a0412
commit 4e5ff7f53d
23 changed files with 697 additions and 511 deletions

View file

@ -8,6 +8,8 @@ import Criterion.Main
import Control.Concurrent.Chan import Control.Concurrent.Chan
import Network.TLS import Network.TLS
import Data.X509 import Data.X509
import Data.X509.Validation
import Data.Default.Class
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
@ -15,15 +17,22 @@ import qualified Data.ByteString.Lazy as L
recvDataNonNull ctx = recvData ctx >>= \l -> if B.null l then recvDataNonNull ctx else return l recvDataNonNull ctx = recvData ctx >>= \l -> if B.null l then recvDataNonNull ctx else return l
getParams connectVer cipher = (cParams, sParams) getParams connectVer cipher = (cParams, sParams)
where sParams = defaultParamsServer where sParams = def { serverSupported = supported
{ pAllowedVersions = [connectVer] , serverShared = def {
, pCiphers = [cipher] sharedCredentials = Credentials [ (CertificateChain [simpleX509 $ PubKeyRSA pubKey], PrivKeyRSA privKey) ]
, pCredentials = Credentials [ (CertificateChain [simpleX509 $ PubKeyRSA pubKey], PrivKeyRSA privKey) ] }
} }
cParams = defaultParamsClient cParams = (defaultParamsClient "" B.empty)
{ pAllowedVersions = [connectVer] { clientSupported = supported
, pCiphers = [cipher] , clientShared = def { sharedValidationCache = ValidationCache
{ cacheAdd = \_ _ _ -> return ()
, cacheQuery = \_ _ _ -> return ValidationCachePass
}
}
} }
supported = def { supportedCiphers = [cipher]
, supportedVersions = [connectVer]
}
(pubKey, privKey) = getGlobalRSAPair (pubKey, privKey) = getGlobalRSAPair
runTLSPipe params tlsServer tlsClient d name = bench name $ do runTLSPipe params tlsServer tlsClient d name = bench name $ do

View file

@ -8,19 +8,17 @@
module Network.TLS module Network.TLS
( (
-- * Context configuration -- * Context configuration
Params(..) ClientParams(..)
, RoleParams(..)
, ClientParams(..)
, ServerParams(..) , ServerParams(..)
, updateClientParams , ClientHooks(..)
, updateServerParams , ServerHooks(..)
, Supported(..)
, Shared(..)
, Logging(..) , Logging(..)
, Measurement(..) , Measurement(..)
, CertificateUsage(..) , CertificateUsage(..)
, CertificateRejectReason(..) , CertificateRejectReason(..)
, defaultParamsClient , defaultParamsClient
, defaultParamsServer
, defaultLogging
, MaxFragmentEnum(..) , MaxFragmentEnum(..)
, HashAndSignatureAlgorithm , HashAndSignatureAlgorithm
, HashAlgorithm(..) , HashAlgorithm(..)
@ -36,7 +34,6 @@ module Network.TLS
, SessionData(..) , SessionData(..)
, SessionManager(..) , SessionManager(..)
, noSessionManager , noSessionManager
, setSessionManager
-- * Backend abstraction -- * Backend abstraction
, Backend(..) , Backend(..)
@ -94,18 +91,27 @@ module Network.TLS
-- * Exceptions -- * Exceptions
, TLSException(..) , TLSException(..)
-- * Validation Cache
, ValidationCache
, exceptionValidationCache
) where ) where
import Network.TLS.Backend (Backend(..)) import Network.TLS.Backend (Backend(..))
import Network.TLS.Struct ( Version(..), TLSError(..), TLSException(..) import Network.TLS.Struct ( TLSError(..), TLSException(..)
, HashAndSignatureAlgorithm, HashAlgorithm(..), SignatureAlgorithm(..) , HashAndSignatureAlgorithm, HashAlgorithm(..), SignatureAlgorithm(..)
, Header(..), ProtocolType(..), CertificateType(..) , Header(..), ProtocolType(..), CertificateType(..)
, AlertDescription(..)) , AlertDescription(..))
import Network.TLS.Crypto (KxError(..)) import Network.TLS.Crypto (KxError(..))
import Network.TLS.Cipher import Network.TLS.Cipher
import Network.TLS.Hooks
import Network.TLS.Measurement
import Network.TLS.Credentials import Network.TLS.Credentials
import Network.TLS.Compression (CompressionC(..), Compression(..), nullCompression) import Network.TLS.Compression (CompressionC(..), Compression(..), nullCompression)
import Network.TLS.Context import Network.TLS.Context
import Network.TLS.Parameters
import Network.TLS.Core import Network.TLS.Core
import Network.TLS.Session import Network.TLS.Session
import Network.TLS.X509
import Network.TLS.Types
import Data.X509 (PubKey(..), PrivKey(..)) import Data.X509 (PubKey(..), PrivKey(..))

View file

@ -8,44 +8,18 @@
module Network.TLS.Context module Network.TLS.Context
( (
-- * Context configuration -- * Context configuration
Params(..) TLSParams
, RoleParams(..)
, ClientParams(..)
, ServerParams(..)
, updateClientParams
, updateServerParams
, Logging(..)
, SessionID
, SessionData(..)
, MaxFragmentEnum(..)
, Measurement(..)
, CertificateUsage(..)
, CertificateRejectReason(..)
, defaultLogging
, defaultParamsClient
, defaultParamsServer
, withSessionManager
, setSessionManager
, getClientParams
, getServerParams
, credentialsGet
-- * Context object and accessor -- * Context object and accessor
, Context , Context(..)
, Hooks(..) , Hooks(..)
, ctxParams
, ctxConnection
, ctxEOF , ctxEOF
, ctxHasSSLv2ClientHello , ctxHasSSLv2ClientHello
, ctxDisableSSLv2ClientHello , ctxDisableSSLv2ClientHello
, ctxEstablished , ctxEstablished
, ctxCiphers
, ctxLogging , ctxLogging
, ctxWithHooks , ctxWithHooks
, ctxRxState , modifyHooks
, ctxTxState
, ctxHandshake
, ctxNeedEmptyPacket
, setEOF , setEOF
, setEstablished , setEstablished
, contextFlush , contextFlush
@ -63,13 +37,6 @@ module Network.TLS.Context
, Information(..) , Information(..)
, contextGetInformation , contextGetInformation
-- * deprecated types
, TLSParams
, TLSLogging
, TLSCertificateUsage
, TLSCertificateRejectReason
, TLSCtx
-- * New contexts -- * New contexts
, contextNew , contextNew
-- * Deprecated new contexts methods -- * Deprecated new contexts methods
@ -91,177 +58,55 @@ module Network.TLS.Context
) where ) where
import Network.TLS.Backend import Network.TLS.Backend
import Network.TLS.Extension import Network.TLS.Context.Internal
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Cipher (Cipher(..), CipherKeyExchangeType(..)) import Network.TLS.Cipher (Cipher(..), CipherKeyExchangeType(..))
import Network.TLS.Compression (Compression)
import Network.TLS.Credentials import Network.TLS.Credentials
import Network.TLS.State import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Hooks import Network.TLS.Hooks
import Network.TLS.Record.State import Network.TLS.Record.State
import Network.TLS.Parameters import Network.TLS.Parameters
import Network.TLS.Measurement import Network.TLS.Measurement
import Network.TLS.Types (Role(..)) import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Data.Maybe (isJust) import Data.Maybe (isJust)
import qualified Data.ByteString as B
import Crypto.Random import Crypto.Random
import Control.Concurrent.MVar import Control.Concurrent.MVar
import Control.Monad.State import Control.Monad.State
import Control.Exception (throwIO, Exception())
import Data.IORef import Data.IORef
import Data.Tuple
-- deprecated imports -- deprecated imports
import Network.Socket (Socket) import Network.Socket (Socket)
import System.IO (Handle) import System.IO (Handle)
-- | Information related to a running context, e.g. current cipher class TLSParams a where
data Information = Information getTLSCommonParams :: a -> CommonParams
{ infoVersion :: Version getTLSRole :: a -> Role
, infoCipher :: Cipher getCiphers :: a -> [Cipher]
, infoCompression :: Compression doHandshake :: a -> Context -> IO ()
} deriving (Show,Eq) doHandshakeWith :: a -> Context -> Handshake -> IO ()
-- | A TLS Context keep tls specific state, parameters and backend information.
data Context = Context
{ ctxConnection :: Backend -- ^ return the backend object associated with this context
, ctxParams :: Params
, ctxCiphers :: [Cipher] -- ^ prepared list of allowed ciphers according to parameters
, ctxState :: MVar TLSState
, ctxMeasurement :: IORef Measurement
, ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not.
, ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful.
, ctxNeedEmptyPacket :: IORef Bool -- ^ empty packet workaround for CBC guessability.
, ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello.
-- the flag will be set to false regardless of its initial value
-- after the first packet received.
, ctxTxState :: MVar RecordState -- ^ current tx state
, ctxRxState :: MVar RecordState -- ^ current rx state
, ctxHandshake :: MVar (Maybe HandshakeState) -- ^ optional handshake state
, ctxHooks :: IORef Hooks -- ^ hooks for this context
, ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state)
, ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state)
, ctxLockState :: MVar () -- ^ lock used during read/write when receiving and sending packet.
-- it is usually nested in a write or read lock.
}
-- deprecated types, setup as aliases for compatibility. instance TLSParams ClientParams where
type TLSParams = Params getTLSCommonParams cparams = ( clientSupported cparams
type TLSCtx = Context , clientShared cparams
type TLSLogging = Logging , clientCommonHooks cparams
type TLSCertificateUsage = CertificateUsage )
type TLSCertificateRejectReason = CertificateRejectReason getTLSRole _ = ClientRole
getCiphers cparams = supportedCiphers $ clientSupported cparams
doHandshake = handshakeClient
doHandshakeWith = handshakeClientWith
updateMeasure :: Context -> (Measurement -> Measurement) -> IO () instance TLSParams ServerParams where
updateMeasure ctx f = do getTLSCommonParams sparams = ( serverSupported sparams
x <- readIORef (ctxMeasurement ctx) , serverShared sparams
writeIORef (ctxMeasurement ctx) $! f x , serverCommonHooks sparams
)
withMeasure :: Context -> (Measurement -> IO a) -> IO a getTLSRole _ = ServerRole
withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f
contextFlush :: Context -> IO ()
contextFlush = backendFlush . ctxConnection
contextClose :: Context -> IO ()
contextClose = backendClose . ctxConnection
-- | Information about the current context
contextGetInformation :: Context -> IO (Maybe Information)
contextGetInformation ctx = do
ver <- usingState_ ctx $ gets stVersion
(cipher,comp) <- failOnEitherError $ runRxState ctx $ gets $ \st -> (stCipher st, stCompression st)
case (ver, cipher) of
(Just v, Just c) -> return $ Just $ Information v c comp
_ -> return Nothing
contextSend :: Context -> Bytes -> IO ()
contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b
contextRecv :: Context -> Int -> IO Bytes
contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz
ctxEOF :: Context -> IO Bool
ctxEOF ctx = readIORef $ ctxEOF_ ctx
ctxHasSSLv2ClientHello :: Context -> IO Bool
ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx
ctxDisableSSLv2ClientHello :: Context -> IO ()
ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False
setEOF :: Context -> IO ()
setEOF ctx = writeIORef (ctxEOF_ ctx) True
ctxEstablished :: Context -> IO Bool
ctxEstablished ctx = readIORef $ ctxEstablished_ ctx
ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a
ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f
setEstablished :: Context -> Bool -> IO ()
setEstablished ctx v = writeIORef (ctxEstablished_ ctx) v
ctxLogging :: Context -> Logging
ctxLogging = pLogging . ctxParams
-- | create a new context using the backend and parameters specified.
contextNew :: (MonadIO m, CPRG rng, HasBackend backend)
=> backend -- ^ Backend abstraction with specific method to interact with the connection type.
-> Params -- ^ Parameters of the context.
-> rng -- ^ Random number generator associated with this context.
-> m Context
contextNew backend params rng = liftIO $ do
initializeBackend backend
let role = case roleParams params of
Client {} -> ClientRole
Server {} -> ServerRole
let st = newTLSState rng role
stvar <- newMVar st
eof <- newIORef False
established <- newIORef False
stats <- newIORef newMeasurement
-- we enable the reception of SSLv2 ClientHello message only in the
-- server context, where we might be dealing with an old/compat client.
sslv2Compat <- newIORef (role == ServerRole)
needEmptyPacket <- newIORef False
hooks <- newIORef defaultHooks
tx <- newMVar newRecordState
rx <- newMVar newRecordState
hs <- newMVar Nothing
-- on the server we filter our allowed ciphers here according -- on the server we filter our allowed ciphers here according
-- to the credentials and DHE parameters loaded -- to the credentials and DHE parameters loaded
let ciphers = case roleParams params of getCiphers sparams = filter authorizedCKE (supportedCiphers $ serverSupported sparams)
Client {} -> pCiphers params
Server sParams -> filterServer sParams $ pCiphers params
lockWrite <- newMVar ()
lockRead <- newMVar ()
lockState <- newMVar ()
when (null ciphers) $ error "no ciphers available with those parameters"
return $ Context
{ ctxConnection = getBackend backend
, ctxParams = params
, ctxCiphers = ciphers
, ctxState = stvar
, ctxTxState = tx
, ctxRxState = rx
, ctxHandshake = hs
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
, ctxSSLv2ClientHello = sslv2Compat
, ctxNeedEmptyPacket = needEmptyPacket
, ctxHooks = hooks
, ctxLockWrite = lockWrite
, ctxLockRead = lockRead
, ctxLockState = lockState
}
where filterServer sParams ciphers = filter authorizedCKE ciphers
where authorizedCKE cipher = where authorizedCKE cipher =
case cipherKeyExchange cipher of case cipherKeyExchange cipher of
CipherKeyExchange_RSA -> canEncryptRSA CipherKeyExchange_RSA -> canEncryptRSA
@ -277,26 +122,83 @@ contextNew backend params rng = liftIO $ do
CipherKeyExchange_ECDH_RSA -> False CipherKeyExchange_ECDH_RSA -> False
CipherKeyExchange_ECDHE_ECDSA -> False CipherKeyExchange_ECDHE_ECDSA -> False
canDHE = isJust $ serverDHEParams sParams canDHE = isJust $ serverDHEParams sparams
canSignDSS = SignatureDSS `elem` signingAlgs canSignDSS = SignatureDSS `elem` signingAlgs
canSignRSA = SignatureRSA `elem` signingAlgs canSignRSA = SignatureRSA `elem` signingAlgs
canEncryptRSA = isJust $ credentialsFindForDecrypting creds canEncryptRSA = isJust $ credentialsFindForDecrypting creds
signingAlgs = credentialsListSigningAlgorithms creds signingAlgs = credentialsListSigningAlgorithms creds
creds = credentialsGet params creds = sharedCredentials $ serverShared sparams
doHandshake = handshakeServer
doHandshakeWith = handshakeServerWith
-- | create a new context using the backend and parameters specified.
contextNew :: (MonadIO m, CPRG rng, HasBackend backend, TLSParams params)
=> backend -- ^ Backend abstraction with specific method to interact with the connection type.
-> params -- ^ Parameters of the context.
-> rng -- ^ Random number generator associated with this context.
-> m Context
contextNew backend params rng = liftIO $ do
initializeBackend backend
let role = getTLSRole params
st = newTLSState rng role
(supported, shared, commonHooks) = getTLSCommonParams params
ciphers = getCiphers params
when (null ciphers) $ error "no ciphers available with those parameters"
stvar <- newMVar st
eof <- newIORef False
established <- newIORef False
stats <- newIORef newMeasurement
-- we enable the reception of SSLv2 ClientHello message only in the
-- server context, where we might be dealing with an old/compat client.
sslv2Compat <- newIORef (role == ServerRole)
needEmptyPacket <- newIORef False
hooks <- newIORef defaultHooks
tx <- newMVar newRecordState
rx <- newMVar newRecordState
hs <- newMVar Nothing
lockWrite <- newMVar ()
lockRead <- newMVar ()
lockState <- newMVar ()
return $ Context
{ ctxConnection = getBackend backend
, ctxShared = shared
, ctxSupported = supported
, ctxCommonHooks = commonHooks
, ctxCiphers = ciphers
, ctxState = stvar
, ctxTxState = tx
, ctxRxState = rx
, ctxHandshake = hs
, ctxDoHandshake = doHandshake params
, ctxDoHandshakeWith = doHandshakeWith params
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
, ctxSSLv2ClientHello = sslv2Compat
, ctxNeedEmptyPacket = needEmptyPacket
, ctxHooks = hooks
, ctxLockWrite = lockWrite
, ctxLockRead = lockRead
, ctxLockState = lockState
}
-- | create a new context on an handle. -- | create a new context on an handle.
contextNewOnHandle :: (MonadIO m, CPRG rng) contextNewOnHandle :: (MonadIO m, CPRG rng, TLSParams params)
=> Handle -- ^ Handle of the connection. => Handle -- ^ Handle of the connection.
-> Params -- ^ Parameters of the context. -> params -- ^ Parameters of the context.
-> rng -- ^ Random number generator associated with this context. -> rng -- ^ Random number generator associated with this context.
-> m Context -> m Context
contextNewOnHandle handle params st = contextNew handle params st contextNewOnHandle handle params st = contextNew handle params st
{-# DEPRECATED contextNewOnHandle "use contextNew" #-} {-# DEPRECATED contextNewOnHandle "use contextNew" #-}
-- | create a new context on a socket. -- | create a new context on a socket.
contextNewOnSocket :: (MonadIO m, CPRG rng) contextNewOnSocket :: (MonadIO m, CPRG rng, TLSParams params)
=> Socket -- ^ Socket of the connection. => Socket -- ^ Socket of the connection.
-> Params -- ^ Parameters of the context. -> params -- ^ Parameters of the context.
-> rng -- ^ Random number generator associated with this context. -> rng -- ^ Random number generator associated with this context.
-> m Context -> m Context
contextNewOnSocket sock params st = contextNew sock params st contextNewOnSocket sock params st = contextNew sock params st
@ -305,62 +207,3 @@ contextNewOnSocket sock params st = contextNew sock params st
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO () contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv context f = contextHookSetHandshakeRecv context f =
liftIO $ modifyIORef (ctxHooks context) (\hooks -> hooks { hookRecvHandshake = f }) liftIO $ modifyIORef (ctxHooks context) (\hooks -> hooks { hookRecvHandshake = f })
throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO
failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a
failOnEitherError f = do
ret <- f
case ret of
Left err -> throwCore err
Right r -> return r
usingState :: Context -> TLSSt a -> IO (Either TLSError a)
usingState ctx f =
modifyMVar (ctxState ctx) $ \st ->
let (a, newst) = runTLSState f st
in newst `seq` return (newst, a)
usingState_ :: Context -> TLSSt a -> IO a
usingState_ ctx f = failOnEitherError $ usingState ctx f
usingHState :: Context -> HandshakeM a -> IO a
usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst ->
case mst of
Nothing -> throwCore $ Error_Misc "missing handshake"
Just st -> return $ swap (Just `fmap` runHandshake st f)
getHState :: Context -> IO (Maybe HandshakeState)
getHState ctx = liftIO $ readMVar (ctxHandshake ctx)
runTxState :: Context -> RecordM a -> IO (Either TLSError a)
runTxState ctx f = do
ver <- usingState_ ctx (getVersionWithDefault $ maximum $ pAllowedVersions $ ctxParams ctx)
modifyMVar (ctxTxState ctx) $ \st ->
case runRecordM f ver st of
Left err -> return (st, Left err)
Right (a, newSt) -> return (newSt, Right a)
runRxState :: Context -> RecordM a -> IO (Either TLSError a)
runRxState ctx f = do
ver <- usingState_ ctx getVersion
modifyMVar (ctxRxState ctx) $ \st ->
case runRecordM f ver st of
Left err -> return (st, Left err)
Right (a, newSt) -> return (newSt, Right a)
getStateRNG :: Context -> Int -> IO Bytes
getStateRNG ctx n = usingState_ ctx $ genRandom n
withReadLock :: Context -> IO a -> IO a
withReadLock ctx f = withMVar (ctxLockRead ctx) (const f)
withWriteLock :: Context -> IO a -> IO a
withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f)
withRWLock :: Context -> IO a -> IO a
withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f
withStateLock :: Context -> IO a -> IO a
withStateLock ctx f = withMVar (ctxLockState ctx) (const f)

View file

@ -0,0 +1,224 @@
-- |
-- Module : Network.TLS.Context.Internal
-- License : BSD-style
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
-- Stability : experimental
-- Portability : unknown
--
module Network.TLS.Context.Internal
(
-- * Context configuration
ClientParams(..)
, ServerParams(..)
, defaultParamsClient
, SessionID
, SessionData(..)
, MaxFragmentEnum(..)
, Measurement(..)
-- * Context object and accessor
, Context(..)
, Hooks(..)
, ctxEOF
, ctxHasSSLv2ClientHello
, ctxDisableSSLv2ClientHello
, ctxEstablished
, ctxLogging
, ctxWithHooks
, modifyHooks
, setEOF
, setEstablished
, contextFlush
, contextClose
, contextSend
, contextRecv
, updateMeasure
, withMeasure
, withReadLock
, withWriteLock
, withStateLock
, withRWLock
-- * information
, Information(..)
, contextGetInformation
-- * Using context states
, throwCore
, usingState
, usingState_
, runTxState
, runRxState
, usingHState
, getHState
, getStateRNG
) where
import Network.TLS.Backend
import Network.TLS.Extension
import Network.TLS.Cipher
import Network.TLS.Struct
import Network.TLS.Compression (Compression)
import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import qualified Data.ByteString as B
import Control.Concurrent.MVar
import Control.Monad.State
import Control.Exception (throwIO, Exception())
import Data.IORef
import Data.Tuple
-- | Information related to a running context, e.g. current cipher
data Information = Information
{ infoVersion :: Version
, infoCipher :: Cipher
, infoCompression :: Compression
} deriving (Show,Eq)
-- | A TLS Context keep tls specific state, parameters and backend information.
data Context = Context
{ ctxConnection :: Backend -- ^ return the backend object associated with this context
, ctxSupported :: Supported
, ctxShared :: Shared
, ctxCommonHooks :: CommonHooks
, ctxCiphers :: [Cipher] -- ^ prepared list of allowed ciphers according to parameters
, ctxState :: MVar TLSState
, ctxMeasurement :: IORef Measurement
, ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not.
, ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful.
, ctxNeedEmptyPacket :: IORef Bool -- ^ empty packet workaround for CBC guessability.
, ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello.
-- the flag will be set to false regardless of its initial value
-- after the first packet received.
, ctxTxState :: MVar RecordState -- ^ current tx state
, ctxRxState :: MVar RecordState -- ^ current rx state
, ctxHandshake :: MVar (Maybe HandshakeState) -- ^ optional handshake state
, ctxDoHandshake :: Context -> IO ()
, ctxDoHandshakeWith :: Context -> Handshake -> IO ()
, ctxHooks :: IORef Hooks -- ^ hooks for this context
, ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state)
, ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state)
, ctxLockState :: MVar () -- ^ lock used during read/write when receiving and sending packet.
-- it is usually nested in a write or read lock.
}
updateMeasure :: Context -> (Measurement -> Measurement) -> IO ()
updateMeasure ctx f = do
x <- readIORef (ctxMeasurement ctx)
writeIORef (ctxMeasurement ctx) $! f x
withMeasure :: Context -> (Measurement -> IO a) -> IO a
withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f
contextFlush :: Context -> IO ()
contextFlush = backendFlush . ctxConnection
contextClose :: Context -> IO ()
contextClose = backendClose . ctxConnection
-- | Information about the current context
contextGetInformation :: Context -> IO (Maybe Information)
contextGetInformation ctx = do
ver <- usingState_ ctx $ gets stVersion
(cipher,comp) <- failOnEitherError $ runRxState ctx $ gets $ \st -> (stCipher st, stCompression st)
case (ver, cipher) of
(Just v, Just c) -> return $ Just $ Information v c comp
_ -> return Nothing
contextSend :: Context -> Bytes -> IO ()
contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b
contextRecv :: Context -> Int -> IO Bytes
contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz
ctxEOF :: Context -> IO Bool
ctxEOF ctx = readIORef $ ctxEOF_ ctx
ctxHasSSLv2ClientHello :: Context -> IO Bool
ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx
ctxDisableSSLv2ClientHello :: Context -> IO ()
ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False
setEOF :: Context -> IO ()
setEOF ctx = writeIORef (ctxEOF_ ctx) True
ctxEstablished :: Context -> IO Bool
ctxEstablished ctx = readIORef $ ctxEstablished_ ctx
ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a
ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f
modifyHooks :: Context -> (Hooks -> Hooks) -> IO ()
modifyHooks ctx f = modifyIORef (ctxHooks ctx) f
setEstablished :: Context -> Bool -> IO ()
setEstablished ctx v = writeIORef (ctxEstablished_ ctx) v
ctxLogging :: Context -> Logging
ctxLogging = logging . ctxCommonHooks
throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO
failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a
failOnEitherError f = do
ret <- f
case ret of
Left err -> throwCore err
Right r -> return r
usingState :: Context -> TLSSt a -> IO (Either TLSError a)
usingState ctx f =
modifyMVar (ctxState ctx) $ \st ->
let (a, newst) = runTLSState f st
in newst `seq` return (newst, a)
usingState_ :: Context -> TLSSt a -> IO a
usingState_ ctx f = failOnEitherError $ usingState ctx f
usingHState :: Context -> HandshakeM a -> IO a
usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst ->
case mst of
Nothing -> throwCore $ Error_Misc "missing handshake"
Just st -> return $ swap (Just `fmap` runHandshake st f)
getHState :: Context -> IO (Maybe HandshakeState)
getHState ctx = liftIO $ readMVar (ctxHandshake ctx)
runTxState :: Context -> RecordM a -> IO (Either TLSError a)
runTxState ctx f = do
ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx)
modifyMVar (ctxTxState ctx) $ \st ->
case runRecordM f ver st of
Left err -> return (st, Left err)
Right (a, newSt) -> return (newSt, Right a)
runRxState :: Context -> RecordM a -> IO (Either TLSError a)
runRxState ctx f = do
ver <- usingState_ ctx getVersion
modifyMVar (ctxRxState ctx) $ \st ->
case runRecordM f ver st of
Left err -> return (st, Left err)
Right (a, newSt) -> return (newSt, Right a)
getStateRNG :: Context -> Int -> IO Bytes
getStateRNG ctx n = usingState_ ctx $ genRandom n
withReadLock :: Context -> IO a -> IO a
withReadLock ctx f = withMVar (ctxLockRead ctx) (const f)
withWriteLock :: Context -> IO a -> IO a
withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f)
withRWLock :: Context -> IO a -> IO a
withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f
withStateLock :: Context -> IO a -> IO a
withStateLock ctx f = withMVar (ctxLockState ctx) (const f)

View file

@ -29,6 +29,7 @@ module Network.TLS.Core
import Network.TLS.Context import Network.TLS.Context
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.State (getSession) import Network.TLS.State (getSession)
import Network.TLS.Parameters
import Network.TLS.IO import Network.TLS.IO
import Network.TLS.Session import Network.TLS.Session
import Network.TLS.Handshake import Network.TLS.Handshake
@ -79,17 +80,22 @@ recvData ctx = liftIO $ do
terminate err AlertLevel_Fatal InternalError (show err) terminate err AlertLevel_Fatal InternalError (show err)
process (Handshake [ch@(ClientHello {})]) = process (Handshake [ch@(ClientHello {})]) =
-- on server context receiving a client hello == renegotiation withRWLock ctx ((ctxDoHandshakeWith ctx) ctx ch) >> recvData ctx
{-
case roleParams $ ctxParams ctx of case roleParams $ ctxParams ctx of
Server sparams -> withRWLock ctx (handshakeServerWith sparams ctx ch) >> recvData ctx Server sparams -> withRWLock ctx (handshakeServerWith sparams ctx ch) >> recvData ctx
Client {} -> let reason = "unexpected client hello in client context" in Client {} -> let reason = "unexpected client hello in client context" in
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
process (Handshake [HelloRequest]) = -}
process (Handshake [hr@HelloRequest]) =
withRWLock ctx ((ctxDoHandshakeWith ctx) ctx hr) >> recvData ctx
{-
-- on client context, receiving a hello request == renegotiation -- on client context, receiving a hello request == renegotiation
case roleParams $ ctxParams ctx of case roleParams $ ctxParams ctx of
Server {} -> let reason = "unexpected hello request in server context" in Server {} -> let reason = "unexpected hello request in server context" in
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
Client cparams -> withRWLock ctx (handshakeClient cparams ctx) >> recvData ctx Client cparams -> withRWLock ctx (handshakeClient cparams ctx) >> recvData ctx
-}
process (Alert [(AlertLevel_Warning, CloseNotify)]) = tryBye >> setEOF ctx >> return B.empty process (Alert [(AlertLevel_Warning, CloseNotify)]) = tryBye >> setEOF ctx >> return B.empty
process (Alert [(AlertLevel_Fatal, desc)]) = do process (Alert [(AlertLevel_Fatal, desc)]) = do
@ -107,7 +113,7 @@ recvData ctx = liftIO $ do
session <- usingState_ ctx getSession session <- usingState_ ctx getSession
case session of case session of
Session Nothing -> return () Session Nothing -> return ()
Session (Just sid) -> withSessionManager (ctxParams ctx) (\s -> sessionInvalidate s sid) Session (Just sid) -> sessionInvalidate (sharedSessionManager $ ctxShared ctx) sid
catchException (sendPacket ctx $ Alert [(level, desc)]) (\_ -> return ()) catchException (sendPacket ctx $ Alert [(level, desc)]) (\_ -> return ())
setEOF ctx setEOF ctx
E.throwIO (Terminated False reason err) E.throwIO (Terminated False reason err)

View file

@ -7,11 +7,13 @@
-- --
module Network.TLS.Handshake module Network.TLS.Handshake
( handshake ( handshake
, handshakeClientWith
, handshakeServerWith , handshakeServerWith
, handshakeClient , handshakeClient
, handshakeServer
) where ) where
import Network.TLS.Context import Network.TLS.Context.Internal
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.IO import Network.TLS.IO
import Network.TLS.Util (catchException) import Network.TLS.Util (catchException)
@ -26,11 +28,8 @@ import Control.Exception (fromException)
-- | Handshake for a new TLS connection -- | Handshake for a new TLS connection
-- This is to be called at the beginning of a connection, and during renegotiation -- This is to be called at the beginning of a connection, and during renegotiation
handshake :: MonadIO m => Context -> m () handshake :: MonadIO m => Context -> m ()
handshake ctx = do handshake ctx =
let handshakeF = case roleParams $ ctxParams ctx of liftIO $ handleException $ withRWLock ctx (ctxDoHandshake ctx $ ctx)
Server sparams -> handshakeServer sparams
Client cparams -> handshakeClient cparams
liftIO $ handleException $ withRWLock ctx (handshakeF ctx)
where handleException f = catchException f $ \exception -> do where handleException f = catchException f $ \exception -> do
let tlserror = maybe (Error_Misc $ show exception) id $ fromException exception let tlserror = maybe (Error_Misc $ show exception) id $ fromException exception
setEstablished ctx False setEstablished ctx False

View file

@ -10,8 +10,9 @@ module Network.TLS.Handshake.Certificate
, rejectOnException , rejectOnException
) where ) where
import Network.TLS.Context import Network.TLS.Context.Internal
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.X509
import Control.Monad.State import Control.Monad.State
import Control.Exception (SomeException) import Control.Exception (SomeException)
@ -26,5 +27,5 @@ certificateRejected CertificateRejectUnknownCA =
certificateRejected (CertificateRejectOther s) = certificateRejected (CertificateRejectOther s) =
throwCore $ Error_Protocol ("certificate rejected: " ++ s, True, CertificateUnknown) throwCore $ Error_Protocol ("certificate rejected: " ++ s, True, CertificateUnknown)
rejectOnException :: SomeException -> IO TLSCertificateUsage rejectOnException :: SomeException -> IO CertificateUsage
rejectOnException e = return $ CertificateUsageReject $ CertificateRejectOther $ show e rejectOnException e = return $ CertificateUsageReject $ CertificateRejectOther $ show e

View file

@ -8,10 +8,12 @@
-- --
module Network.TLS.Handshake.Client module Network.TLS.Handshake.Client
( handshakeClient ( handshakeClient
, handshakeClientWith
) where ) where
import Network.TLS.Crypto import Network.TLS.Crypto
import Network.TLS.Context import Network.TLS.Context.Internal
import Network.TLS.Parameters
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Cipher import Network.TLS.Cipher
import Network.TLS.Compression import Network.TLS.Compression
@ -41,6 +43,10 @@ import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.Key import Network.TLS.Handshake.Key
import Network.TLS.Handshake.State import Network.TLS.Handshake.State
handshakeClientWith :: ClientParams -> Context -> Handshake -> IO ()
handshakeClientWith cparams ctx HelloRequest = handshakeClient cparams ctx
handshakeClientWith _ _ _ = throwCore $ Error_Protocol ("unexpected handshake message received in handshakeClientWith", True, HandshakeFailure)
-- client part of handshake. send a bunch of handshake of client -- client part of handshake. send a bunch of handshake of client
-- values intertwined with response from the server. -- values intertwined with response from the server.
handshakeClient :: ClientParams -> Context -> IO () handshakeClient :: ClientParams -> Context -> IO ()
@ -50,31 +56,32 @@ handshakeClient cparams ctx = do
recvServerHello sentExtensions recvServerHello sentExtensions
sessionResuming <- usingState_ ctx isSessionResuming sessionResuming <- usingState_ ctx isSessionResuming
if sessionResuming if sessionResuming
then sendChangeCipherAndFinish ctx ClientRole then sendChangeCipherAndFinish sendMaybeNPN ctx ClientRole
else do sendClientData cparams ctx else do sendClientData cparams ctx
sendChangeCipherAndFinish ctx ClientRole sendChangeCipherAndFinish sendMaybeNPN ctx ClientRole
recvChangeCipherAndFinish ctx recvChangeCipherAndFinish ctx
handshakeTerminate ctx handshakeTerminate ctx
where params = ctxParams ctx where ciphers = ctxCiphers ctx
ciphers = pCiphers params compressions = supportedCompressions $ ctxSupported ctx
compressions = pCompressions params
getExtensions = sequence [sniExtension,secureReneg,npnExtention] >>= return . catMaybes getExtensions = sequence [sniExtension,secureReneg,npnExtention] >>= return . catMaybes
toExtensionRaw :: Extension e => e -> ExtensionRaw toExtensionRaw :: Extension e => e -> ExtensionRaw
toExtensionRaw ext = (extensionID ext, extensionEncode ext) toExtensionRaw ext = (extensionID ext, extensionEncode ext)
secureReneg = secureReneg =
if pUseSecureRenegotiation params if supportedSecureRenegotiation $ ctxSupported ctx
then usingState_ ctx (getVerifiedData ClientRole) >>= \vd -> return $ Just $ toExtensionRaw $ SecureRenegotiation vd Nothing then usingState_ ctx (getVerifiedData ClientRole) >>= \vd -> return $ Just $ toExtensionRaw $ SecureRenegotiation vd Nothing
else return Nothing else return Nothing
npnExtention = if isJust $ onNPNServerSuggest cparams npnExtention = if isJust $ onNPNServerSuggest $ clientHooks cparams
then return $ Just $ toExtensionRaw $ NextProtocolNegotiation [] then return $ Just $ toExtensionRaw $ NextProtocolNegotiation []
else return Nothing else return Nothing
sniExtension = return ((\h -> toExtensionRaw $ ServerName [(ServerNameHostName h)]) <$> clientUseServerName cparams) sniExtension = if clientUseServerNameIndication cparams
then return $ Just $ toExtensionRaw $ ServerName [ServerNameHostName $ fst $ clientServerIdentification cparams]
else return Nothing
sendClientHello = do sendClientHello = do
crand <- getStateRNG ctx 32 >>= return . ClientRandom crand <- getStateRNG ctx 32 >>= return . ClientRandom
let clientSession = Session . maybe Nothing (Just . fst) $ clientWantSessionResume cparams let clientSession = Session . maybe Nothing (Just . fst) $ clientWantSessionResume cparams
highestVer = maximum $ pAllowedVersions params highestVer = maximum $ supportedVersions $ ctxSupported ctx
extensions <- getExtensions extensions <- getExtensions
startHandshake ctx highestVer crand startHandshake ctx highestVer crand
usingState_ ctx $ setVersionIfUnset highestVer usingState_ ctx $ setVersionIfUnset highestVer
@ -84,6 +91,18 @@ handshakeClient cparams ctx = do
] ]
return $ map fst extensions return $ map fst extensions
sendMaybeNPN = do
suggest <- usingState_ ctx $ getServerNextProtocolSuggest
case (onNPNServerSuggest $ clientHooks cparams, suggest) of
-- client offered, server picked up. send NPN handshake.
(Just io, Just protos) -> do proto <- liftIO $ io protos
sendPacket ctx (Handshake [HsNextProtocolNegotiation proto])
usingState_ ctx $ setNegotiatedProtocol proto
-- client offered, server didn't pick up. do nothing.
(Just _, Nothing) -> return ()
-- client didn't offer. do nothing.
(Nothing, _) -> return ()
recvServerHello sentExts = runRecvState ctx (RecvStateHandshake $ onServerHello ctx cparams sentExts) recvServerHello sentExts = runRecvState ctx (RecvStateHandshake $ onServerHello ctx cparams sentExts)
-- | send client Data after receiving all server data (hello/certificates/key). -- | send client Data after receiving all server data (hello/certificates/key).
@ -108,7 +127,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
return () return ()
Just req -> do Just req -> do
certChain <- liftIO $ onCertificateRequest cparams req `catchException` certChain <- liftIO $ (onCertificateRequest $ clientHooks cparams) req `catchException`
throwMiscErrorOnException "certificate request callback failed" throwMiscErrorOnException "certificate request callback failed"
usingHState ctx $ setClientCertSent False usingHState ctx $ setClientCertSent False
@ -176,7 +195,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
malg <- case usedVersion of malg <- case usedVersion of
TLS12 -> do TLS12 -> do
Just (_, Just hashSigs, _) <- usingHState ctx $ getClientCertRequest Just (_, Just hashSigs, _) <- usingHState ctx $ getClientCertRequest
let suppHashSigs = pHashSignatures $ ctxParams ctx let suppHashSigs = supportedHashSignatures $ ctxSupported ctx
hashSigs' = filter (\ a -> a `elem` hashSigs) suppHashSigs hashSigs' = filter (\ a -> a `elem` hashSigs) suppHashSigs
when (null hashSigs') $ when (null hashSigs') $
@ -218,14 +237,14 @@ throwMiscErrorOnException msg e =
onServerHello :: Context -> ClientParams -> [ExtensionID] -> Handshake -> IO (RecvState IO) onServerHello :: Context -> ClientParams -> [ExtensionID] -> Handshake -> IO (RecvState IO)
onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cipher compression exts) = do onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cipher compression exts) = do
when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion) when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
case find ((==) rver) allowedvers of case find ((==) rver) (supportedVersions $ ctxSupported ctx) of
Nothing -> throwCore $ Error_Protocol ("server version " ++ show rver ++ " is not supported", True, ProtocolVersion) Nothing -> throwCore $ Error_Protocol ("server version " ++ show rver ++ " is not supported", True, ProtocolVersion)
Just _ -> return () Just _ -> return ()
-- find the compression and cipher methods that the server want to use. -- find the compression and cipher methods that the server want to use.
cipherAlg <- case find ((==) cipher . cipherID) ciphers of cipherAlg <- case find ((==) cipher . cipherID) (ctxCiphers ctx) of
Nothing -> throwCore $ Error_Protocol ("server choose unknown cipher", True, HandshakeFailure) Nothing -> throwCore $ Error_Protocol ("server choose unknown cipher", True, HandshakeFailure)
Just alg -> return alg Just alg -> return alg
compressAlg <- case find ((==) compression . compressionID) compressions of compressAlg <- case find ((==) compression . compressionID) (supportedCompressions $ ctxSupported ctx) of
Nothing -> throwCore $ Error_Protocol ("server choose unknown compression", True, HandshakeFailure) Nothing -> throwCore $ Error_Protocol ("server choose unknown compression", True, HandshakeFailure)
Just alg -> return alg Just alg -> return alg
@ -251,25 +270,25 @@ onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cip
_ -> return () _ -> return ()
case resumingSession of case resumingSession of
Nothing -> return $ RecvStateHandshake (processCertificate ctx) Nothing -> return $ RecvStateHandshake (processCertificate cparams ctx)
Just sessionData -> do Just sessionData -> do
usingHState ctx (setMasterSecret rver ClientRole $ sessionSecret sessionData) usingHState ctx (setMasterSecret rver ClientRole $ sessionSecret sessionData)
return $ RecvStateNext expectChangeCipher return $ RecvStateNext expectChangeCipher
where params = ctxParams ctx
allowedvers = pAllowedVersions params
ciphers = pCiphers params
compressions = pCompressions params
onServerHello _ _ _ p = unexpected (show p) (Just "server hello") onServerHello _ _ _ p = unexpected (show p) (Just "server hello")
processCertificate :: Context -> Handshake -> IO (RecvState IO) processCertificate :: ClientParams -> Context -> Handshake -> IO (RecvState IO)
processCertificate ctx (Certificates certs) = do processCertificate cparams ctx (Certificates certs) = do
usage <- liftIO $ catchException (onCertificatesRecv params certs) rejectOnException usage <- catchException (wrapCertificateChecks <$> checkCert) rejectOnException
case usage of case usage of
CertificateUsageAccept -> return () CertificateUsageAccept -> return ()
CertificateUsageReject reason -> certificateRejected reason CertificateUsageReject reason -> certificateRejected reason
return $ RecvStateHandshake (processServerKeyExchange ctx) return $ RecvStateHandshake (processServerKeyExchange ctx)
where params = ctxParams ctx where shared = clientShared cparams
processCertificate ctx p = processServerKeyExchange ctx p checkCert = (onServerCertificate $ clientHooks cparams) (sharedCAStore shared)
(sharedValidationCache shared)
(clientServerIdentification cparams)
certs
processCertificate _ ctx p = processServerKeyExchange ctx p
expectChangeCipher :: Packet -> IO (RecvState IO) expectChangeCipher :: Packet -> IO (RecvState IO)
expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish

View file

@ -16,7 +16,8 @@ module Network.TLS.Handshake.Common
import Control.Concurrent.MVar import Control.Concurrent.MVar
import Network.TLS.Context import Network.TLS.Parameters
import Network.TLS.Context.Internal
import Network.TLS.Session import Network.TLS.Session
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.IO import Network.TLS.IO
@ -45,8 +46,8 @@ unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" e
newSession :: Context -> IO Session newSession :: Context -> IO Session
newSession ctx newSession ctx
| pUseSession $ ctxParams ctx = getStateRNG ctx 32 >>= return . Session . Just | supportedSession $ ctxSupported ctx = getStateRNG ctx 32 >>= return . Session . Just
| otherwise = return $ Session Nothing | otherwise = return $ Session Nothing
-- | when a new handshake is done, wrap up & clean up. -- | when a new handshake is done, wrap up & clean up.
handshakeTerminate :: Context -> IO () handshakeTerminate :: Context -> IO ()
@ -56,7 +57,7 @@ handshakeTerminate ctx = do
case session of case session of
Session (Just sessionId) -> do Session (Just sessionId) -> do
sessionData <- getSessionData ctx sessionData <- getSessionData ctx
withSessionManager (ctxParams ctx) (\s -> liftIO $ sessionEstablish s sessionId (fromJust "session-data" sessionData)) liftIO $ sessionEstablish (sharedSessionManager $ ctxShared ctx) sessionId (fromJust "session-data" sessionData)
_ -> return () _ -> return ()
-- forget all handshake data now and reset bytes counters. -- forget all handshake data now and reset bytes counters.
liftIO $ modifyMVar_ (ctxHandshake ctx) (return . const Nothing) liftIO $ modifyMVar_ (ctxHandshake ctx) (return . const Nothing)
@ -65,24 +66,14 @@ handshakeTerminate ctx = do
setEstablished ctx True setEstablished ctx True
return () return ()
sendChangeCipherAndFinish :: Context -> Role -> IO () sendChangeCipherAndFinish :: IO () -- ^ message possibly sent between ChangeCipherSpec and Finished.
sendChangeCipherAndFinish ctx role = do -> Context
-> Role
-> IO ()
sendChangeCipherAndFinish betweenCall ctx role = do
sendPacket ctx ChangeCipherSpec sendPacket ctx ChangeCipherSpec
betweenCall
when (role == ClientRole) $ do
let cparams = getClientParams $ ctxParams ctx
suggest <- usingState_ ctx $ getServerNextProtocolSuggest
case (onNPNServerSuggest cparams, suggest) of
-- client offered, server picked up. send NPN handshake.
(Just io, Just protos) -> do proto <- liftIO $ io protos
sendPacket ctx (Handshake [HsNextProtocolNegotiation proto])
usingState_ ctx $ setNegotiatedProtocol proto
-- client offered, server didn't pick up. do nothing.
(Just _, Nothing) -> return ()
-- client didn't offer. do nothing.
(Nothing, _) -> return ()
liftIO $ contextFlush ctx liftIO $ contextFlush ctx
cf <- usingState_ ctx getVersion >>= \ver -> usingHState ctx $ getHandshakeDigest ver role cf <- usingState_ ctx getVersion >>= \ver -> usingHState ctx $ getHandshakeDigest ver role
sendPacket ctx (Handshake [Finished cf]) sendPacket ctx (Handshake [Finished cf])
liftIO $ contextFlush ctx liftIO $ contextFlush ctx

View file

@ -22,7 +22,7 @@ import Network.TLS.Handshake.State
import Network.TLS.State (withRNG, getVersion) import Network.TLS.State (withRNG, getVersion)
import Network.TLS.Crypto import Network.TLS.Crypto
import Network.TLS.Types import Network.TLS.Types
import Network.TLS.Context import Network.TLS.Context.Internal
{- if the RSA encryption fails we just return an empty bytestring, and let the protocol {- if the RSA encryption fails we just return an empty bytestring, and let the protocol
- fail by itself; however it would be probably better to just report it since it's an internal problem. - fail by itself; however it would be probably better to just report it since it's an internal problem.

View file

@ -23,7 +23,7 @@ import Network.TLS.Util
import Network.TLS.Packet import Network.TLS.Packet
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.State import Network.TLS.State
import Network.TLS.Context import Network.TLS.Context.Internal
import Network.TLS.Crypto import Network.TLS.Crypto
import Network.TLS.Handshake.State import Network.TLS.Handshake.State
import Network.TLS.Handshake.Key import Network.TLS.Handshake.Key

View file

@ -11,7 +11,8 @@ module Network.TLS.Handshake.Server
, handshakeServerWith , handshakeServerWith
) where ) where
import Network.TLS.Context import Network.TLS.Parameters
import Network.TLS.Context.Internal
import Network.TLS.Session import Network.TLS.Session
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Cipher import Network.TLS.Cipher
@ -80,7 +81,7 @@ handshakeServer sparams ctx = liftIO $ do
handshakeServerWith :: ServerParams -> Context -> Handshake -> IO () handshakeServerWith :: ServerParams -> Context -> Handshake -> IO ()
handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientSession ciphers compressions exts _) = do handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientSession ciphers compressions exts _) = do
-- check if policy allow this new handshake to happens -- check if policy allow this new handshake to happens
handshakeAuthorized <- withMeasure ctx (onHandshake $ ctxParams ctx) handshakeAuthorized <- withMeasure ctx (onHandshake $ ctxCommonHooks ctx)
unless handshakeAuthorized (throwCore $ Error_HandshakePolicy "server: handshake denied") unless handshakeAuthorized (throwCore $ Error_HandshakePolicy "server: handshake denied")
updateMeasure ctx incrementNbHandshakes updateMeasure ctx incrementNbHandshakes
@ -88,7 +89,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS
processHandshake ctx clientHello processHandshake ctx clientHello
when (clientVersion == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion) when (clientVersion == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
chosenVersion <- case findHighestVersionFrom clientVersion (pAllowedVersions params) of chosenVersion <- case findHighestVersionFrom clientVersion (supportedVersions $ ctxSupported ctx) of
Nothing -> throwCore $ Error_Protocol ("client version " ++ show clientVersion ++ " is not supported", True, ProtocolVersion) Nothing -> throwCore $ Error_Protocol ("client version " ++ show clientVersion ++ " is not supported", True, ProtocolVersion)
Just v -> return v Just v -> return v
@ -98,8 +99,8 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS
Error_Protocol ("no compression in common with the client", True, HandshakeFailure) Error_Protocol ("no compression in common with the client", True, HandshakeFailure)
let ciphersFilteredVersion = filter (cipherAllowedForVersion chosenVersion) commonCiphers let ciphersFilteredVersion = filter (cipherAllowedForVersion chosenVersion) commonCiphers
usedCipher = (onCipherChoosing sparams) chosenVersion ciphersFilteredVersion usedCipher = (onCipherChoosing $ serverHooks sparams) chosenVersion ciphersFilteredVersion
creds = credentialsGet params creds = sharedCredentials $ ctxShared ctx
cred <- case cipherKeyExchange usedCipher of cred <- case cipherKeyExchange usedCipher of
CipherKeyExchange_RSA -> return $ credentialsFindForDecrypting creds CipherKeyExchange_RSA -> return $ credentialsFindForDecrypting creds
CipherKeyExchange_DH_Anon -> return $ Nothing CipherKeyExchange_DH_Anon -> return $ Nothing
@ -108,20 +109,19 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS
_ -> throwCore $ Error_Protocol ("key exchange algorithm not implemented", True, HandshakeFailure) _ -> throwCore $ Error_Protocol ("key exchange algorithm not implemented", True, HandshakeFailure)
resumeSessionData <- case clientSession of resumeSessionData <- case clientSession of
(Session (Just clientSessionId)) -> withSessionManager params (\s -> liftIO $ sessionResume s clientSessionId) (Session (Just clientSessionId)) -> liftIO $ sessionResume (sharedSessionManager $ ctxShared ctx) clientSessionId
(Session Nothing) -> return Nothing (Session Nothing) -> return Nothing
doHandshake sparams cred ctx chosenVersion usedCipher usedCompression clientSession resumeSessionData exts doHandshake sparams cred ctx chosenVersion usedCipher usedCompression clientSession resumeSessionData exts
where where
params = ctxParams ctx
commonCipherIDs = intersect ciphers (map cipherID $ ctxCiphers ctx) commonCipherIDs = intersect ciphers (map cipherID $ ctxCiphers ctx)
commonCiphers = filter (flip elem commonCipherIDs . cipherID) (ctxCiphers ctx) commonCiphers = filter (flip elem commonCipherIDs . cipherID) (ctxCiphers ctx)
commonCompressions = compressionIntersectID (pCompressions params) compressions commonCompressions = compressionIntersectID (supportedCompressions $ ctxSupported ctx) compressions
usedCompression = head commonCompressions usedCompression = head commonCompressions
handshakeServerWith _ _ _ = fail "unexpected handshake type received. expecting client hello" handshakeServerWith _ _ _ = throwCore $ Error_Protocol ("unexpected handshake message received in handshakeServerWith", True, HandshakeFailure)
doHandshake :: ServerParams -> Maybe Credential -> Context -> Version -> Cipher doHandshake :: ServerParams -> Maybe Credential -> Context -> Version -> Cipher
-> Compression -> Session -> Maybe SessionData -> Compression -> Session -> Maybe SessionData
@ -133,13 +133,13 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes
liftIO $ contextFlush ctx liftIO $ contextFlush ctx
-- Receive client info until client Finished. -- Receive client info until client Finished.
recvClientData sparams ctx recvClientData sparams ctx
sendChangeCipherAndFinish ctx ServerRole sendChangeCipherAndFinish (return ()) ctx ServerRole
Just sessionData -> do Just sessionData -> do
usingState_ ctx (setSession clientSession True) usingState_ ctx (setSession clientSession True)
serverhello <- makeServerHello clientSession serverhello <- makeServerHello clientSession
sendPacket ctx $ Handshake [serverhello] sendPacket ctx $ Handshake [serverhello]
usingHState ctx $ setMasterSecret chosenVersion ServerRole $ sessionSecret sessionData usingHState ctx $ setMasterSecret chosenVersion ServerRole $ sessionSecret sessionData
sendChangeCipherAndFinish ctx ServerRole sendChangeCipherAndFinish (return ()) ctx ServerRole
recvChangeCipherAndFinish ctx recvChangeCipherAndFinish ctx
handshakeTerminate ctx handshakeTerminate ctx
where where
@ -169,7 +169,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes
else return [] else return []
nextProtocols <- nextProtocols <-
if clientRequestedNPN if clientRequestedNPN
then liftIO $ onSuggestNextProtocols sparams then liftIO $ onSuggestNextProtocols $ serverHooks sparams
else return Nothing else return Nothing
npnExt <- case nextProtocols of npnExt <- case nextProtocols of
Just protos -> do usingState_ ctx $ do setExtensionNPN True Just protos -> do usingState_ ctx $ do setExtensionNPN True
@ -212,7 +212,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes
let certTypes = [ CertificateType_RSA_Sign ] let certTypes = [ CertificateType_RSA_Sign ]
hashSigs = if usedVersion < TLS12 hashSigs = if usedVersion < TLS12
then Nothing then Nothing
else Just (pHashSignatures $ ctxParams ctx) else Just (supportedHashSignatures $ ctxSupported ctx)
creq = CertRequest certTypes hashSigs creq = CertRequest certTypes hashSigs
(map extractCAname $ serverCACertificates sparams) (map extractCAname $ serverCACertificates sparams)
usingHState ctx $ setCertReqSent True usingHState ctx $ setCertReqSent True
@ -240,7 +240,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes
usedVersion <- usingState_ ctx getVersion usedVersion <- usingState_ ctx getVersion
let mhash = case usedVersion of let mhash = case usedVersion of
TLS12 -> case filter ((==) sigAlg . snd) $ pHashSignatures $ ctxParams ctx of TLS12 -> case filter ((==) sigAlg . snd) $ supportedHashSignatures $ ctxSupported ctx of
[] -> error ("no hash signature for " ++ show sigAlg) [] -> error ("no hash signature for " ++ show sigAlg)
x:_ -> Just (fst x) x:_ -> Just (fst x)
_ -> Nothing _ -> Nothing
@ -269,7 +269,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
-- Call application callback to see whether the -- Call application callback to see whether the
-- certificate chain is acceptable. -- certificate chain is acceptable.
-- --
usage <- liftIO $ catchException (onClientCertificate sparams certs) rejectOnException usage <- liftIO $ catchException (onClientCertificate (serverHooks sparams) certs) rejectOnException
case usage of case usage of
CertificateUsageAccept -> return () CertificateUsageAccept -> return ()
CertificateUsageReject reason -> certificateRejected reason CertificateUsageReject reason -> certificateRejected reason
@ -320,7 +320,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
-- the signature is wrong. In either case, -- the signature is wrong. In either case,
-- ask the application if it wants to -- ask the application if it wants to
-- proceed, we will do that. -- proceed, we will do that.
res <- liftIO $ onUnverifiedClientCert sparams res <- liftIO $ onUnverifiedClientCert (serverHooks sparams)
if res if res
then do then do
-- When verification fails, but the -- When verification fails, but the

View file

@ -18,7 +18,7 @@ module Network.TLS.Handshake.Signature
import Crypto.PubKey.HashDescr import Crypto.PubKey.HashDescr
import Network.TLS.Crypto import Network.TLS.Crypto
import Network.TLS.Context import Network.TLS.Context.Internal
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Packet (generateCertificateVerify_SSL, encodeSignedDHParams) import Network.TLS.Packet (generateCertificateVerify_SSL, encodeSignedDHParams)
import Network.TLS.State import Network.TLS.State

View file

@ -7,13 +7,13 @@
-- --
module Network.TLS.Hooks module Network.TLS.Hooks
( Logging(..) ( Logging(..)
, defaultLogging
, Hooks(..) , Hooks(..)
, defaultHooks , defaultHooks
) where ) where
import qualified Data.ByteString as B import qualified Data.ByteString as B
import Network.TLS.Struct (Header, Handshake(..)) import Network.TLS.Struct (Header, Handshake(..))
import Data.Default.Class
-- | Hooks for logging -- | Hooks for logging
data Logging = Logging data Logging = Logging
@ -31,6 +31,9 @@ defaultLogging = Logging
, loggingIORecv = (\_ _ -> return ()) , loggingIORecv = (\_ _ -> return ())
} }
instance Default Logging where
def = defaultLogging
-- | A collection of hooks actions. -- | A collection of hooks actions.
data Hooks = Hooks data Hooks = Hooks
{ hookRecvHandshake :: Handshake -> IO Handshake { hookRecvHandshake :: Handshake -> IO Handshake
@ -41,3 +44,5 @@ defaultHooks = Hooks
{ hookRecvHandshake = \hs -> return hs { hookRecvHandshake = \hs -> return hs
} }
instance Default Hooks where
def = defaultHooks

View file

@ -13,10 +13,11 @@ module Network.TLS.IO
, recvPacket , recvPacket
) where ) where
import Network.TLS.Context import Network.TLS.Context.Internal
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Record import Network.TLS.Record
import Network.TLS.Packet import Network.TLS.Packet
import Network.TLS.Hooks
import Network.TLS.Sending import Network.TLS.Sending
import Network.TLS.Receiving import Network.TLS.Receiving
import qualified Data.ByteString as B import qualified Data.ByteString as B

View file

@ -5,33 +5,23 @@
-- Stability : experimental -- Stability : experimental
-- Portability : unknown -- Portability : unknown
-- --
-- extension RecordWildCards only needed because of some GHC bug
-- relative to insufficient polymorphic field
{-# LANGUAGE RecordWildCards #-}
module Network.TLS.Parameters module Network.TLS.Parameters
( (
-- * Parameters ClientParams(..)
Params(..)
, RoleParams(..)
, ClientParams(..)
, ServerParams(..) , ServerParams(..)
, updateClientParams , CommonParams
, updateServerParams , ClientHooks(..)
, Logging(..) , ServerHooks(..)
, SessionID , CommonHooks(..)
, SessionData(..) , Supported(..)
, Shared(..)
-- * special default
, defaultParamsClient
-- * Parameters
, MaxFragmentEnum(..) , MaxFragmentEnum(..)
, Measurement(..) , Logging(..)
, CertificateUsage(..) , CertificateUsage(..)
, CertificateRejectReason(..) , CertificateRejectReason(..)
, defaultLogging
, defaultParamsClient
, defaultParamsServer
, withSessionManager
, setSessionManager
, getClientParams
, getServerParams
, credentialsGet
) where ) where
import Network.BSD (HostName) import Network.BSD (HostName)
@ -48,15 +38,137 @@ import Network.TLS.Hooks
import Network.TLS.Measurement import Network.TLS.Measurement
import Network.TLS.X509 import Network.TLS.X509
import Data.Monoid import Data.Monoid
import Data.List (intercalate) import Data.Default.Class
import qualified Data.ByteString as B import qualified Data.ByteString as B
data ClientParams = ClientParams type CommonParams = (Supported, Shared, CommonHooks)
{ clientUseMaxFragmentLength :: Maybe MaxFragmentEnum
, clientUseServerName :: Maybe HostName
, clientWantSessionResume :: Maybe (SessionID, SessionData) -- ^ try to establish a connection using this session.
-- | This action is called when the server sends a data ClientParams = ClientParams
{ clientUseMaxFragmentLength :: Maybe MaxFragmentEnum
-- | Define the name of the server, along with an extra service identification blob.
-- this is important that the hostname part is properly filled for security reason,
-- as it allow to properly associate the remote side with the given certificate
-- during a handshake.
--
-- The extra blob is useful to differentiate services running on the same host, but that
-- might have different certificates given. It's only used as part of the X509 validation
-- infrastructure.
, clientServerIdentification :: (HostName, Bytes)
-- | Allow the use of the Server Name Indication TLS extension during handshake, which allow
-- the client to specify which host name, it's trying to access. This is useful to distinguish
-- CNAME aliasing (e.g. web virtual host).
, clientUseServerNameIndication :: Bool
-- | try to establish a connection using this session.
, clientWantSessionResume :: Maybe (SessionID, SessionData)
, clientShared :: Shared
, clientHooks :: ClientHooks
, clientCommonHooks :: CommonHooks
, clientSupported :: Supported
} deriving (Show)
defaultParamsClient :: HostName -> Bytes -> ClientParams
defaultParamsClient serverName serverId = ClientParams
{ clientWantSessionResume = Nothing
, clientUseMaxFragmentLength = Nothing
, clientServerIdentification = (serverName, serverId)
, clientUseServerNameIndication = True
, clientShared = def
, clientHooks = def
, clientCommonHooks = def
, clientSupported = def
}
data ServerParams = ServerParams
{ -- | request a certificate from client.
serverWantClientCert :: Bool
-- | This is a list of certificates from which the
-- disinguished names are sent in certificate request
-- messages. For TLS1.0, it should not be empty.
, serverCACertificates :: [SignedCertificate]
-- | Server Optional Diffie Hellman parameters. If this value is not
-- properly set, no Diffie Hellman key exchange will take place.
, serverDHEParams :: Maybe DHParams
, serverShared :: Shared
, serverHooks :: ServerHooks
, serverCommonHooks :: CommonHooks
, serverSupported :: Supported
} deriving (Show)
defaultParamsServer :: ServerParams
defaultParamsServer = ServerParams
{ serverWantClientCert = False
, serverCACertificates = []
, serverDHEParams = Nothing
, serverHooks = def
, serverShared = def
, serverCommonHooks = def
, serverSupported = def
}
instance Default ServerParams where
def = defaultParamsServer
-- | List all the supported algorithms, versions, ciphers, etc supported.
data Supported = Supported
{
-- | Supported Versions by this context
-- On the client side, the highest version will be used to establish the connection.
-- On the server side, the highest version that is less or equal than the client version will be chosed.
supportedVersions :: [Version]
-- | Supported cipher methods
, supportedCiphers :: [Cipher]
-- | supported compressions methods
, supportedCompressions :: [Compression]
-- | All supported hash/signature algorithms pair for client
-- certificate verification, ordered by decreasing priority.
, supportedHashSignatures :: [HashAndSignatureAlgorithm]
-- | Set if we support secure renegotiation.
, supportedSecureRenegotiation :: Bool
-- | Set if we support session.
, supportedSession :: Bool
} deriving (Show,Eq)
defaultSupported :: Supported
defaultSupported = Supported
{ supportedVersions = [TLS10,TLS11,TLS12]
, supportedCiphers = []
, supportedCompressions = [nullCompression]
, supportedHashSignatures = [ (Struct.HashSHA512, SignatureRSA)
, (Struct.HashSHA384, SignatureRSA)
, (Struct.HashSHA256, SignatureRSA)
, (Struct.HashSHA224, SignatureRSA)
, (Struct.HashSHA1, SignatureDSS)
]
, supportedSecureRenegotiation = True
, supportedSession = True
}
instance Default Supported where
def = defaultSupported
data Shared = Shared
{ sharedCredentials :: Credentials
, sharedSessionManager :: SessionManager
, sharedCAStore :: CertificateStore
, sharedValidationCache :: ValidationCache
}
instance Show Shared where
show _ = "Shared"
instance Default Shared where
def = Shared
{ sharedCAStore = mempty
, sharedCredentials = mempty
, sharedSessionManager = noSessionManager
, sharedValidationCache = def
}
-- | A set of callbacks run by the clients for various corners of TLS establishment
data ClientHooks = ClientHooks
{ -- | This action is called when the server sends a
-- certificate request. The parameter is the information -- certificate request. The parameter is the information
-- from the request. The action should select a certificate -- from the request. The action should select a certificate
-- chain of one of the given certificate types where the -- chain of one of the given certificate types where the
@ -75,24 +187,32 @@ data ClientParams = ClientParams
-- Returning a certificate chain not matching the -- Returning a certificate chain not matching the
-- distinguished names may lead to problems or not, -- distinguished names may lead to problems or not,
-- depending whether the server accepts it. -- depending whether the server accepts it.
, onCertificateRequest :: ([CertificateType], onCertificateRequest :: ([CertificateType],
Maybe [HashAndSignatureAlgorithm], Maybe [HashAndSignatureAlgorithm],
[DistinguishedName]) -> IO (Maybe (CertificateChain, PrivKey)) [DistinguishedName]) -> IO (Maybe (CertificateChain, PrivKey))
, onNPNServerSuggest :: Maybe ([B.ByteString] -> IO B.ByteString) , onNPNServerSuggest :: Maybe ([B.ByteString] -> IO B.ByteString)
, onServerCertificate :: CertificateStore -> ValidationCache -> ServiceID -> CertificateChain -> IO [FailedReason]
} }
data ServerParams = ServerParams defaultClientHooks :: ClientHooks
{ serverWantClientCert :: Bool -- ^ request a certificate from client. defaultClientHooks = ClientHooks
{ onCertificateRequest = \ _ -> return Nothing
, onNPNServerSuggest = Nothing
, onServerCertificate = validateDefault
}
-- | This is a list of certificates from which the instance Show ClientHooks where
-- disinguished names are sent in certificate request show _ = "ClientHooks"
-- messages. For TLS1.0, it should not be empty. instance Default ClientHooks where
, serverCACertificates :: [SignedCertificate] def = defaultClientHooks
-- | A set of callbacks run by the server for various corners of the TLS establishment
data ServerHooks = ServerHooks
{
-- | This action is called when a client certificate chain -- | This action is called when a client certificate chain
-- is received from the client. When it returns a -- is received from the client. When it returns a
-- CertificateUsageReject value, the handshake is aborted. -- CertificateUsageReject value, the handshake is aborted.
, onClientCertificate :: CertificateChain -> IO CertificateUsage onClientCertificate :: CertificateChain -> IO CertificateUsage
-- | This action is called when the client certificate -- | This action is called when the client certificate
-- cannot be verified. A 'Nothing' argument indicates a -- cannot be verified. A 'Nothing' argument indicates a
@ -109,113 +229,35 @@ data ServerParams = ServerParams
-- The client cipher list cannot be empty. -- The client cipher list cannot be empty.
, onCipherChoosing :: Version -> [Cipher] -> Cipher , onCipherChoosing :: Version -> [Cipher] -> Cipher
-- | Server Optional Diffie Hellman parameters
, serverDHEParams :: Maybe DHParams
-- | suggested next protocols accoring to the next protocol negotiation extension. -- | suggested next protocols accoring to the next protocol negotiation extension.
, onSuggestNextProtocols :: IO (Maybe [B.ByteString]) , onSuggestNextProtocols :: IO (Maybe [B.ByteString])
} }
data RoleParams = Client ClientParams | Server ServerParams defaultServerHooks :: ServerHooks
defaultServerHooks = ServerHooks
{ onCipherChoosing = \_ -> head
, onClientCertificate = \_ -> return $ CertificateUsageReject $ CertificateRejectOther "no client certificates expected"
, onUnverifiedClientCert = return False
, onSuggestNextProtocols = return Nothing
}
data Params = Params instance Show ServerHooks where
{ pAllowedVersions :: [Version] -- ^ allowed versions that we can use. show _ = "ClientHooks"
-- the default version used for connection is the highest version in the list instance Default ServerHooks where
, pCiphers :: [Cipher] -- ^ all ciphers supported ordered by priority. def = defaultServerHooks
, pCompressions :: [Compression] -- ^ all compression supported ordered by priority.
, pHashSignatures :: [HashAndSignatureAlgorithm] -- ^ All supported hash/signature algorithms pair for client certificate verification, ordered by decreasing priority. data CommonHooks = CommonHooks
, pUseSecureRenegotiation :: Bool -- ^ notify that we want to use secure renegotation { onCertificatesRecv :: CertificateChain -> IO CertificateUsage -- ^ callback to verify received cert chain.
, pUseSession :: Bool -- ^ generate new session if specified
, pCertificates :: Maybe (CertificateChain, Maybe PrivKey) -- ^ the cert chain for this context with the associated keys if any.
, pCredentials :: Credentials -- ^ credentials
, pLogging :: Logging -- ^ callback for logging
, onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake , onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake
, onCertificatesRecv :: CertificateChain -> IO CertificateUsage -- ^ callback to verify received cert chain. , logging :: Logging -- ^ callback for logging
, pSessionManager :: SessionManager
, roleParams :: RoleParams
}
{-# DEPRECATED pCertificates "use pCredentials instead of pCertificates. removed in tls-1.3" #-}
credentialsGet :: Params -> Credentials
credentialsGet params = pCredentials params `mappend`
case pCertificates params of
Just (cchain, Just priv) -> Credentials [(cchain, priv)]
_ -> Credentials []
-- | Set a new session manager in a parameters structure.
setSessionManager :: SessionManager -> Params -> Params
setSessionManager manager (Params {..}) = Params { pSessionManager = manager, .. }
withSessionManager :: Params -> (SessionManager -> a) -> a
withSessionManager (Params { pSessionManager = man }) f = f man
getClientParams :: Params -> ClientParams
getClientParams params =
case roleParams params of
Client clientParams -> clientParams
_ -> error "server params in client context"
getServerParams :: Params -> ServerParams
getServerParams params =
case roleParams params of
Server serverParams -> serverParams
_ -> error "client params in server context"
defaultParamsClient :: Params
defaultParamsClient = Params
{ pAllowedVersions = [TLS10,TLS11,TLS12]
, pCiphers = []
, pCompressions = [nullCompression]
, pHashSignatures = [ (Struct.HashSHA512, SignatureRSA)
, (Struct.HashSHA384, SignatureRSA)
, (Struct.HashSHA256, SignatureRSA)
, (Struct.HashSHA224, SignatureRSA)
, (Struct.HashSHA1, SignatureDSS)
]
, pUseSecureRenegotiation = True
, pUseSession = True
, pCertificates = Nothing
, pCredentials = mempty
, pLogging = defaultLogging
, onHandshake = (\_ -> return True)
, onCertificatesRecv = (\_ -> return CertificateUsageAccept)
, pSessionManager = noSessionManager
, roleParams = Client $ ClientParams
{ clientWantSessionResume = Nothing
, clientUseMaxFragmentLength = Nothing
, clientUseServerName = Nothing
, onCertificateRequest = \ _ -> return Nothing
, onNPNServerSuggest = Nothing
}
} }
defaultParamsServer :: Params instance Show CommonHooks where
defaultParamsServer = defaultParamsClient { roleParams = Server role } show _ = "CommonHooks"
where role = ServerParams
{ serverWantClientCert = False
, onCipherChoosing = \_ -> head
, serverCACertificates = []
, serverDHEParams = Nothing
, onClientCertificate = \ _ -> return $ CertificateUsageReject $ CertificateRejectOther "no client certificates expected"
, onUnverifiedClientCert = return False
, onSuggestNextProtocols = return Nothing
}
updateRoleParams :: (ClientParams -> ClientParams) -> (ServerParams -> ServerParams) -> Params -> Params instance Default CommonHooks where
updateRoleParams fc fs params = case roleParams params of def = CommonHooks
Client c -> params { roleParams = Client (fc c) } { onCertificatesRecv = \_ -> return CertificateUsageAccept
Server s -> params { roleParams = Server (fs s) } , logging = def
, onHandshake = \_ -> return True
updateClientParams :: (ClientParams -> ClientParams) -> Params -> Params }
updateClientParams f = updateRoleParams f id
updateServerParams :: (ServerParams -> ServerParams) -> Params -> Params
updateServerParams f = updateRoleParams id f
instance Show Params where
show p = "Params { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v)
[ ("allowedVersions", show $ pAllowedVersions p)
, ("ciphers", show $ pCiphers p)
, ("compressions", show $ pCompressions p)
, ("certificates", show $ pCertificates p)
]) ++ " }"

View file

@ -16,7 +16,7 @@ import Control.Monad.State
import Control.Monad.Error import Control.Monad.Error
import Control.Concurrent.MVar import Control.Concurrent.MVar
import Network.TLS.Context import Network.TLS.Context.Internal
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Record import Network.TLS.Record
import Network.TLS.Packet import Network.TLS.Packet

View file

@ -23,7 +23,8 @@ import Network.TLS.Cap
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Record import Network.TLS.Record
import Network.TLS.Packet import Network.TLS.Packet
import Network.TLS.Context import Network.TLS.Context.Internal
import Network.TLS.Parameters
import Network.TLS.State import Network.TLS.State
import Network.TLS.Handshake.State import Network.TLS.Handshake.State
import Network.TLS.Cipher import Network.TLS.Cipher
@ -67,7 +68,7 @@ writePacket ctx pkt = do
-- so we use cstIV as is, however in other case we generate an explicit IV -- so we use cstIV as is, however in other case we generate an explicit IV
prepareRecord :: Context -> RecordM a -> IO (Either TLSError a) prepareRecord :: Context -> RecordM a -> IO (Either TLSError a)
prepareRecord ctx f = do prepareRecord ctx f = do
ver <- usingState_ ctx (getVersionWithDefault $ maximum $ pAllowedVersions $ ctxParams ctx) ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx)
txState <- readMVar $ ctxTxState ctx txState <- readMVar $ ctxTxState ctx
let sz = case stCipher $ txState of let sz = case stCipher $ txState of
Nothing -> 0 Nothing -> 0

View file

@ -32,7 +32,7 @@ data SessionData = SessionData
{ sessionVersion :: Version { sessionVersion :: Version
, sessionCipher :: CipherID , sessionCipher :: CipherID
, sessionSecret :: ByteString , sessionSecret :: ByteString
} } deriving (Show,Eq)
-- | Cipher identification -- | Cipher identification
type CipherID = Word16 type CipherID = Word16

View file

@ -16,9 +16,18 @@ module Network.TLS.X509
, getCertificateChainLeaf , getCertificateChainLeaf
, CertificateRejectReason(..) , CertificateRejectReason(..)
, CertificateUsage(..) , CertificateUsage(..)
, CertificateStore
, ValidationCache
, exceptionValidationCache
, validateDefault
, FailedReason
, ServiceID
, wrapCertificateChecks
) where ) where
import Data.X509 import Data.X509
import Data.X509.Validation
import Data.X509.CertificateStore
isNullCertificateChain :: CertificateChain -> Bool isNullCertificateChain :: CertificateChain -> Bool
isNullCertificateChain (CertificateChain l) = null l isNullCertificateChain (CertificateChain l) = null l
@ -41,3 +50,10 @@ data CertificateUsage =
| CertificateUsageReject CertificateRejectReason -- ^ usage of certificate rejected | CertificateUsageReject CertificateRejectReason -- ^ usage of certificate rejected
deriving (Show,Eq) deriving (Show,Eq)
wrapCertificateChecks :: [FailedReason] -> CertificateUsage
wrapCertificateChecks [] = CertificateUsageAccept
wrapCertificateChecks l
| Expired `elem` l = CertificateUsageReject $ CertificateRejectExpired
| InFuture `elem` l = CertificateUsageReject $ CertificateRejectExpired
| UnknownCA `elem` l = CertificateUsageReject $ CertificateRejectUnknownCA
| otherwise = CertificateUsageReject $ CertificateRejectOther (show l)

View file

@ -14,6 +14,8 @@ import PubKey
import PipeChan import PipeChan
import Network.TLS import Network.TLS
import Data.X509 import Data.X509
import Data.X509.Validation
import Data.Default.Class
import Control.Applicative import Control.Applicative
import Control.Concurrent.Chan import Control.Concurrent.Chan
import Control.Concurrent import Control.Concurrent
@ -70,11 +72,11 @@ streamCipher = blockCipher
} }
} }
supportedCiphers :: [Cipher] knownCiphers :: [Cipher]
supportedCiphers = [blockCipher,blockCipherDHE_RSA,blockCipherDHE_DSS,streamCipher] knownCiphers = [blockCipher,blockCipherDHE_RSA,blockCipherDHE_DSS,streamCipher]
supportedVersions :: [Version] knownVersions :: [Version]
supportedVersions = [SSL3,TLS10,TLS11,TLS12] knownVersions = [SSL3,TLS10,TLS11,TLS12]
arbitraryPairParams = do arbitraryPairParams = do
(dsaPub, dsaPriv) <- (\(p,r) -> (PubKeyDSA p, PrivKeyDSA r)) <$> arbitraryDSAPair (dsaPub, dsaPriv) <- (\(p,r) -> (PubKeyDSA p, PrivKeyDSA r)) <$> arbitraryDSAPair
@ -83,44 +85,54 @@ arbitraryPairParams = do
cert <- arbitraryX509WithKey (pub, priv) cert <- arbitraryX509WithKey (pub, priv)
return (CertificateChain [cert], priv) return (CertificateChain [cert], priv)
) [ (pubKey, privKey), (dsaPub, dsaPriv) ] ) [ (pubKey, privKey), (dsaPub, dsaPriv) ]
connectVersion <- elements supportedVersions connectVersion <- elements knownVersions
let allowedVersions = [ v | v <- supportedVersions, v <= connectVersion ] let allowedVersions = [ v | v <- knownVersions, v <= connectVersion ]
serAllowedVersions <- (:[]) `fmap` elements allowedVersions serAllowedVersions <- (:[]) `fmap` elements allowedVersions
serverCiphers <- arbitraryCiphers serverCiphers <- arbitraryCiphers
clientCiphers <- oneof [arbitraryCiphers] `suchThat` (\cs -> or [x `elem` serverCiphers | x <- cs]) clientCiphers <- oneof [arbitraryCiphers] `suchThat` (\cs -> or [x `elem` serverCiphers | x <- cs])
secNeg <- arbitrary secNeg <- arbitrary
--let cred = (CertificateChain [servCert], PrivKeyRSA privKey)
let serverState = defaultParamsServer -- , pLogging = logging "server: "
{ pAllowedVersions = serAllowedVersions -- , pLogging = logging "client: "
, pCiphers = serverCiphers
, pCredentials = Credentials creds let serverState = def
, pUseSecureRenegotiation = secNeg { serverSupported = def { supportedCiphers = serverCiphers
, pLogging = logging "server: " , supportedVersions = serAllowedVersions
, roleParams = roleParams $ updateServerParams (\sp -> sp { serverDHEParams = Just dhParams }) defaultParamsServer , supportedSecureRenegotiation = secNeg
}
, serverDHEParams = Just dhParams
, serverShared = def { sharedCredentials = Credentials creds }
} }
let clientState = defaultParamsClient let clientState = (defaultParamsClient "" B.empty)
{ pAllowedVersions = allowedVersions { clientSupported = def { supportedCiphers = clientCiphers
, pCiphers = clientCiphers , supportedVersions = allowedVersions
, pUseSecureRenegotiation = secNeg , supportedSecureRenegotiation = secNeg
, pLogging = logging "client: " }
, clientShared = def { sharedValidationCache = ValidationCache
{ cacheAdd = \_ _ _ -> return ()
, cacheQuery = \_ _ _ -> return ValidationCachePass
}
}
} }
return (clientState, serverState) return (clientState, serverState)
where where
logging pre = logging pre =
if debug if debug
then defaultLogging { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++) then def { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++)
, loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++) } , loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++) }
else defaultLogging else def
arbitraryCiphers = resize (length supportedCiphers + 1) $ listOf1 (elements supportedCiphers) arbitraryCiphers = resize (length knownCiphers + 1) $ listOf1 (elements knownCiphers)
setPairParamsSessionManager :: SessionManager -> (Params, Params) -> (Params, Params) setPairParamsSessionManager :: SessionManager -> (ClientParams, ServerParams) -> (ClientParams, ServerParams)
setPairParamsSessionManager manager (clientState, serverState) = (nc,ns) setPairParamsSessionManager manager (clientState, serverState) = (nc,ns)
where nc = setSessionManager manager clientState where nc = clientState { clientShared = updateSessionManager $ clientShared clientState }
ns = setSessionManager manager serverState ns = serverState { serverShared = updateSessionManager $ serverShared serverState }
updateSessionManager shared = shared { sharedSessionManager = manager }
setPairParamsSessionResuming sessionStuff (clientState, serverState) = (nc,serverState) setPairParamsSessionResuming sessionStuff (clientState, serverState) =
where nc = updateClientParams (\cparams -> cparams { clientWantSessionResume = Just sessionStuff }) clientState ( clientState { clientWantSessionResume = Just sessionStuff }
, serverState)
newPairContext pipe (cParams, sParams) = do newPairContext pipe (cParams, sParams) = do
let noFlush = return () let noFlush = return ()

View file

@ -72,8 +72,12 @@ prop_handshake_initiate = do
prop_handshake_npn_initiate :: PropertyM IO () prop_handshake_npn_initiate :: PropertyM IO ()
prop_handshake_npn_initiate = do prop_handshake_npn_initiate = do
(clientParam,serverParam) <- pick arbitraryPairParams (clientParam,serverParam) <- pick arbitraryPairParams
let clientParam' = updateClientParams (\cp -> cp { onNPNServerSuggest = Just $ \protos -> return (head protos) }) clientParam let clientParam' = clientParam { clientHooks = (clientHooks clientParam)
serverParam' = updateServerParams (\sp -> sp { onSuggestNextProtocols = return $ Just [C8.pack "spdy/2", C8.pack "http/1.1"] }) serverParam { onNPNServerSuggest = Just $ \protos -> return (head protos) }
}
serverParam' = serverParam { serverHooks = (serverHooks serverParam)
{ onSuggestNextProtocols = return $ Just [C8.pack "spdy/2", C8.pack "http/1.1"] }
}
params' = (clientParam',serverParam') params' = (clientParam',serverParam')
runTLSPipe params' tlsServer tlsClient runTLSPipe params' tlsServer tlsClient
where tlsServer ctx queue = do where tlsServer ctx queue = do

View file

@ -37,6 +37,7 @@ Library
, cereal >= 0.3 , cereal >= 0.3
, bytestring , bytestring
, network , network
, data-default-class
, crypto-random >= 0.0 && < 0.1 , crypto-random >= 0.0 && < 0.1
, crypto-numbers , crypto-numbers
, crypto-pubkey-types >= 0.4 , crypto-pubkey-types >= 0.4
@ -45,6 +46,7 @@ Library
, asn1-encoding , asn1-encoding
, x509 >= 1.4.3 && < 1.5.0 , x509 >= 1.4.3 && < 1.5.0
, x509-store , x509-store
, x509-validation >= 1.5.0 && < 1.6.0
Exposed-modules: Network.TLS Exposed-modules: Network.TLS
Network.TLS.Cipher Network.TLS.Cipher
Network.TLS.Compression Network.TLS.Compression
@ -53,6 +55,7 @@ Library
Network.TLS.Struct Network.TLS.Struct
Network.TLS.Core Network.TLS.Core
Network.TLS.Context Network.TLS.Context
Network.TLS.Context.Internal
Network.TLS.Credentials Network.TLS.Credentials
Network.TLS.Backend Network.TLS.Backend
Network.TLS.Crypto Network.TLS.Crypto
@ -102,6 +105,7 @@ Test-Suite test-tls
Build-Depends: base >= 3 && < 5 Build-Depends: base >= 3 && < 5
, mtl , mtl
, cereal >= 0.3 , cereal >= 0.3
, data-default-class
, QuickCheck >= 2 , QuickCheck >= 2
, test-framework , test-framework
, test-framework-quickcheck2 , test-framework-quickcheck2
@ -109,6 +113,7 @@ Test-Suite test-tls
, crypto-pubkey >= 0.2 , crypto-pubkey >= 0.2
, bytestring , bytestring
, x509 , x509
, x509-validation
, tls , tls
, time , time
, crypto-random , crypto-random
@ -122,6 +127,8 @@ Benchmark bench-tls
Build-depends: base >= 4 && < 5 Build-depends: base >= 4 && < 5
, tls , tls
, x509 , x509
, x509-validation
, data-default-class
, crypto-random , crypto-random
, criterion , criterion
, cprng-aes , cprng-aes