Allow handshake records to be split across records.

* Store continuations in record state
* parse handshake records one by one.
This commit is contained in:
Vincent Hanquez 2014-03-22 06:54:37 +00:00
parent 14c3325c75
commit 7d0e1d5267
5 changed files with 30 additions and 24 deletions

View file

@ -11,6 +11,7 @@ module Network.TLS.Internal
, module Network.TLS.Packet
, module Network.TLS.Receiving
, module Network.TLS.Sending
, module Network.TLS.Wire
, sendPacket
, recvPacket
) where
@ -19,4 +20,5 @@ import Network.TLS.Struct
import Network.TLS.Packet
import Network.TLS.Receiving
import Network.TLS.Sending
import Network.TLS.Wire
import Network.TLS.Core (sendPacket, recvPacket)

View file

@ -26,7 +26,7 @@ module Network.TLS.Packet
, encodeAlerts
-- * marshall functions for handshake messages
, decodeHandshakes
, decodeHandshakeRecord
, decodeHandshake
, decodeDeprecatedHandshake
, encodeHandshake
@ -157,21 +157,12 @@ encodeAlerts l = runPut $ mapM_ encodeAlert l
where encodeAlert (al, ad) = putWord8 (valOfType al) >> putWord8 (valOfType ad)
{- decode and encode HANDSHAKE -}
decodeHandshakeHeader :: Get (HandshakeType, Bytes)
decodeHandshakeHeader = do
decodeHandshakeRecord :: ByteString -> GetResult (HandshakeType, Bytes)
decodeHandshakeRecord = runGet "handshake-record" $ do
ty <- getHandshakeType
content <- getOpaque24
return (ty, content)
decodeHandshakes :: ByteString -> Either TLSError [(HandshakeType, Bytes)]
decodeHandshakes b = runGetErr "handshakes" getAll b
where getAll = do
x <- decodeHandshakeHeader
empty <- isEmpty
if empty
then return [x]
else liftM ((:) x) getAll
decodeHandshake :: CurrentParams -> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake cp ty = runGetErr ("handshake[" ++ show ty ++ "]") $ case ty of
HandshakeType_HelloRequest -> decodeHelloRequest

View file

@ -20,6 +20,7 @@ import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Record
import Network.TLS.Packet
import Network.TLS.Wire
import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Cipher
@ -27,10 +28,6 @@ import Network.TLS.Util
import Data.Byteable
returnEither :: Either TLSError a -> TLSSt a
returnEither (Left err) = throwError err
returnEither (Right a) = return a
processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket _ (Record ProtocolType_AppData _ fragment) = return $ Right $ AppData $ toBytes fragment
@ -47,17 +44,26 @@ processPacket ctx (Record ProtocolType_Handshake ver fragment) = do
keyxchg <- getHState ctx >>= \hs -> return $ (hs >>= hstPendingCipher >>= Just . cipherKeyExchange)
usingState ctx $ do
npn <- getExtensionNPN
let currentparams = CurrentParams
let currentParams = CurrentParams
{ cParamsVersion = ver
, cParamsKeyXchgType = keyxchg
, cParamsSupportNPN = npn
}
handshakes <- returnEither (decodeHandshakes $ toBytes fragment)
hss <- forM handshakes $ \(ty, content) -> do
case decodeHandshake currentparams ty content of
Left err -> throwError err
Right hs -> return hs
-- get back the optional continuation, and parse as many handshake record as possible.
mCont <- gets stHandshakeRecordCont
modify (\st -> st { stHandshakeRecordCont = Nothing })
hss <- parseMany currentParams mCont (toBytes fragment)
return $ Handshake hss
where parseMany currentParams mCont bs =
case maybe decodeHandshakeRecord id mCont $ bs of
GotError err -> throwError err
GotPartial cont -> modify (\st -> st { stHandshakeRecordCont = Just cont }) >> return []
GotSuccess (ty,content) ->
either throwError (return . (:[])) $ decodeHandshake currentParams ty content
GotSuccessRemaining (ty,content) left ->
case decodeHandshake currentParams ty content of
Left err -> throwError err
Right hh -> (hh:) `fmap` parseMany currentParams Nothing left
processPacket _ (Record ProtocolType_DeprecatedHandshake _ fragment) =
case decodeDeprecatedHandshake $ toBytes fragment of

View file

@ -51,6 +51,7 @@ import Control.Applicative
import Network.TLS.Struct
import Network.TLS.RNG
import Network.TLS.Types (Role(..))
import Network.TLS.Wire (GetContinuation)
import qualified Data.ByteString as B
import Control.Monad.State
import Control.Monad.Error
@ -64,13 +65,14 @@ data TLSState = TLSState
, stClientVerifiedData :: Bytes -- RFC 5746
, stServerVerifiedData :: Bytes -- RFC 5746
, stExtensionNPN :: Bool -- NPN draft extension
, stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, Bytes))
, stNegotiatedProtocol :: Maybe B.ByteString -- NPN protocol
, stServerNextProtocolSuggest :: Maybe [B.ByteString]
, stClientCertificateChain :: Maybe CertificateChain
, stRandomGen :: StateRNG
, stVersion :: Maybe Version
, stClientContext :: Role
} deriving (Show)
}
newtype TLSSt a = TLSSt { runTLSSt :: ErrorT TLSError (State TLSState) a }
deriving (Monad, MonadError TLSError, Functor, Applicative)
@ -93,6 +95,7 @@ newTLSState rng clientContext = TLSState
, stClientVerifiedData = B.empty
, stServerVerifiedData = B.empty
, stExtensionNPN = False
, stHandshakeRecordCont = Nothing
, stNegotiatedProtocol = Nothing
, stServerNextProtocolSuggest = Nothing
, stClientCertificateChain = Nothing

View file

@ -103,5 +103,9 @@ prop_header_marshalling_id x = (decodeHeader $ encodeHeader x) == Right x
prop_handshake_marshalling_id :: Handshake -> Bool
prop_handshake_marshalling_id x = (decodeHs $ encodeHandshake x) == Right x
where decodeHs b = either (Left . id) (uncurry (decodeHandshake cp) . head) $ decodeHandshakes b
where decodeHs b = case decodeHandshakeRecord b of
GotPartial _ -> error "got partial"
GotError e -> error ("got error: " ++ show e)
GotSuccessRemaining _ _ -> error "got remaining byte left"
GotSuccess (ty, content) -> decodeHandshake cp ty content
cp = CurrentParams { cParamsVersion = TLS10, cParamsKeyXchgType = Just CipherKeyExchange_RSA, cParamsSupportNPN = True }