hs-tls/Network/TLS/Context.hs
Martin Grabmueller 3c46042ce5 Integrate client certificate settings into RoleParams,
remember client cert chain for use after handshake has
finished.
2012-07-16 14:36:44 +02:00

333 lines
13 KiB
Haskell

-- |
-- Module : Network.TLS.Context
-- License : BSD-style
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
-- Stability : experimental
-- Portability : unknown
--
{-# LANGUAGE ExistentialQuantification, Rank2Types #-}
-- only needed because of some GHC bug relative to insufficient polymorphic field
{-# LANGUAGE RecordWildCards #-}
module Network.TLS.Context
(
-- * Context configuration
Params(..)
, RoleParams(..)
, ClientParams(..)
, ServerParams(..)
, updateClientParams
, updateServerParams
, Logging(..)
, SessionID
, SessionData(..)
, Measurement(..)
, CertificateUsage(..)
, CertificateRejectReason(..)
, defaultLogging
, defaultParamsClient
, defaultParamsServer
, withSessionManager
, setSessionManager
-- * Context object and accessor
, Backend(..)
, Context
, ctxParams
, ctxConnection
, ctxEOF
, ctxEstablished
, ctxLogging
, setEOF
, setEstablished
, contextFlush
, contextClose
, contextSend
, contextRecv
, updateMeasure
, withMeasure
-- * deprecated types
, TLSParams
, TLSLogging
, TLSCertificateUsage
, TLSCertificateRejectReason
, TLSCtx
-- * deprecated values
, defaultParams
-- * New contexts
, contextNew
, contextNewOnHandle
-- * Using context states
, throwCore
, usingState
, usingState_
, getStateRNG
) where
import Network.TLS.Struct
import Network.TLS.Session
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Crypto
import Network.TLS.State
import Network.TLS.Measurement
import Data.Maybe
import Data.Certificate.X509
import Data.List (intercalate)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Crypto.Random
import Control.Concurrent.MVar
import Control.Monad.State
import Control.Exception (throwIO, Exception())
import Data.IORef
import System.IO (Handle, hSetBuffering, BufferMode(..), hFlush, hClose)
import Prelude hiding (catch)
data Logging = Logging
{ loggingPacketSent :: String -> IO ()
, loggingPacketRecv :: String -> IO ()
, loggingIOSent :: B.ByteString -> IO ()
, loggingIORecv :: Header -> B.ByteString -> IO ()
}
data ClientParams = ClientParams
{ clientWantSessionResume :: Maybe (SessionID, SessionData) -- ^ try to establish a connection using this session.
, onCertificateRequest :: ([CertificateType],
Maybe [(HashAlgorithm, SignatureAlgorithm)],
[DistinguishedName]) -> IO [(X509, Maybe PrivateKey)]
}
data ServerParams = ServerParams
{ serverWantClientCert :: Bool -- ^ request a certificate from client.
, serverCACertificates :: [X509]
, onClientCertificate :: [X509] -> IO CertificateUsage
, onUnverifiedClientCert :: Maybe KxError -> IO Bool
}
data RoleParams = Client ClientParams | Server ServerParams
data Params = forall s . SessionManager s => Params
{ pConnectVersion :: Version -- ^ version to use on client connection.
, pAllowedVersions :: [Version] -- ^ allowed versions that we can use.
, pCiphers :: [Cipher] -- ^ all ciphers supported ordered by priority.
, pCompressions :: [Compression] -- ^ all compression supported ordered by priority.
, pUseSecureRenegotiation :: Bool -- ^ notify that we want to use secure renegotation
, pUseSession :: Bool -- ^ generate new session if specified
, pCertificates :: [(X509, Maybe PrivateKey)] -- ^ the cert chain for this context with the associated keys if any.
, pLogging :: Logging -- ^ callback for logging
, onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake
, onCertificatesRecv :: [X509] -> IO CertificateUsage -- ^ callback to verify received cert chain.
, pSessionManager :: s
, onSuggestNextProtocols :: IO (Maybe [B.ByteString]) -- ^ suggested next protocols accoring to the next protocol negotiation extension.
, onNPNServerSuggest :: Maybe ([B.ByteString] -> IO B.ByteString)
, roleParams :: RoleParams
}
-- | Set a new session manager in a parameters structure.
setSessionManager :: SessionManager s => s -> Params -> Params
setSessionManager manager (Params {..}) = Params { pSessionManager = manager, .. }
withSessionManager :: Params -> (forall s . SessionManager s => s -> a) -> a
withSessionManager (Params { pSessionManager = man }) f = f man
defaultLogging :: Logging
defaultLogging = Logging
{ loggingPacketSent = (\_ -> return ())
, loggingPacketRecv = (\_ -> return ())
, loggingIOSent = (\_ -> return ())
, loggingIORecv = (\_ _ -> return ())
}
defaultParamsClient :: Params
defaultParamsClient = Params
{ pConnectVersion = TLS10
, pAllowedVersions = [TLS10,TLS11,TLS12]
, pCiphers = []
, pCompressions = [nullCompression]
, pUseSecureRenegotiation = True
, pUseSession = True
, pCertificates = []
, pLogging = defaultLogging
, onHandshake = (\_ -> return True)
, onCertificatesRecv = (\_ -> return CertificateUsageAccept)
, pSessionManager = NoSessionManager
, onSuggestNextProtocols = return Nothing
, onNPNServerSuggest = Nothing
, roleParams = Client $ ClientParams
{ clientWantSessionResume = Nothing
, onCertificateRequest = \ _ -> return []
}
}
defaultParamsServer :: Params
defaultParamsServer = defaultParamsClient
{ roleParams = Server $ ServerParams
{ serverWantClientCert = False
, serverCACertificates = []
, onClientCertificate = \ _ ->
return $ CertificateUsageReject $ CertificateRejectOther "no client certificates expected"
, onUnverifiedClientCert = \ _ -> return False
}
}
updateRoleParams :: (ClientParams -> ClientParams) -> (ServerParams -> ServerParams) -> Params -> Params
updateRoleParams fc fs params = case roleParams params of
Client c -> params { roleParams = Client (fc c) }
Server s -> params { roleParams = Server (fs s) }
updateClientParams :: (ClientParams -> ClientParams) -> Params -> Params
updateClientParams f = updateRoleParams f id
updateServerParams :: (ServerParams -> ServerParams) -> Params -> Params
updateServerParams f = updateRoleParams id f
defaultParams :: Params
defaultParams = defaultParamsClient
{-# DEPRECATED defaultParams "use defaultParamsClient" #-}
instance Show Params where
show p = "Params { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v)
[ ("connectVersion", show $ pConnectVersion p)
, ("allowedVersions", show $ pAllowedVersions p)
, ("ciphers", show $ pCiphers p)
, ("compressions", show $ pCompressions p)
, ("certificates", show $ length $ pCertificates p)
]) ++ " }"
-- | Certificate and Chain rejection reason
data CertificateRejectReason =
CertificateRejectExpired
| CertificateRejectRevoked
| CertificateRejectUnknownCA
| CertificateRejectOther String
deriving (Show,Eq)
-- | Certificate Usage callback possible returns values.
data CertificateUsage =
CertificateUsageAccept -- ^ usage of certificate accepted
| CertificateUsageReject CertificateRejectReason -- ^ usage of certificate rejected
deriving (Show,Eq)
-- | Connection IO backend
data Backend = Backend
{ backendFlush :: IO () -- ^ Flush the connection sending buffer, if any.
, backendClose :: IO () -- ^ Close the connection.
, backendSend :: ByteString -> IO () -- ^ Send a bytestring through the connection.
, backendRecv :: Int -> IO ByteString -- ^ Receive specified number of bytes from the connection.
}
-- | A TLS Context keep tls specific state, parameters and backend information.
data Context = Context
{ ctxConnection :: Backend -- ^ return the backend object associated with this context
, ctxParams :: Params
, ctxState :: MVar TLSState
, ctxMeasurement :: IORef Measurement
, ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not.
, ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful.
}
-- deprecated types, setup as aliases for compatibility.
type TLSParams = Params
type TLSCtx = Context
type TLSLogging = Logging
type TLSCertificateUsage = CertificateUsage
type TLSCertificateRejectReason = CertificateRejectReason
updateMeasure :: MonadIO m => Context -> (Measurement -> Measurement) -> m ()
updateMeasure ctx f = liftIO $ do
x <- readIORef (ctxMeasurement ctx)
writeIORef (ctxMeasurement ctx) $! f x
withMeasure :: MonadIO m => Context -> (Measurement -> IO a) -> m a
withMeasure ctx f = liftIO (readIORef (ctxMeasurement ctx) >>= f)
contextFlush :: Context -> IO ()
contextFlush = backendFlush . ctxConnection
contextClose :: Context -> IO ()
contextClose = backendClose . ctxConnection
contextSend :: Context -> Bytes -> IO ()
contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b
contextRecv :: Context -> Int -> IO Bytes
contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz
ctxEOF :: MonadIO m => Context -> m Bool
ctxEOF ctx = liftIO (readIORef $ ctxEOF_ ctx)
setEOF :: MonadIO m => Context -> m ()
setEOF ctx = liftIO $ writeIORef (ctxEOF_ ctx) True
ctxEstablished :: MonadIO m => Context -> m Bool
ctxEstablished ctx = liftIO $ readIORef $ ctxEstablished_ ctx
setEstablished :: MonadIO m => Context -> Bool -> m ()
setEstablished ctx v = liftIO $ writeIORef (ctxEstablished_ ctx) v
ctxLogging :: Context -> Logging
ctxLogging = pLogging . ctxParams
-- | create a new context using the backend and parameters specified.
contextNew :: (MonadIO m, CryptoRandomGen rng)
=> Backend -- ^ Backend abstraction with specific method to interacat with the connection type.
-> Params -- ^ Parameters of the context.
-> rng -- ^ Random number generator associated with this context.
-> m Context
contextNew backend params rng = liftIO $ do
let clientContext = case roleParams params of
Client {} -> True
Server {} -> False
let st = (newTLSState rng) { stClientContext = clientContext }
stvar <- newMVar st
eof <- newIORef False
established <- newIORef False
stats <- newIORef newMeasurement
return $ Context
{ ctxConnection = backend
, ctxParams = params
, ctxState = stvar
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
}
-- | create a new context on an handle.
contextNewOnHandle :: (MonadIO m, CryptoRandomGen rng)
=> Handle -- ^ Handle of the connection.
-> Params -- ^ Parameters of the context.
-> rng -- ^ Random number generator associated with this context.
-> m Context
contextNewOnHandle handle params st =
liftIO (hSetBuffering handle NoBuffering) >> contextNew backend params st
where backend = Backend (hFlush handle) (hClose handle) (B.hPut handle) (B.hGet handle)
throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO
usingState :: MonadIO m => Context -> TLSSt a -> m (Either TLSError a)
usingState ctx f =
liftIO $ modifyMVar (ctxState ctx) $ \st ->
let (a, newst) = runTLSState f st
in newst `seq` return (newst, a)
usingState_ :: MonadIO m => Context -> TLSSt a -> m a
usingState_ ctx f = do
ret <- usingState ctx f
case ret of
Left err -> throwCore err
Right r -> return r
getStateRNG :: MonadIO m => Context -> Int -> m Bytes
getStateRNG ctx n = usingState_ ctx (genTLSRandom n)