use the new getList
This commit is contained in:
parent
d49bff619b
commit
e2eb3ba95c
1 changed files with 20 additions and 43 deletions
|
@ -53,7 +53,6 @@ import Network.TLS.Cap
|
|||
import Data.Either (partitionEithers)
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Bits ((.|.))
|
||||
import Data.Word(Word16)
|
||||
import Control.Applicative ((<$>))
|
||||
import Control.Monad
|
||||
import Data.Certificate.X509 (decodeCertificate, encodeCertificate, X509, encodeDN, decodeDN)
|
||||
|
@ -208,11 +207,12 @@ decodeServerHelloDone = return ServerHelloDone
|
|||
|
||||
decodeCertificates :: Get Handshake
|
||||
decodeCertificates = do
|
||||
certs <- getWord24 >>= getCerts >>= return . map (decodeCertificate . L.fromChunks . (:[]))
|
||||
let (l, r) = partitionEithers certs
|
||||
if length l > 0
|
||||
then fail ("error certificate parsing: " ++ show l)
|
||||
else return $ Certificates r
|
||||
certsRaw <- getWord24 >>= \len -> getList (fromIntegral len) getCertRaw
|
||||
let (badCerts, certs) = partitionEithers $ map (decodeCertificate . L.fromChunks . (:[])) certsRaw
|
||||
if not $ null badCerts
|
||||
then fail ("error certificate parsing: " ++ show badCerts)
|
||||
else return $ Certificates certs
|
||||
where getCertRaw = getOpaque24 >>= \cert -> return (3 + B.length cert, cert)
|
||||
|
||||
decodeFinished :: Get Handshake
|
||||
decodeFinished = Finished <$> (remaining >>= getBytes)
|
||||
|
@ -223,50 +223,35 @@ decodeNextProtocolNegotiation = do
|
|||
_ <- getOpaque8 -- ignore padding
|
||||
return $ HsNextProtocolNegotiation opaque
|
||||
|
||||
getSignatureHashAlgorithm :: Get (HashAlgorithm, SignatureAlgorithm)
|
||||
getSignatureHashAlgorithm :: Get HashAndSignatureAlgorithm
|
||||
getSignatureHashAlgorithm = do
|
||||
h <- fromJust . valToType <$> getWord8
|
||||
s <- fromJust . valToType <$> getWord8
|
||||
return (h,s)
|
||||
|
||||
getSignatureHashAlgorithms :: Int -> Get [ (HashAlgorithm, SignatureAlgorithm) ]
|
||||
getSignatureHashAlgorithms 0 = return []
|
||||
getSignatureHashAlgorithms len = liftM2 (:) getSignatureHashAlgorithm (getSignatureHashAlgorithms (len-2))
|
||||
|
||||
decodeCertRequest :: CurrentParams -> Get Handshake
|
||||
decodeCertRequest cp = do
|
||||
certTypes <- map (fromJust . valToType . fromIntegral) <$> getWords8
|
||||
|
||||
sigHashAlgs <- if cParamsVersion cp >= TLS12
|
||||
then do
|
||||
sighashlen <- getWord16
|
||||
Just <$> getSignatureHashAlgorithms (fromIntegral sighashlen)
|
||||
else return Nothing
|
||||
then Just <$> (getWord16 >>= getSignatureHashAlgorithms)
|
||||
else return Nothing
|
||||
dNameLen <- getWord16
|
||||
-- FIXME: Decide whether to remove this check completely or to make it an option.
|
||||
-- when (cParamsVersion cp < TLS12 && dNameLen < 3) $ fail "certrequest distinguishname not of the correct size"
|
||||
dNames <- decodeDNames dNameLen
|
||||
dNames <- getList (fromIntegral dNameLen) getDName
|
||||
return $ CertRequest certTypes sigHashAlgs dNames
|
||||
where
|
||||
-- Parse a list of distinguished names, which must be exactly
|
||||
-- 'len' bytes long.
|
||||
decodeDNames :: Word16 -> Get [DistinguishedName]
|
||||
decodeDNames len | len == 0 = return []
|
||||
decodeDNames len = do
|
||||
thisLen <- getWord16
|
||||
when (thisLen == 0) $ fail "certrequest: invalid DN length"
|
||||
dName <- getBytes $ fromIntegral thisLen
|
||||
l <- decodeDNames (len - (2 + thisLen))
|
||||
dn <- decodeDName dName
|
||||
return $ dn : l
|
||||
|
||||
-- Decode the given bytes into a distinguished name.
|
||||
decodeDName :: Bytes -> Get DistinguishedName
|
||||
decodeDName d =
|
||||
case decodeDN (L.fromChunks [d]) of
|
||||
Left err -> fail $ "certrequest: " ++ show err
|
||||
Right s -> return $ DistinguishedName s
|
||||
where
|
||||
getSignatureHashAlgorithms len = getList (fromIntegral len) (getSignatureHashAlgorithm >>= \sh -> return (2, sh))
|
||||
getDName = do
|
||||
dName <- getOpaque16
|
||||
when (B.length dName == 0) $ fail "certrequest: invalid DN length"
|
||||
dn <- decodeDName dName
|
||||
return (2 + B.length dName, dn)
|
||||
|
||||
decodeDName d = case decodeDN (L.fromChunks [d]) of
|
||||
Left err -> fail ("certrequest: " ++ show err)
|
||||
Right s -> return $ DistinguishedName s
|
||||
|
||||
decodeCertVerify :: CurrentParams -> Get Handshake
|
||||
decodeCertVerify cp = do
|
||||
|
@ -419,14 +404,6 @@ putSession :: Session -> Put
|
|||
putSession (Session Nothing) = putWord8 0
|
||||
putSession (Session (Just s)) = putOpaque8 s
|
||||
|
||||
getCerts :: Int -> Get [Bytes]
|
||||
getCerts 0 = return []
|
||||
getCerts len = do
|
||||
certlen <- getWord24
|
||||
cert <- getBytes certlen
|
||||
certxs <- getCerts (len - certlen - 3)
|
||||
return (cert : certxs)
|
||||
|
||||
putCert :: X509 -> Put
|
||||
putCert cert = putOpaque24 (B.concat $ L.toChunks $ encodeCertificate cert)
|
||||
|
||||
|
|
Loading…
Reference in a new issue