Allow handshake records to be split across records.
* Store continuations in record state * parse handshake records one by one.
This commit is contained in:
parent
14c3325c75
commit
7d0e1d5267
5 changed files with 30 additions and 24 deletions
|
@ -11,6 +11,7 @@ module Network.TLS.Internal
|
|||
, module Network.TLS.Packet
|
||||
, module Network.TLS.Receiving
|
||||
, module Network.TLS.Sending
|
||||
, module Network.TLS.Wire
|
||||
, sendPacket
|
||||
, recvPacket
|
||||
) where
|
||||
|
@ -19,4 +20,5 @@ import Network.TLS.Struct
|
|||
import Network.TLS.Packet
|
||||
import Network.TLS.Receiving
|
||||
import Network.TLS.Sending
|
||||
import Network.TLS.Wire
|
||||
import Network.TLS.Core (sendPacket, recvPacket)
|
||||
|
|
|
@ -26,7 +26,7 @@ module Network.TLS.Packet
|
|||
, encodeAlerts
|
||||
|
||||
-- * marshall functions for handshake messages
|
||||
, decodeHandshakes
|
||||
, decodeHandshakeRecord
|
||||
, decodeHandshake
|
||||
, decodeDeprecatedHandshake
|
||||
, encodeHandshake
|
||||
|
@ -157,21 +157,12 @@ encodeAlerts l = runPut $ mapM_ encodeAlert l
|
|||
where encodeAlert (al, ad) = putWord8 (valOfType al) >> putWord8 (valOfType ad)
|
||||
|
||||
{- decode and encode HANDSHAKE -}
|
||||
decodeHandshakeHeader :: Get (HandshakeType, Bytes)
|
||||
decodeHandshakeHeader = do
|
||||
decodeHandshakeRecord :: ByteString -> GetResult (HandshakeType, Bytes)
|
||||
decodeHandshakeRecord = runGet "handshake-record" $ do
|
||||
ty <- getHandshakeType
|
||||
content <- getOpaque24
|
||||
return (ty, content)
|
||||
|
||||
decodeHandshakes :: ByteString -> Either TLSError [(HandshakeType, Bytes)]
|
||||
decodeHandshakes b = runGetErr "handshakes" getAll b
|
||||
where getAll = do
|
||||
x <- decodeHandshakeHeader
|
||||
empty <- isEmpty
|
||||
if empty
|
||||
then return [x]
|
||||
else liftM ((:) x) getAll
|
||||
|
||||
decodeHandshake :: CurrentParams -> HandshakeType -> ByteString -> Either TLSError Handshake
|
||||
decodeHandshake cp ty = runGetErr ("handshake[" ++ show ty ++ "]") $ case ty of
|
||||
HandshakeType_HelloRequest -> decodeHelloRequest
|
||||
|
|
|
@ -20,6 +20,7 @@ import Network.TLS.Context.Internal
|
|||
import Network.TLS.Struct
|
||||
import Network.TLS.Record
|
||||
import Network.TLS.Packet
|
||||
import Network.TLS.Wire
|
||||
import Network.TLS.State
|
||||
import Network.TLS.Handshake.State
|
||||
import Network.TLS.Cipher
|
||||
|
@ -27,10 +28,6 @@ import Network.TLS.Util
|
|||
|
||||
import Data.Byteable
|
||||
|
||||
returnEither :: Either TLSError a -> TLSSt a
|
||||
returnEither (Left err) = throwError err
|
||||
returnEither (Right a) = return a
|
||||
|
||||
processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
|
||||
|
||||
processPacket _ (Record ProtocolType_AppData _ fragment) = return $ Right $ AppData $ toBytes fragment
|
||||
|
@ -47,17 +44,26 @@ processPacket ctx (Record ProtocolType_Handshake ver fragment) = do
|
|||
keyxchg <- getHState ctx >>= \hs -> return $ (hs >>= hstPendingCipher >>= Just . cipherKeyExchange)
|
||||
usingState ctx $ do
|
||||
npn <- getExtensionNPN
|
||||
let currentparams = CurrentParams
|
||||
let currentParams = CurrentParams
|
||||
{ cParamsVersion = ver
|
||||
, cParamsKeyXchgType = keyxchg
|
||||
, cParamsSupportNPN = npn
|
||||
}
|
||||
handshakes <- returnEither (decodeHandshakes $ toBytes fragment)
|
||||
hss <- forM handshakes $ \(ty, content) -> do
|
||||
case decodeHandshake currentparams ty content of
|
||||
Left err -> throwError err
|
||||
Right hs -> return hs
|
||||
-- get back the optional continuation, and parse as many handshake record as possible.
|
||||
mCont <- gets stHandshakeRecordCont
|
||||
modify (\st -> st { stHandshakeRecordCont = Nothing })
|
||||
hss <- parseMany currentParams mCont (toBytes fragment)
|
||||
return $ Handshake hss
|
||||
where parseMany currentParams mCont bs =
|
||||
case maybe decodeHandshakeRecord id mCont $ bs of
|
||||
GotError err -> throwError err
|
||||
GotPartial cont -> modify (\st -> st { stHandshakeRecordCont = Just cont }) >> return []
|
||||
GotSuccess (ty,content) ->
|
||||
either throwError (return . (:[])) $ decodeHandshake currentParams ty content
|
||||
GotSuccessRemaining (ty,content) left ->
|
||||
case decodeHandshake currentParams ty content of
|
||||
Left err -> throwError err
|
||||
Right hh -> (hh:) `fmap` parseMany currentParams Nothing left
|
||||
|
||||
processPacket _ (Record ProtocolType_DeprecatedHandshake _ fragment) =
|
||||
case decodeDeprecatedHandshake $ toBytes fragment of
|
||||
|
|
|
@ -51,6 +51,7 @@ import Control.Applicative
|
|||
import Network.TLS.Struct
|
||||
import Network.TLS.RNG
|
||||
import Network.TLS.Types (Role(..))
|
||||
import Network.TLS.Wire (GetContinuation)
|
||||
import qualified Data.ByteString as B
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Error
|
||||
|
@ -64,13 +65,14 @@ data TLSState = TLSState
|
|||
, stClientVerifiedData :: Bytes -- RFC 5746
|
||||
, stServerVerifiedData :: Bytes -- RFC 5746
|
||||
, stExtensionNPN :: Bool -- NPN draft extension
|
||||
, stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, Bytes))
|
||||
, stNegotiatedProtocol :: Maybe B.ByteString -- NPN protocol
|
||||
, stServerNextProtocolSuggest :: Maybe [B.ByteString]
|
||||
, stClientCertificateChain :: Maybe CertificateChain
|
||||
, stRandomGen :: StateRNG
|
||||
, stVersion :: Maybe Version
|
||||
, stClientContext :: Role
|
||||
} deriving (Show)
|
||||
}
|
||||
|
||||
newtype TLSSt a = TLSSt { runTLSSt :: ErrorT TLSError (State TLSState) a }
|
||||
deriving (Monad, MonadError TLSError, Functor, Applicative)
|
||||
|
@ -93,6 +95,7 @@ newTLSState rng clientContext = TLSState
|
|||
, stClientVerifiedData = B.empty
|
||||
, stServerVerifiedData = B.empty
|
||||
, stExtensionNPN = False
|
||||
, stHandshakeRecordCont = Nothing
|
||||
, stNegotiatedProtocol = Nothing
|
||||
, stServerNextProtocolSuggest = Nothing
|
||||
, stClientCertificateChain = Nothing
|
||||
|
|
|
@ -103,5 +103,9 @@ prop_header_marshalling_id x = (decodeHeader $ encodeHeader x) == Right x
|
|||
|
||||
prop_handshake_marshalling_id :: Handshake -> Bool
|
||||
prop_handshake_marshalling_id x = (decodeHs $ encodeHandshake x) == Right x
|
||||
where decodeHs b = either (Left . id) (uncurry (decodeHandshake cp) . head) $ decodeHandshakes b
|
||||
where decodeHs b = case decodeHandshakeRecord b of
|
||||
GotPartial _ -> error "got partial"
|
||||
GotError e -> error ("got error: " ++ show e)
|
||||
GotSuccessRemaining _ _ -> error "got remaining byte left"
|
||||
GotSuccess (ty, content) -> decodeHandshake cp ty content
|
||||
cp = CurrentParams { cParamsVersion = TLS10, cParamsKeyXchgType = Just CipherKeyExchange_RSA, cParamsSupportNPN = True }
|
||||
|
|
Loading…
Reference in a new issue