move to a proper role type for client|server
This commit is contained in:
parent
dd30cc05b0
commit
5ca744a8bf
6 changed files with 39 additions and 30 deletions
|
@ -92,6 +92,7 @@ import Network.TLS.State
|
|||
import Network.TLS.Handshake.State
|
||||
import Network.TLS.Measurement
|
||||
import Network.TLS.X509
|
||||
import Network.TLS.Types (Role(..))
|
||||
import Data.List (intercalate)
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
|
@ -376,10 +377,10 @@ contextNew :: (MonadIO m, CPRG rng)
|
|||
-> rng -- ^ Random number generator associated with this context.
|
||||
-> m Context
|
||||
contextNew backend params rng = liftIO $ do
|
||||
let clientContext = case roleParams params of
|
||||
Client {} -> True
|
||||
Server {} -> False
|
||||
let st = newTLSState rng clientContext
|
||||
let role = case roleParams params of
|
||||
Client {} -> ClientRole
|
||||
Server {} -> ServerRole
|
||||
let st = newTLSState rng role
|
||||
|
||||
stvar <- newMVar st
|
||||
eof <- newIORef False
|
||||
|
@ -387,7 +388,7 @@ contextNew backend params rng = liftIO $ do
|
|||
stats <- newIORef newMeasurement
|
||||
-- we enable the reception of SSLv2 ClientHello message only in the
|
||||
-- server context, where we might be dealing with an old/compat client.
|
||||
sslv2Compat <- newIORef (not clientContext)
|
||||
sslv2Compat <- newIORef (role == ServerRole)
|
||||
hooks <- newIORef defaultHooks
|
||||
lockWrite <- newMVar ()
|
||||
lockRead <- newMVar ()
|
||||
|
|
|
@ -22,6 +22,7 @@ import Control.Monad.Error
|
|||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
|
||||
import Network.TLS.Types (Role(..))
|
||||
import Network.TLS.Util
|
||||
import Network.TLS.Struct
|
||||
import Network.TLS.Record
|
||||
|
@ -70,16 +71,16 @@ processPacket (Record ProtocolType_DeprecatedHandshake _ fragment) =
|
|||
|
||||
processHandshake :: Handshake -> TLSSt ()
|
||||
processHandshake hs = do
|
||||
clientmode <- isClientContext
|
||||
role <- isClientContext
|
||||
case hs of
|
||||
ClientHello cver ran _ _ _ ex _ -> unless clientmode $ do
|
||||
ClientHello cver ran _ _ _ ex _ -> when (role == ServerRole) $ do
|
||||
mapM_ processClientExtension ex
|
||||
startHandshakeClient cver ran
|
||||
Certificates certs -> processCertificates clientmode certs
|
||||
ClientKeyXchg content -> unless clientmode $ do
|
||||
Certificates certs -> processCertificates role certs
|
||||
ClientKeyXchg content -> when (role == ServerRole) $ do
|
||||
processClientKeyXchg content
|
||||
HsNextProtocolNegotiation selected_protocol ->
|
||||
unless clientmode $ setNegotiatedProtocol selected_protocol
|
||||
when (role == ServerRole) $ setNegotiatedProtocol selected_protocol
|
||||
Finished fdata -> processClientFinished fdata
|
||||
_ -> return ()
|
||||
let encoded = encodeHandshake hs
|
||||
|
@ -148,17 +149,17 @@ processClientKeyXchg encryptedPremaster = do
|
|||
processClientFinished :: FinishedData -> TLSSt ()
|
||||
processClientFinished fdata = do
|
||||
cc <- isClientContext
|
||||
expected <- getHandshakeDigest (not cc)
|
||||
expected <- getHandshakeDigest (cc == ServerRole)
|
||||
when (expected /= fdata) $ do
|
||||
throwError $ Error_Protocol("bad record mac", True, BadRecordMac)
|
||||
updateVerifiedData False fdata
|
||||
updateVerifiedData ServerRole fdata
|
||||
return ()
|
||||
|
||||
processCertificates :: Bool -> CertificateChain -> TLSSt ()
|
||||
processCertificates False (CertificateChain []) = return ()
|
||||
processCertificates True (CertificateChain []) =
|
||||
processCertificates :: Role -> CertificateChain -> TLSSt ()
|
||||
processCertificates ServerRole (CertificateChain []) = return ()
|
||||
processCertificates ClientRole (CertificateChain []) =
|
||||
throwError $ Error_Protocol ("server certificate missing", True, HandshakeFailure)
|
||||
processCertificates clientmode (CertificateChain (c:_))
|
||||
| clientmode = withHandshakeM $ setPublicKey pubkey
|
||||
| otherwise = withHandshakeM $ setClientPublicKey pubkey
|
||||
processCertificates role (CertificateChain (c:_))
|
||||
| role == ClientRole = withHandshakeM $ setPublicKey pubkey
|
||||
| otherwise = withHandshakeM $ setClientPublicKey pubkey
|
||||
where pubkey = certPubKey $ getCertificate c
|
||||
|
|
|
@ -38,6 +38,7 @@ import Network.TLS.Wire
|
|||
import Network.TLS.Packet
|
||||
import Network.TLS.MAC
|
||||
import Network.TLS.Util
|
||||
import Network.TLS.Types (Role(..))
|
||||
|
||||
import qualified Data.ByteString as B
|
||||
|
||||
|
@ -59,7 +60,7 @@ data TransmissionState = TransmissionState
|
|||
} deriving (Show)
|
||||
|
||||
data RecordState = RecordState
|
||||
{ stClientContext :: Bool
|
||||
{ stClientContext :: Role
|
||||
, stVersion :: !Version
|
||||
, stTxState :: TransmissionState
|
||||
, stRxState :: TransmissionState
|
||||
|
@ -95,7 +96,7 @@ incrTransmissionState :: TransmissionState -> TransmissionState
|
|||
incrTransmissionState ts = ts { stMacState = MacState (ms + 1) }
|
||||
where (MacState ms) = stMacState ts
|
||||
|
||||
newRecordState :: CPRG g => g -> Bool -> RecordState
|
||||
newRecordState :: CPRG g => g -> Role -> RecordState
|
||||
newRecordState rng clientContext = RecordState
|
||||
{ stClientContext = clientContext
|
||||
, stVersion = TLS10
|
||||
|
|
|
@ -17,6 +17,7 @@ import Data.ByteString (ByteString)
|
|||
import qualified Data.ByteString as B
|
||||
|
||||
import Network.TLS.Util
|
||||
import Network.TLS.Types (Role(..))
|
||||
import Network.TLS.Struct
|
||||
import Network.TLS.Record
|
||||
import Network.TLS.Packet
|
||||
|
@ -46,7 +47,7 @@ writePacket :: Packet -> TLSSt ByteString
|
|||
writePacket pkt@(Handshake hss) = do
|
||||
forM_ hss $ \hs -> do
|
||||
case hs of
|
||||
Finished fdata -> updateVerifiedData True fdata
|
||||
Finished fdata -> updateVerifiedData ClientRole fdata
|
||||
_ -> return ()
|
||||
let encoded = encodeHandshake hs
|
||||
when (certVerifyHandshakeMaterial hs) $ withHandshakeM $ addHandshakeMessage encoded
|
||||
|
|
|
@ -69,6 +69,7 @@ import Network.TLS.Cipher
|
|||
import Network.TLS.Record.State
|
||||
import Network.TLS.Handshake.State
|
||||
import Network.TLS.RNG
|
||||
import Network.TLS.Types (Role(..))
|
||||
import qualified Data.ByteString as B
|
||||
import Control.Applicative ((<$>))
|
||||
import Control.Monad
|
||||
|
@ -128,7 +129,7 @@ runRecordStateSt f = do
|
|||
(Left e, _) -> throwError e
|
||||
(Right a, newSt) -> put newSt >> return a
|
||||
|
||||
newTLSState :: CPRG g => g -> Bool -> TLSState
|
||||
newTLSState :: CPRG g => g -> Role -> TLSState
|
||||
newTLSState rng clientContext = TLSState
|
||||
{ stHandshake = Nothing
|
||||
, stSession = Session Nothing
|
||||
|
@ -143,7 +144,7 @@ newTLSState rng clientContext = TLSState
|
|||
, stClientCertificateChain = Nothing
|
||||
}
|
||||
|
||||
updateVerifiedData :: MonadState TLSState m => Bool -> Bytes -> m ()
|
||||
updateVerifiedData :: MonadState TLSState m => Role -> Bytes -> m ()
|
||||
updateVerifiedData sending bs = do
|
||||
cc <- isClientContext
|
||||
if cc /= sending
|
||||
|
@ -233,7 +234,7 @@ isSessionResuming = gets stSessionResuming
|
|||
needEmptyPacket :: MonadState RecordState m => m Bool
|
||||
needEmptyPacket = gets f
|
||||
where f st = (stVersion st <= TLS10)
|
||||
&& stClientContext st
|
||||
&& stClientContext st == ClientRole
|
||||
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
|
||||
|
||||
setKeyBlock :: MonadState TLSState m => m ()
|
||||
|
@ -267,14 +268,14 @@ setKeyBlock = modify setPendingState
|
|||
msServer = MacState { msSequence = 0 }
|
||||
|
||||
pendingTx = TransmissionState
|
||||
{ stCryptState = if cc then cstClient else cstServer
|
||||
, stMacState = if cc then msClient else msServer
|
||||
{ stCryptState = if cc == ClientRole then cstClient else cstServer
|
||||
, stMacState = if cc == ClientRole then msClient else msServer
|
||||
, stCipher = Just cipher
|
||||
, stCompression = stPendingCompression rst
|
||||
}
|
||||
pendingRx = TransmissionState
|
||||
{ stCryptState = if cc then cstServer else cstClient
|
||||
, stMacState = if cc then msServer else msClient
|
||||
{ stCryptState = if cc == ClientRole then cstServer else cstClient
|
||||
, stMacState = if cc == ClientRole then msServer else msClient
|
||||
, stCipher = Just cipher
|
||||
, stCompression = stPendingCompression rst
|
||||
}
|
||||
|
@ -326,10 +327,9 @@ getCipherKeyExchangeType = gets (\st -> cipherKeyExchange <$> stPendingCipher st
|
|||
getVerifiedData :: MonadState TLSState m => Bool -> m Bytes
|
||||
getVerifiedData client = gets (if client then stClientVerifiedData else stServerVerifiedData)
|
||||
|
||||
isClientContext :: MonadState TLSState m => m Bool
|
||||
isClientContext :: MonadState TLSState m => m Role
|
||||
isClientContext = getRecordState stClientContext
|
||||
|
||||
|
||||
startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m ()
|
||||
startHandshakeClient ver crand = do
|
||||
-- FIXME check if handshake is already not null
|
||||
|
|
|
@ -11,6 +11,7 @@ module Network.TLS.Types
|
|||
, SessionData(..)
|
||||
, CipherID
|
||||
, CompressionID
|
||||
, Role(..)
|
||||
) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
|
@ -36,3 +37,7 @@ type CipherID = Word16
|
|||
|
||||
-- | Compression identification
|
||||
type CompressionID = Word8
|
||||
|
||||
-- | Role
|
||||
data Role = ClientRole | ServerRole
|
||||
deriving (Show,Eq)
|
||||
|
|
Loading…
Reference in a new issue