properly handle multiple packet fragments.
as a bonus it cleans lots of differents part since the state machine is inside receiving/sending code
This commit is contained in:
parent
e189f37a67
commit
383cf4c021
4 changed files with 49 additions and 71 deletions
|
@ -87,7 +87,7 @@ runTLSClient f params rng = runTLSClientST f (TLSStateClient { scParams = params
|
|||
where state = (newTLSState rng) { stVersion = cpConnectVersion params, stClientContext = True }
|
||||
|
||||
{- | receive a single TLS packet or on error a TLSError -}
|
||||
recvPacket :: Handle -> TLSClient IO (Either TLSError Packet)
|
||||
recvPacket :: Handle -> TLSClient IO (Either TLSError [Packet])
|
||||
recvPacket handle = do
|
||||
hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader
|
||||
case hdr of
|
||||
|
@ -102,33 +102,33 @@ sendPacket handle pkt = do
|
|||
dataToSend <- writePacket pkt
|
||||
lift $ B.hPut handle dataToSend
|
||||
|
||||
recvServerHello :: Handle -> TLSClient IO ()
|
||||
recvServerHello handle = do
|
||||
processServerInfo (Handshake (ServerHello ver _ _ cipher _ _)) = do
|
||||
ciphers <- cpCiphers . scParams <$> get
|
||||
allowedvers <- cpAllowedVersions . scParams <$> get
|
||||
callbacks <- cpCallbacks . scParams <$> get
|
||||
pkt <- recvPacket handle
|
||||
let hs = case pkt of
|
||||
Right (Handshake h) -> h
|
||||
Left err -> error ("error received: " ++ show err)
|
||||
Right x -> error ("unexpected packet received, expecting handshake " ++ show x)
|
||||
case hs of
|
||||
ServerHello ver _ _ cipher _ _ -> do
|
||||
case find ((==) ver) allowedvers of
|
||||
Nothing -> error ("received version which is not allowed: " ++ show ver)
|
||||
Just _ -> setVersion ver
|
||||
case find ((==) ver) allowedvers of
|
||||
Nothing -> error ("received version which is not allowed: " ++ show ver)
|
||||
Just _ -> setVersion ver
|
||||
case find ((==) cipher . cipherID) ciphers of
|
||||
Nothing -> error "no cipher in common with the server"
|
||||
Just c -> setCipher c
|
||||
|
||||
case find ((==) cipher . cipherID) ciphers of
|
||||
Nothing -> error "no cipher in common with the server"
|
||||
Just c -> setCipher c
|
||||
recvServerHello handle
|
||||
CertRequest _ _ _ -> modify (\sc -> sc { scCertRequested = True }) >> recvServerHello handle
|
||||
Certificates certs -> do
|
||||
valid <- lift $ maybe (return True) (\cb -> cb certs) (cbCertificates callbacks)
|
||||
unless valid $ error "certificates received deemed invalid by user"
|
||||
recvServerHello handle
|
||||
ServerHelloDone -> return ()
|
||||
_ -> error "unexpected handshake message received in server hello messages"
|
||||
processServerInfo (Handshake (CertRequest _ _ _)) = do
|
||||
modify (\sc -> sc { scCertRequested = True })
|
||||
|
||||
processServerInfo (Handshake (Certificates certs)) = do
|
||||
callbacks <- cpCallbacks . scParams <$> get
|
||||
valid <- lift $ maybe (return True) (\cb -> cb certs) (cbCertificates callbacks)
|
||||
unless valid $ error "certificates received deemed invalid by user"
|
||||
|
||||
processServerInfo _ = return ()
|
||||
|
||||
recvServerInfo :: Handle -> TLSClient IO ()
|
||||
recvServerInfo handle = do
|
||||
whileStatus (/= (StatusHandshake HsStatusServerHelloDone)) $ do
|
||||
pkts <- recvPacket handle
|
||||
case pkts of
|
||||
Left err -> error ("error received: " ++ show err)
|
||||
Right l -> forM_ l processServerInfo
|
||||
|
||||
connectSendClientHello :: Handle -> ClientRandom -> TLSClient IO ()
|
||||
connectSendClientHello handle crand = do
|
||||
|
@ -157,7 +157,7 @@ connectSendFinish handle = do
|
|||
connect :: Handle -> ClientRandom -> ClientKeyData -> TLSClient IO ()
|
||||
connect handle crand premasterRandom = do
|
||||
connectSendClientHello handle crand
|
||||
recvServerHello handle
|
||||
recvServerInfo handle
|
||||
connectSendClientCertificate handle
|
||||
|
||||
connectSendClientKeyXchg handle premasterRandom
|
||||
|
@ -172,16 +172,10 @@ connect handle crand premasterRandom = do
|
|||
connectSendFinish handle
|
||||
|
||||
{- receive changeCipherSpec -}
|
||||
pktCCS <- recvPacket handle
|
||||
case pktCCS of
|
||||
Right ChangeCipherSpec -> return ()
|
||||
x -> error ("unexpected reply. expecting change cipher spec " ++ show x)
|
||||
_ <- recvPacket handle
|
||||
|
||||
{- receive Finished -}
|
||||
pktFin <- recvPacket handle
|
||||
case pktFin of
|
||||
Right (Handshake (Finished _)) -> return ()
|
||||
x -> error ("unexpected reply. expecting finished " ++ show x)
|
||||
_ <- recvPacket handle
|
||||
|
||||
return ()
|
||||
|
||||
|
@ -205,8 +199,8 @@ recvData :: Handle -> TLSClient IO L.ByteString
|
|||
recvData handle = do
|
||||
pkt <- recvPacket handle
|
||||
case pkt of
|
||||
Right (AppData x) -> return $ L.fromChunks [x]
|
||||
Right (Handshake HelloRequest) -> do
|
||||
Right [AppData x] -> return $ L.fromChunks [x]
|
||||
Right [Handshake HelloRequest] -> do
|
||||
-- SECURITY FIXME audit the rng here..
|
||||
st <- getTLSState
|
||||
let (bytes, rng') = getRandomBytes (stRandomGen st) 32
|
||||
|
|
|
@ -54,7 +54,7 @@ returnEither :: Either TLSError a -> TLSRead a
|
|||
returnEither (Left err) = throwError err
|
||||
returnEither (Right a) = return a
|
||||
|
||||
readPacket :: MonadTLSState m => Header -> EncryptedData -> m (Either TLSError Packet)
|
||||
readPacket :: MonadTLSState m => Header -> EncryptedData -> m (Either TLSError [Packet])
|
||||
readPacket hdr content = runTLSRead (checkState hdr >> decryptContent hdr content >>= processPacket hdr)
|
||||
|
||||
checkState :: Header -> TLSRead ()
|
||||
|
@ -71,11 +71,11 @@ checkState (Header pt _ _) =
|
|||
allowed ProtocolType_ChangeCipherSpec (StatusHandshake HsStatusClientCertificateVerify) = True
|
||||
allowed _ _ = False
|
||||
|
||||
processPacket :: Header -> Bytes -> TLSRead Packet
|
||||
processPacket :: Header -> Bytes -> TLSRead [Packet]
|
||||
|
||||
processPacket (Header ProtocolType_AppData _ _) content = return $ AppData content
|
||||
processPacket (Header ProtocolType_AppData _ _) content = return [AppData content]
|
||||
|
||||
processPacket (Header ProtocolType_Alert _ _) content = return . Alert =<< returnEither (decodeAlert content)
|
||||
processPacket (Header ProtocolType_Alert _ _) content = return . (:[]) . Alert =<< returnEither (decodeAlert content)
|
||||
|
||||
processPacket (Header ProtocolType_ChangeCipherSpec _ _) content = do
|
||||
e <- updateStatusCC False
|
||||
|
@ -84,15 +84,14 @@ processPacket (Header ProtocolType_ChangeCipherSpec _ _) content = do
|
|||
returnEither $ decodeChangeCipherSpec content
|
||||
switchRxEncryption
|
||||
isClientContext >>= \cc -> when (not cc) setKeyBlock
|
||||
return ChangeCipherSpec
|
||||
return [ChangeCipherSpec]
|
||||
|
||||
processPacket (Header ProtocolType_Handshake ver _) dcontent = do
|
||||
handshakes <- returnEither (decodeHandshakes dcontent)
|
||||
hss <- forM handshakes $ \(ty, content) -> do
|
||||
forM handshakes $ \(ty, content) -> do
|
||||
hs <- processHandshake ver ty content
|
||||
when (finishHandshakeTypeMaterial ty) $ updateHandshakeDigestSplitted ty content
|
||||
return hs
|
||||
return $ head hss -- FIXME for compat until we fixes the expectations in server/client
|
||||
|
||||
processHandshake :: Version -> HandshakeType -> ByteString -> TLSRead Packet
|
||||
processHandshake ver ty econtent = do
|
||||
|
|
|
@ -88,7 +88,7 @@ runTLSServer f params rng = runTLSServerST f (TLSStateServer { scParams = params
|
|||
where state = (newTLSState rng) { stClientContext = False }
|
||||
|
||||
{- | receive a single TLS packet or on error a TLSError -}
|
||||
recvPacket :: Handle -> TLSServer IO (Either TLSError Packet)
|
||||
recvPacket :: Handle -> TLSServer IO (Either TLSError [Packet])
|
||||
recvPacket handle = do
|
||||
hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader
|
||||
case hdr of
|
||||
|
@ -127,25 +127,6 @@ handleClientHello (ClientHello ver _ _ ciphers compressionID _) = do
|
|||
handleClientHello _ = do
|
||||
fail "unexpected handshake type received. expecting client hello"
|
||||
|
||||
expectingPacket :: (Either TLSError Packet) -> ProtocolType -> TLSServer IO ()
|
||||
expectingPacket pkt expectedType = do
|
||||
apkt <- case pkt of
|
||||
Right x -> return x
|
||||
Left tlserror -> fail ("expecting packet but got error " ++ show tlserror)
|
||||
when (packetType apkt /= expectedType) $ do
|
||||
fail ("unexpected packet received, expecting " ++ show expectedType)
|
||||
return ()
|
||||
|
||||
expectingHandshake :: (Either TLSError Packet) -> HandshakeType -> TLSServer IO ()
|
||||
expectingHandshake pkt expectedType = do
|
||||
hs <- case pkt of
|
||||
Right (Handshake hs) -> return hs
|
||||
Right _ -> fail ("unexpected packet received, expecting handshake " ++ show expectedType)
|
||||
Left tlserror -> fail ("expecting handshake but got error " ++ show tlserror)
|
||||
when (typeOfHandshake hs /= expectedType) $ do
|
||||
fail ("unexpected handshake received, expecting " ++ show expectedType)
|
||||
return ()
|
||||
|
||||
handshakeSendServerData :: Handle -> ServerRandom -> TLSServer IO ()
|
||||
handshakeSendServerData handle srand = do
|
||||
sp <- get >>= return . scParams
|
||||
|
@ -192,9 +173,7 @@ handshake handle srand = do
|
|||
handshakeSendServerData handle srand
|
||||
lift $ hFlush handle
|
||||
|
||||
recvPacket handle >>= \pkt -> expectingHandshake pkt HandshakeType_ClientKeyXchg
|
||||
recvPacket handle >>= \pkt -> expectingPacket pkt ProtocolType_ChangeCipherSpec
|
||||
recvPacket handle >>= \pkt -> expectingHandshake pkt HandshakeType_Finished
|
||||
whileStatus (/= (StatusHandshake HsStatusClientFinished)) (recvPacket handle)
|
||||
|
||||
sendPacket handle ChangeCipherSpec
|
||||
handshakeSendFinish handle
|
||||
|
@ -206,9 +185,9 @@ handshake handle srand = do
|
|||
{- | listen on a handle to a new TLS connection. -}
|
||||
listen :: Handle -> ServerRandom -> TLSServer IO ()
|
||||
listen handle srand = do
|
||||
pkt <- recvPacket handle
|
||||
case pkt of
|
||||
Right (Handshake hs) -> handleClientHello hs
|
||||
pkts <- recvPacket handle
|
||||
case pkts of
|
||||
Right [Handshake hs] -> handleClientHello hs
|
||||
x -> fail ("unexpected type received. expecting handshake ++ " ++ show x)
|
||||
handshake handle srand
|
||||
|
||||
|
@ -234,7 +213,7 @@ recvData :: Handle -> TLSServer IO L.ByteString
|
|||
recvData handle = do
|
||||
pkt <- recvPacket handle
|
||||
case pkt of
|
||||
Right (Handshake (ClientHello _ _ _ _ _ _)) -> do
|
||||
Right [Handshake (ClientHello _ _ _ _ _ _)] -> do
|
||||
-- SECURITY FIXME audit the rng here..
|
||||
st <- getTLSState
|
||||
let (bytes, rng') = getRandomBytes (stRandomGen st) 32
|
||||
|
@ -242,7 +221,7 @@ recvData handle = do
|
|||
let srand = fromJust $ serverRandom bytes
|
||||
handshake handle srand
|
||||
recvData handle
|
||||
Right (AppData x) -> return $ L.fromChunks [x]
|
||||
Right [AppData x] -> return $ L.fromChunks [x]
|
||||
Left err -> error ("error received: " ++ show err)
|
||||
_ -> error "unexpected item"
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ module Network.TLS.State
|
|||
, assert -- FIXME move somewhere else (Internal.hs ?)
|
||||
, updateStatusHs
|
||||
, updateStatusCC
|
||||
, whileStatus
|
||||
, finishHandshakeTypeMaterial
|
||||
, finishHandshakeMaterial
|
||||
, makeDigest
|
||||
|
@ -191,6 +192,11 @@ hsStatusTransitionTable =
|
|||
[ StatusHandshake HsStatusServerChangeCipher ])
|
||||
]
|
||||
|
||||
whileStatus :: (MonadTLSState m, Monad m) => (TLSStatus -> Bool) -> m a -> m ()
|
||||
whileStatus p a = do
|
||||
currentStatus <- getTLSState >>= return . stStatus
|
||||
when (p currentStatus) (a >> whileStatus p a)
|
||||
|
||||
updateStatus :: MonadTLSState m => TLSStatus -> m ()
|
||||
updateStatus x = modifyTLSState (\st -> st { stStatus = x })
|
||||
|
||||
|
|
Loading…
Reference in a new issue