move to a proper role type for client|server

This commit is contained in:
Vincent Hanquez 2013-07-21 10:16:01 +01:00
parent dd30cc05b0
commit 5ca744a8bf
6 changed files with 39 additions and 30 deletions

View file

@ -92,6 +92,7 @@ import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Measurement
import Network.TLS.X509
import Network.TLS.Types (Role(..))
import Data.List (intercalate)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
@ -376,10 +377,10 @@ contextNew :: (MonadIO m, CPRG rng)
-> rng -- ^ Random number generator associated with this context.
-> m Context
contextNew backend params rng = liftIO $ do
let clientContext = case roleParams params of
Client {} -> True
Server {} -> False
let st = newTLSState rng clientContext
let role = case roleParams params of
Client {} -> ClientRole
Server {} -> ServerRole
let st = newTLSState rng role
stvar <- newMVar st
eof <- newIORef False
@ -387,7 +388,7 @@ contextNew backend params rng = liftIO $ do
stats <- newIORef newMeasurement
-- we enable the reception of SSLv2 ClientHello message only in the
-- server context, where we might be dealing with an old/compat client.
sslv2Compat <- newIORef (not clientContext)
sslv2Compat <- newIORef (role == ServerRole)
hooks <- newIORef defaultHooks
lockWrite <- newMVar ()
lockRead <- newMVar ()

View file

@ -22,6 +22,7 @@ import Control.Monad.Error
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Network.TLS.Types (Role(..))
import Network.TLS.Util
import Network.TLS.Struct
import Network.TLS.Record
@ -70,16 +71,16 @@ processPacket (Record ProtocolType_DeprecatedHandshake _ fragment) =
processHandshake :: Handshake -> TLSSt ()
processHandshake hs = do
clientmode <- isClientContext
role <- isClientContext
case hs of
ClientHello cver ran _ _ _ ex _ -> unless clientmode $ do
ClientHello cver ran _ _ _ ex _ -> when (role == ServerRole) $ do
mapM_ processClientExtension ex
startHandshakeClient cver ran
Certificates certs -> processCertificates clientmode certs
ClientKeyXchg content -> unless clientmode $ do
Certificates certs -> processCertificates role certs
ClientKeyXchg content -> when (role == ServerRole) $ do
processClientKeyXchg content
HsNextProtocolNegotiation selected_protocol ->
unless clientmode $ setNegotiatedProtocol selected_protocol
when (role == ServerRole) $ setNegotiatedProtocol selected_protocol
Finished fdata -> processClientFinished fdata
_ -> return ()
let encoded = encodeHandshake hs
@ -148,17 +149,17 @@ processClientKeyXchg encryptedPremaster = do
processClientFinished :: FinishedData -> TLSSt ()
processClientFinished fdata = do
cc <- isClientContext
expected <- getHandshakeDigest (not cc)
expected <- getHandshakeDigest (cc == ServerRole)
when (expected /= fdata) $ do
throwError $ Error_Protocol("bad record mac", True, BadRecordMac)
updateVerifiedData False fdata
updateVerifiedData ServerRole fdata
return ()
processCertificates :: Bool -> CertificateChain -> TLSSt ()
processCertificates False (CertificateChain []) = return ()
processCertificates True (CertificateChain []) =
processCertificates :: Role -> CertificateChain -> TLSSt ()
processCertificates ServerRole (CertificateChain []) = return ()
processCertificates ClientRole (CertificateChain []) =
throwError $ Error_Protocol ("server certificate missing", True, HandshakeFailure)
processCertificates clientmode (CertificateChain (c:_))
| clientmode = withHandshakeM $ setPublicKey pubkey
| otherwise = withHandshakeM $ setClientPublicKey pubkey
processCertificates role (CertificateChain (c:_))
| role == ClientRole = withHandshakeM $ setPublicKey pubkey
| otherwise = withHandshakeM $ setClientPublicKey pubkey
where pubkey = certPubKey $ getCertificate c

View file

@ -38,6 +38,7 @@ import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Types (Role(..))
import qualified Data.ByteString as B
@ -59,7 +60,7 @@ data TransmissionState = TransmissionState
} deriving (Show)
data RecordState = RecordState
{ stClientContext :: Bool
{ stClientContext :: Role
, stVersion :: !Version
, stTxState :: TransmissionState
, stRxState :: TransmissionState
@ -95,7 +96,7 @@ incrTransmissionState :: TransmissionState -> TransmissionState
incrTransmissionState ts = ts { stMacState = MacState (ms + 1) }
where (MacState ms) = stMacState ts
newRecordState :: CPRG g => g -> Bool -> RecordState
newRecordState :: CPRG g => g -> Role -> RecordState
newRecordState rng clientContext = RecordState
{ stClientContext = clientContext
, stVersion = TLS10

View file

@ -17,6 +17,7 @@ import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Network.TLS.Util
import Network.TLS.Types (Role(..))
import Network.TLS.Struct
import Network.TLS.Record
import Network.TLS.Packet
@ -46,7 +47,7 @@ writePacket :: Packet -> TLSSt ByteString
writePacket pkt@(Handshake hss) = do
forM_ hss $ \hs -> do
case hs of
Finished fdata -> updateVerifiedData True fdata
Finished fdata -> updateVerifiedData ClientRole fdata
_ -> return ()
let encoded = encodeHandshake hs
when (certVerifyHandshakeMaterial hs) $ withHandshakeM $ addHandshakeMessage encoded

View file

@ -69,6 +69,7 @@ import Network.TLS.Cipher
import Network.TLS.Record.State
import Network.TLS.Handshake.State
import Network.TLS.RNG
import Network.TLS.Types (Role(..))
import qualified Data.ByteString as B
import Control.Applicative ((<$>))
import Control.Monad
@ -128,7 +129,7 @@ runRecordStateSt f = do
(Left e, _) -> throwError e
(Right a, newSt) -> put newSt >> return a
newTLSState :: CPRG g => g -> Bool -> TLSState
newTLSState :: CPRG g => g -> Role -> TLSState
newTLSState rng clientContext = TLSState
{ stHandshake = Nothing
, stSession = Session Nothing
@ -143,7 +144,7 @@ newTLSState rng clientContext = TLSState
, stClientCertificateChain = Nothing
}
updateVerifiedData :: MonadState TLSState m => Bool -> Bytes -> m ()
updateVerifiedData :: MonadState TLSState m => Role -> Bytes -> m ()
updateVerifiedData sending bs = do
cc <- isClientContext
if cc /= sending
@ -233,7 +234,7 @@ isSessionResuming = gets stSessionResuming
needEmptyPacket :: MonadState RecordState m => m Bool
needEmptyPacket = gets f
where f st = (stVersion st <= TLS10)
&& stClientContext st
&& stClientContext st == ClientRole
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
setKeyBlock :: MonadState TLSState m => m ()
@ -267,14 +268,14 @@ setKeyBlock = modify setPendingState
msServer = MacState { msSequence = 0 }
pendingTx = TransmissionState
{ stCryptState = if cc then cstClient else cstServer
, stMacState = if cc then msClient else msServer
{ stCryptState = if cc == ClientRole then cstClient else cstServer
, stMacState = if cc == ClientRole then msClient else msServer
, stCipher = Just cipher
, stCompression = stPendingCompression rst
}
pendingRx = TransmissionState
{ stCryptState = if cc then cstServer else cstClient
, stMacState = if cc then msServer else msClient
{ stCryptState = if cc == ClientRole then cstServer else cstClient
, stMacState = if cc == ClientRole then msServer else msClient
, stCipher = Just cipher
, stCompression = stPendingCompression rst
}
@ -326,10 +327,9 @@ getCipherKeyExchangeType = gets (\st -> cipherKeyExchange <$> stPendingCipher st
getVerifiedData :: MonadState TLSState m => Bool -> m Bytes
getVerifiedData client = gets (if client then stClientVerifiedData else stServerVerifiedData)
isClientContext :: MonadState TLSState m => m Bool
isClientContext :: MonadState TLSState m => m Role
isClientContext = getRecordState stClientContext
startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m ()
startHandshakeClient ver crand = do
-- FIXME check if handshake is already not null

View file

@ -11,6 +11,7 @@ module Network.TLS.Types
, SessionData(..)
, CipherID
, CompressionID
, Role(..)
) where
import Data.ByteString (ByteString)
@ -36,3 +37,7 @@ type CipherID = Word16
-- | Compression identification
type CompressionID = Word8
-- | Role
data Role = ClientRole | ServerRole
deriving (Show,Eq)