208 lines
6.7 KiB
Haskell
208 lines
6.7 KiB
Haskell
|
-- |
|
||
|
-- Module : Network.TLS.Context
|
||
|
-- License : BSD-style
|
||
|
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
|
||
|
-- Stability : experimental
|
||
|
-- Portability : unknown
|
||
|
--
|
||
|
module Network.TLS.Context
|
||
|
(
|
||
|
-- * Context configuration
|
||
|
TLSParams(..)
|
||
|
, TLSLogging(..)
|
||
|
, Measurement(..)
|
||
|
, TLSCertificateUsage(..)
|
||
|
, TLSCertificateRejectReason(..)
|
||
|
, defaultLogging
|
||
|
, defaultParams
|
||
|
|
||
|
-- * Context object and accessor
|
||
|
, TLSCtx
|
||
|
, ctxParams
|
||
|
, ctxConnection
|
||
|
, ctxEOF
|
||
|
, ctxLogging
|
||
|
, setEOF
|
||
|
, connectionFlush
|
||
|
, connectionSend
|
||
|
, connectionRecv
|
||
|
, updateMeasure
|
||
|
, withMeasure
|
||
|
|
||
|
-- * New contexts
|
||
|
, newCtxWith
|
||
|
, newCtx
|
||
|
|
||
|
-- * Using context states
|
||
|
, throwCore
|
||
|
, usingState
|
||
|
, usingState_
|
||
|
, getStateRNG
|
||
|
) where
|
||
|
|
||
|
import Network.TLS.Struct
|
||
|
import Network.TLS.Cipher
|
||
|
import Network.TLS.Compression
|
||
|
import Network.TLS.Crypto
|
||
|
import Network.TLS.State
|
||
|
import Network.TLS.Measurement
|
||
|
import Data.Maybe
|
||
|
import Data.Certificate.X509
|
||
|
import Data.List (intercalate)
|
||
|
import qualified Data.ByteString as B
|
||
|
|
||
|
import Control.Concurrent.MVar
|
||
|
import Control.Monad.State
|
||
|
import Control.Exception (throwIO, Exception(), onException)
|
||
|
import Data.IORef
|
||
|
import System.IO (Handle, hSetBuffering, BufferMode(..), hFlush)
|
||
|
import Prelude hiding (catch)
|
||
|
|
||
|
data TLSLogging = TLSLogging
|
||
|
{ loggingPacketSent :: String -> IO ()
|
||
|
, loggingPacketRecv :: String -> IO ()
|
||
|
, loggingIOSent :: Bytes -> IO ()
|
||
|
, loggingIORecv :: Header -> Bytes -> IO ()
|
||
|
}
|
||
|
|
||
|
data TLSParams = TLSParams
|
||
|
{ pConnectVersion :: Version -- ^ version to use on client connection.
|
||
|
, pAllowedVersions :: [Version] -- ^ allowed versions that we can use.
|
||
|
, pCiphers :: [Cipher] -- ^ all ciphers supported ordered by priority.
|
||
|
, pCompressions :: [Compression] -- ^ all compression supported ordered by priority.
|
||
|
, pWantClientCert :: Bool -- ^ request a certificate from client.
|
||
|
-- use by server only.
|
||
|
, pUseSecureRenegotiation :: Bool -- notify that we want to use secure renegotation
|
||
|
, pCertificates :: [(X509, Maybe PrivateKey)] -- ^ the cert chain for this context with the associated keys if any.
|
||
|
, pLogging :: TLSLogging -- ^ callback for logging
|
||
|
, onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake
|
||
|
, onCertificatesRecv :: [X509] -> IO TLSCertificateUsage -- ^ callback to verify received cert chain.
|
||
|
}
|
||
|
|
||
|
defaultLogging :: TLSLogging
|
||
|
defaultLogging = TLSLogging
|
||
|
{ loggingPacketSent = (\_ -> return ())
|
||
|
, loggingPacketRecv = (\_ -> return ())
|
||
|
, loggingIOSent = (\_ -> return ())
|
||
|
, loggingIORecv = (\_ _ -> return ())
|
||
|
}
|
||
|
|
||
|
defaultParams :: TLSParams
|
||
|
defaultParams = TLSParams
|
||
|
{ pConnectVersion = TLS10
|
||
|
, pAllowedVersions = [TLS10,TLS11,TLS12]
|
||
|
, pCiphers = []
|
||
|
, pCompressions = [nullCompression]
|
||
|
, pWantClientCert = False
|
||
|
, pUseSecureRenegotiation = True
|
||
|
, pCertificates = []
|
||
|
, pLogging = defaultLogging
|
||
|
, onHandshake = (\_ -> return True)
|
||
|
, onCertificatesRecv = (\_ -> return CertificateUsageAccept)
|
||
|
}
|
||
|
|
||
|
instance Show TLSParams where
|
||
|
show p = "TLSParams { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v)
|
||
|
[ ("connectVersion", show $ pConnectVersion p)
|
||
|
, ("allowedVersions", show $ pAllowedVersions p)
|
||
|
, ("ciphers", show $ pCiphers p)
|
||
|
, ("compressions", show $ pCompressions p)
|
||
|
, ("want-client-cert", show $ pWantClientCert p)
|
||
|
, ("certificates", show $ length $ pCertificates p)
|
||
|
]) ++ " }"
|
||
|
|
||
|
-- | Certificate and Chain rejection reason
|
||
|
data TLSCertificateRejectReason =
|
||
|
CertificateRejectExpired
|
||
|
| CertificateRejectRevoked
|
||
|
| CertificateRejectUnknownCA
|
||
|
| CertificateRejectOther String
|
||
|
deriving (Show,Eq)
|
||
|
|
||
|
-- | Certificate Usage callback possible returns values.
|
||
|
data TLSCertificateUsage =
|
||
|
CertificateUsageAccept -- ^ usage of certificate accepted
|
||
|
| CertificateUsageReject TLSCertificateRejectReason -- ^ usage of certificate rejected
|
||
|
deriving (Show,Eq)
|
||
|
|
||
|
-- | A TLS Context is a handle augmented by tls specific state and parameters
|
||
|
data TLSCtx a = TLSCtx
|
||
|
{ ctxConnection :: a -- ^ return the connection object associated with this context
|
||
|
, ctxParams :: TLSParams
|
||
|
, ctxState :: MVar TLSState
|
||
|
, ctxMeasurement :: IORef Measurement
|
||
|
, ctxEOF_ :: IORef Bool -- ^ is the handle has EOFed or not.
|
||
|
, ctxConnectionFlush :: IO ()
|
||
|
, ctxConnectionSend :: Bytes -> IO ()
|
||
|
, ctxConnectionRecv :: Int -> IO Bytes
|
||
|
}
|
||
|
|
||
|
updateMeasure :: MonadIO m => TLSCtx c -> (Measurement -> Measurement) -> m ()
|
||
|
updateMeasure ctx f = liftIO $ modifyIORef (ctxMeasurement ctx) f
|
||
|
|
||
|
withMeasure :: MonadIO m => TLSCtx c -> (Measurement -> IO a) -> m a
|
||
|
withMeasure ctx f = liftIO (readIORef (ctxMeasurement ctx) >>= f)
|
||
|
|
||
|
connectionFlush :: TLSCtx c -> IO ()
|
||
|
connectionFlush c = ctxConnectionFlush c
|
||
|
|
||
|
connectionSend :: TLSCtx c -> Bytes -> IO ()
|
||
|
connectionSend c b = updateMeasure c (addBytesSent $ B.length b) >> (ctxConnectionSend c) b
|
||
|
|
||
|
connectionRecv :: TLSCtx c -> Int -> IO Bytes
|
||
|
connectionRecv c sz = updateMeasure c (addBytesReceived sz) >> (ctxConnectionRecv c) sz
|
||
|
|
||
|
ctxEOF :: MonadIO m => TLSCtx a -> m Bool
|
||
|
ctxEOF ctx = liftIO (readIORef $ ctxEOF_ ctx)
|
||
|
|
||
|
setEOF :: MonadIO m => TLSCtx c -> m ()
|
||
|
setEOF ctx = liftIO $ writeIORef (ctxEOF_ ctx) True
|
||
|
|
||
|
ctxLogging :: TLSCtx a -> TLSLogging
|
||
|
ctxLogging = pLogging . ctxParams
|
||
|
|
||
|
newCtxWith :: c -> IO () -> (Bytes -> IO ()) -> (Int -> IO Bytes) -> TLSParams -> TLSState -> IO (TLSCtx c)
|
||
|
newCtxWith c flushF sendF recvF params st = do
|
||
|
stvar <- newMVar st
|
||
|
eof <- newIORef False
|
||
|
stats <- newIORef newMeasurement
|
||
|
return $ TLSCtx
|
||
|
{ ctxConnection = c
|
||
|
, ctxParams = params
|
||
|
, ctxState = stvar
|
||
|
, ctxMeasurement = stats
|
||
|
, ctxEOF_ = eof
|
||
|
, ctxConnectionFlush = flushF
|
||
|
, ctxConnectionSend = sendF
|
||
|
, ctxConnectionRecv = recvF
|
||
|
}
|
||
|
|
||
|
newCtx :: Handle -> TLSParams -> TLSState -> IO (TLSCtx Handle)
|
||
|
newCtx handle params st = do
|
||
|
hSetBuffering handle NoBuffering
|
||
|
newCtxWith handle (hFlush handle) (B.hPut handle) (B.hGet handle) params st
|
||
|
|
||
|
throwCore :: (MonadIO m, Exception e) => e -> m a
|
||
|
throwCore = liftIO . throwIO
|
||
|
|
||
|
|
||
|
usingState :: MonadIO m => TLSCtx c -> TLSSt a -> m (Either TLSError a)
|
||
|
usingState ctx f = liftIO (takeMVar mvar) >>= \st -> liftIO $ onException (execAndStore st) (putMVar mvar st)
|
||
|
where
|
||
|
mvar = ctxState ctx
|
||
|
execAndStore st = do
|
||
|
let (a, newst) = runTLSState f st
|
||
|
putMVar mvar newst
|
||
|
return a
|
||
|
|
||
|
usingState_ :: MonadIO m => TLSCtx c -> TLSSt a -> m a
|
||
|
usingState_ ctx f = do
|
||
|
ret <- usingState ctx f
|
||
|
case ret of
|
||
|
Left err -> throwCore err
|
||
|
Right r -> return r
|
||
|
|
||
|
getStateRNG :: MonadIO m => TLSCtx c -> Int -> m Bytes
|
||
|
getStateRNG ctx n = usingState_ ctx (genTLSRandom n)
|
||
|
|