diff --git a/Network/TLS/Client.hs b/Network/TLS/Client.hs index b1fbd10..74a309e 100644 --- a/Network/TLS/Client.hs +++ b/Network/TLS/Client.hs @@ -87,7 +87,7 @@ runTLSClient f params rng = runTLSClientST f (TLSStateClient { scParams = params where state = (newTLSState rng) { stVersion = cpConnectVersion params, stClientContext = True } {- | receive a single TLS packet or on error a TLSError -} -recvPacket :: Handle -> TLSClient IO (Either TLSError Packet) +recvPacket :: Handle -> TLSClient IO (Either TLSError [Packet]) recvPacket handle = do hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader case hdr of @@ -102,33 +102,33 @@ sendPacket handle pkt = do dataToSend <- writePacket pkt lift $ B.hPut handle dataToSend -recvServerHello :: Handle -> TLSClient IO () -recvServerHello handle = do +processServerInfo (Handshake (ServerHello ver _ _ cipher _ _)) = do ciphers <- cpCiphers . scParams <$> get allowedvers <- cpAllowedVersions . scParams <$> get - callbacks <- cpCallbacks . scParams <$> get - pkt <- recvPacket handle - let hs = case pkt of - Right (Handshake h) -> h - Left err -> error ("error received: " ++ show err) - Right x -> error ("unexpected packet received, expecting handshake " ++ show x) - case hs of - ServerHello ver _ _ cipher _ _ -> do - case find ((==) ver) allowedvers of - Nothing -> error ("received version which is not allowed: " ++ show ver) - Just _ -> setVersion ver + case find ((==) ver) allowedvers of + Nothing -> error ("received version which is not allowed: " ++ show ver) + Just _ -> setVersion ver + case find ((==) cipher . cipherID) ciphers of + Nothing -> error "no cipher in common with the server" + Just c -> setCipher c - case find ((==) cipher . cipherID) ciphers of - Nothing -> error "no cipher in common with the server" - Just c -> setCipher c - recvServerHello handle - CertRequest _ _ _ -> modify (\sc -> sc { scCertRequested = True }) >> recvServerHello handle - Certificates certs -> do - valid <- lift $ maybe (return True) (\cb -> cb certs) (cbCertificates callbacks) - unless valid $ error "certificates received deemed invalid by user" - recvServerHello handle - ServerHelloDone -> return () - _ -> error "unexpected handshake message received in server hello messages" +processServerInfo (Handshake (CertRequest _ _ _)) = do + modify (\sc -> sc { scCertRequested = True }) + +processServerInfo (Handshake (Certificates certs)) = do + callbacks <- cpCallbacks . scParams <$> get + valid <- lift $ maybe (return True) (\cb -> cb certs) (cbCertificates callbacks) + unless valid $ error "certificates received deemed invalid by user" + +processServerInfo _ = return () + +recvServerInfo :: Handle -> TLSClient IO () +recvServerInfo handle = do + whileStatus (/= (StatusHandshake HsStatusServerHelloDone)) $ do + pkts <- recvPacket handle + case pkts of + Left err -> error ("error received: " ++ show err) + Right l -> forM_ l processServerInfo connectSendClientHello :: Handle -> ClientRandom -> TLSClient IO () connectSendClientHello handle crand = do @@ -157,7 +157,7 @@ connectSendFinish handle = do connect :: Handle -> ClientRandom -> ClientKeyData -> TLSClient IO () connect handle crand premasterRandom = do connectSendClientHello handle crand - recvServerHello handle + recvServerInfo handle connectSendClientCertificate handle connectSendClientKeyXchg handle premasterRandom @@ -172,16 +172,10 @@ connect handle crand premasterRandom = do connectSendFinish handle {- receive changeCipherSpec -} - pktCCS <- recvPacket handle - case pktCCS of - Right ChangeCipherSpec -> return () - x -> error ("unexpected reply. expecting change cipher spec " ++ show x) + _ <- recvPacket handle {- receive Finished -} - pktFin <- recvPacket handle - case pktFin of - Right (Handshake (Finished _)) -> return () - x -> error ("unexpected reply. expecting finished " ++ show x) + _ <- recvPacket handle return () @@ -205,8 +199,8 @@ recvData :: Handle -> TLSClient IO L.ByteString recvData handle = do pkt <- recvPacket handle case pkt of - Right (AppData x) -> return $ L.fromChunks [x] - Right (Handshake HelloRequest) -> do + Right [AppData x] -> return $ L.fromChunks [x] + Right [Handshake HelloRequest] -> do -- SECURITY FIXME audit the rng here.. st <- getTLSState let (bytes, rng') = getRandomBytes (stRandomGen st) 32 diff --git a/Network/TLS/Receiving.hs b/Network/TLS/Receiving.hs index 1d50bff..36cc823 100644 --- a/Network/TLS/Receiving.hs +++ b/Network/TLS/Receiving.hs @@ -54,7 +54,7 @@ returnEither :: Either TLSError a -> TLSRead a returnEither (Left err) = throwError err returnEither (Right a) = return a -readPacket :: MonadTLSState m => Header -> EncryptedData -> m (Either TLSError Packet) +readPacket :: MonadTLSState m => Header -> EncryptedData -> m (Either TLSError [Packet]) readPacket hdr content = runTLSRead (checkState hdr >> decryptContent hdr content >>= processPacket hdr) checkState :: Header -> TLSRead () @@ -71,11 +71,11 @@ checkState (Header pt _ _) = allowed ProtocolType_ChangeCipherSpec (StatusHandshake HsStatusClientCertificateVerify) = True allowed _ _ = False -processPacket :: Header -> Bytes -> TLSRead Packet +processPacket :: Header -> Bytes -> TLSRead [Packet] -processPacket (Header ProtocolType_AppData _ _) content = return $ AppData content +processPacket (Header ProtocolType_AppData _ _) content = return [AppData content] -processPacket (Header ProtocolType_Alert _ _) content = return . Alert =<< returnEither (decodeAlert content) +processPacket (Header ProtocolType_Alert _ _) content = return . (:[]) . Alert =<< returnEither (decodeAlert content) processPacket (Header ProtocolType_ChangeCipherSpec _ _) content = do e <- updateStatusCC False @@ -84,15 +84,14 @@ processPacket (Header ProtocolType_ChangeCipherSpec _ _) content = do returnEither $ decodeChangeCipherSpec content switchRxEncryption isClientContext >>= \cc -> when (not cc) setKeyBlock - return ChangeCipherSpec + return [ChangeCipherSpec] processPacket (Header ProtocolType_Handshake ver _) dcontent = do handshakes <- returnEither (decodeHandshakes dcontent) - hss <- forM handshakes $ \(ty, content) -> do + forM handshakes $ \(ty, content) -> do hs <- processHandshake ver ty content when (finishHandshakeTypeMaterial ty) $ updateHandshakeDigestSplitted ty content return hs - return $ head hss -- FIXME for compat until we fixes the expectations in server/client processHandshake :: Version -> HandshakeType -> ByteString -> TLSRead Packet processHandshake ver ty econtent = do diff --git a/Network/TLS/Server.hs b/Network/TLS/Server.hs index e9b2cd9..0e849a3 100644 --- a/Network/TLS/Server.hs +++ b/Network/TLS/Server.hs @@ -88,7 +88,7 @@ runTLSServer f params rng = runTLSServerST f (TLSStateServer { scParams = params where state = (newTLSState rng) { stClientContext = False } {- | receive a single TLS packet or on error a TLSError -} -recvPacket :: Handle -> TLSServer IO (Either TLSError Packet) +recvPacket :: Handle -> TLSServer IO (Either TLSError [Packet]) recvPacket handle = do hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader case hdr of @@ -127,25 +127,6 @@ handleClientHello (ClientHello ver _ _ ciphers compressionID _) = do handleClientHello _ = do fail "unexpected handshake type received. expecting client hello" -expectingPacket :: (Either TLSError Packet) -> ProtocolType -> TLSServer IO () -expectingPacket pkt expectedType = do - apkt <- case pkt of - Right x -> return x - Left tlserror -> fail ("expecting packet but got error " ++ show tlserror) - when (packetType apkt /= expectedType) $ do - fail ("unexpected packet received, expecting " ++ show expectedType) - return () - -expectingHandshake :: (Either TLSError Packet) -> HandshakeType -> TLSServer IO () -expectingHandshake pkt expectedType = do - hs <- case pkt of - Right (Handshake hs) -> return hs - Right _ -> fail ("unexpected packet received, expecting handshake " ++ show expectedType) - Left tlserror -> fail ("expecting handshake but got error " ++ show tlserror) - when (typeOfHandshake hs /= expectedType) $ do - fail ("unexpected handshake received, expecting " ++ show expectedType) - return () - handshakeSendServerData :: Handle -> ServerRandom -> TLSServer IO () handshakeSendServerData handle srand = do sp <- get >>= return . scParams @@ -192,9 +173,7 @@ handshake handle srand = do handshakeSendServerData handle srand lift $ hFlush handle - recvPacket handle >>= \pkt -> expectingHandshake pkt HandshakeType_ClientKeyXchg - recvPacket handle >>= \pkt -> expectingPacket pkt ProtocolType_ChangeCipherSpec - recvPacket handle >>= \pkt -> expectingHandshake pkt HandshakeType_Finished + whileStatus (/= (StatusHandshake HsStatusClientFinished)) (recvPacket handle) sendPacket handle ChangeCipherSpec handshakeSendFinish handle @@ -206,9 +185,9 @@ handshake handle srand = do {- | listen on a handle to a new TLS connection. -} listen :: Handle -> ServerRandom -> TLSServer IO () listen handle srand = do - pkt <- recvPacket handle - case pkt of - Right (Handshake hs) -> handleClientHello hs + pkts <- recvPacket handle + case pkts of + Right [Handshake hs] -> handleClientHello hs x -> fail ("unexpected type received. expecting handshake ++ " ++ show x) handshake handle srand @@ -234,7 +213,7 @@ recvData :: Handle -> TLSServer IO L.ByteString recvData handle = do pkt <- recvPacket handle case pkt of - Right (Handshake (ClientHello _ _ _ _ _ _)) -> do + Right [Handshake (ClientHello _ _ _ _ _ _)] -> do -- SECURITY FIXME audit the rng here.. st <- getTLSState let (bytes, rng') = getRandomBytes (stRandomGen st) 32 @@ -242,7 +221,7 @@ recvData handle = do let srand = fromJust $ serverRandom bytes handshake handle srand recvData handle - Right (AppData x) -> return $ L.fromChunks [x] + Right [AppData x] -> return $ L.fromChunks [x] Left err -> error ("error received: " ++ show err) _ -> error "unexpected item" diff --git a/Network/TLS/State.hs b/Network/TLS/State.hs index 0d2a485..d098d42 100644 --- a/Network/TLS/State.hs +++ b/Network/TLS/State.hs @@ -20,6 +20,7 @@ module Network.TLS.State , assert -- FIXME move somewhere else (Internal.hs ?) , updateStatusHs , updateStatusCC + , whileStatus , finishHandshakeTypeMaterial , finishHandshakeMaterial , makeDigest @@ -191,6 +192,11 @@ hsStatusTransitionTable = [ StatusHandshake HsStatusServerChangeCipher ]) ] +whileStatus :: (MonadTLSState m, Monad m) => (TLSStatus -> Bool) -> m a -> m () +whileStatus p a = do + currentStatus <- getTLSState >>= return . stStatus + when (p currentStatus) (a >> whileStatus p a) + updateStatus :: MonadTLSState m => TLSStatus -> m () updateStatus x = modifyTLSState (\st -> st { stStatus = x })