diff --git a/core/Tests/Certificate.hs b/core/Tests/Certificate.hs index f9122e5..bd595a1 100644 --- a/core/Tests/Certificate.hs +++ b/core/Tests/Certificate.hs @@ -1,7 +1,7 @@ module Certificate - ( arbitraryX509 - , arbitraryX509WithPublicKey - ) where + ( arbitraryX509 + , arbitraryX509WithPublicKey + ) where import Test.QuickCheck import Data.X509 @@ -14,14 +14,14 @@ import PubKey arbitraryDN = return $ DistinguishedName [] arbitraryTime = do - year <- choose (1951, 2050) - month <- choose (1, 12) - day <- choose (1, 30) - hour <- choose (0, 23) - minute <- choose (0, 59) - second <- choose (0, 59) - --z <- arbitrary - return $ UTCTime (fromGregorian year month day) (secondsToDiffTime (hour * 3600 + minute * 60 + second)) + year <- choose (1951, 2050) + month <- choose (1, 12) + day <- choose (1, 30) + hour <- choose (0, 23) + minute <- choose (0, 59) + second <- choose (0, 59) + --z <- arbitrary + return $ UTCTime (fromGregorian year month day) (secondsToDiffTime (hour * 3600 + minute * 60 + second)) maxSerial = 16777216 @@ -66,12 +66,12 @@ arbitraryX509Cert pubKey = do -} arbitraryX509WithPublicKey pubKey = do - cert <- arbitraryCertificate (PubKeyRSA pubKey) - sig <- resize 40 $ listOf1 arbitrary - let sigalg = SignatureALG HashMD5 PubKeyALG_RSA - let (signedExact, ()) = objectToSignedExact (\_ -> (B.pack sig,sigalg,())) cert - return signedExact + cert <- arbitraryCertificate (PubKeyRSA pubKey) + sig <- resize 40 $ listOf1 arbitrary + let sigalg = SignatureALG HashMD5 PubKeyALG_RSA + let (signedExact, ()) = objectToSignedExact (\_ -> (B.pack sig,sigalg,())) cert + return signedExact arbitraryX509 = do - let pubKey = fst $ getGlobalRSAPair - arbitraryX509WithPublicKey pubKey + let pubKey = fst $ getGlobalRSAPair + arbitraryX509WithPublicKey pubKey diff --git a/core/Tests/PipeChan.hs b/core/Tests/PipeChan.hs index d656c1c..889463b 100644 --- a/core/Tests/PipeChan.hs +++ b/core/Tests/PipeChan.hs @@ -1,13 +1,13 @@ -- create a similar concept than a unix pipe. module PipeChan - ( PipeChan(..) - , newPipe - , runPipe - , readPipeA - , readPipeB - , writePipeA - , writePipeB - ) where + ( PipeChan(..) + , newPipe + , runPipe + , readPipeA + , readPipeB + , writePipeA + , writePipeB + ) where import Control.Applicative import Control.Concurrent.Chan @@ -42,15 +42,15 @@ writePipeB (PipeChan _ _ _ s) = writeChan $ getReadUniPipe s -- helper to read buffered data. readBuffered buf chan sz = do - left <- readIORef buf - if B.length left >= sz - then do - let (ret, nleft) = B.splitAt sz left - writeIORef buf nleft - return ret - else do - let newSize = (sz - B.length left) - newData <- readChan chan - writeIORef buf newData - remain <- readBuffered buf chan newSize - return (left `B.append` remain) + left <- readIORef buf + if B.length left >= sz + then do + let (ret, nleft) = B.splitAt sz left + writeIORef buf nleft + return ret + else do + let newSize = (sz - B.length left) + newData <- readChan chan + writeIORef buf newData + remain <- readBuffered buf chan newSize + return (left `B.append` remain) diff --git a/core/Tests/PubKey.hs b/core/Tests/PubKey.hs index f5715b2..31a4892 100644 --- a/core/Tests/PubKey.hs +++ b/core/Tests/PubKey.hs @@ -1,8 +1,8 @@ module PubKey - ( arbitraryRSAPair - , globalRSAPair - , getGlobalRSAPair - ) where + ( arbitraryRSAPair + , globalRSAPair + , getGlobalRSAPair + ) where import Test.QuickCheck @@ -16,8 +16,8 @@ import System.IO.Unsafe arbitraryRSAPair :: Gen (RSA.PublicKey, RSA.PrivateKey) arbitraryRSAPair = do - rng <- (maybe (error "making rng") id . RNG.make . B.pack) `fmap` vector 64 - arbitraryRSAPairWithRNG rng + rng <- (maybe (error "making rng") id . RNG.make . B.pack) `fmap` vector 64 + arbitraryRSAPairWithRNG rng arbitraryRSAPairWithRNG rng = return $ fst $ RSA.generate rng 128 0x10001 diff --git a/core/Tests/Tests.hs b/core/Tests/Tests.hs index 2c582b7..93daf8e 100644 --- a/core/Tests/Tests.hs +++ b/core/Tests/Tests.hs @@ -30,43 +30,42 @@ genByteString :: Int -> Gen B.ByteString genByteString i = B.pack <$> vector i instance Arbitrary Version where - arbitrary = elements [ SSL2, SSL3, TLS10, TLS11, TLS12 ] + arbitrary = elements [ SSL2, SSL3, TLS10, TLS11, TLS12 ] instance Arbitrary ProtocolType where - arbitrary = elements - [ ProtocolType_ChangeCipherSpec - , ProtocolType_Alert - , ProtocolType_Handshake - , ProtocolType_AppData ] + arbitrary = elements + [ ProtocolType_ChangeCipherSpec + , ProtocolType_Alert + , ProtocolType_Handshake + , ProtocolType_AppData ] #if MIN_VERSION_QuickCheck(2,3,0) #else instance Arbitrary Word8 where - arbitrary = fromIntegral <$> (choose (0,255) :: Gen Int) + arbitrary = fromIntegral <$> (choose (0,255) :: Gen Int) instance Arbitrary Word16 where - arbitrary = fromIntegral <$> (choose (0,65535) :: Gen Int) + arbitrary = fromIntegral <$> (choose (0,65535) :: Gen Int) #endif instance Arbitrary Header where - arbitrary = Header <$> arbitrary <*> arbitrary <*> arbitrary + arbitrary = Header <$> arbitrary <*> arbitrary <*> arbitrary instance Arbitrary ClientRandom where - arbitrary = ClientRandom <$> (genByteString 32) + arbitrary = ClientRandom <$> (genByteString 32) instance Arbitrary ServerRandom where - arbitrary = ServerRandom <$> (genByteString 32) + arbitrary = ServerRandom <$> (genByteString 32) instance Arbitrary Session where - arbitrary = do - i <- choose (1,2) :: Gen Int - case i of - 2 -> liftM (Session . Just) (genByteString 32) - _ -> return $ Session Nothing + arbitrary = do + i <- choose (1,2) :: Gen Int + case i of + 2 -> liftM (Session . Just) (genByteString 32) + _ -> return $ Session Nothing instance Arbitrary CertVerifyData where - arbitrary = do - liftM CertVerifyData (genByteString 128) + arbitrary = liftM CertVerifyData (genByteString 128) arbitraryCiphersIDs :: Gen [Word16] arbitraryCiphersIDs = choose (0,200) >>= vector @@ -78,38 +77,38 @@ someWords8 :: Int -> Gen [Word8] someWords8 i = replicateM i (fromIntegral <$> (choose (0,255) :: Gen Int)) instance Arbitrary CertificateType where - arbitrary = elements - [ CertificateType_RSA_Sign, CertificateType_DSS_Sign - , CertificateType_RSA_Fixed_DH, CertificateType_DSS_Fixed_DH - , CertificateType_RSA_Ephemeral_DH, CertificateType_DSS_Ephemeral_DH - , CertificateType_fortezza_dms ] + arbitrary = elements + [ CertificateType_RSA_Sign, CertificateType_DSS_Sign + , CertificateType_RSA_Fixed_DH, CertificateType_DSS_Fixed_DH + , CertificateType_RSA_Ephemeral_DH, CertificateType_DSS_Ephemeral_DH + , CertificateType_fortezza_dms ] instance Arbitrary Handshake where - arbitrary = oneof - [ ClientHello - <$> arbitrary - <*> arbitrary - <*> arbitrary - <*> arbitraryCiphersIDs - <*> arbitraryCompressionIDs - <*> (return []) - <*> (return Nothing) - , ServerHello - <$> arbitrary - <*> arbitrary - <*> arbitrary - <*> arbitrary - <*> arbitrary - <*> (return []) - , liftM Certificates (CertificateChain <$> (resize 2 $ listOf $ arbitraryX509)) - , pure HelloRequest - , pure ServerHelloDone - , ClientKeyXchg <$> genByteString 48 - --, liftM ServerKeyXchg - , liftM3 CertRequest arbitrary (return Nothing) (return []) - , liftM2 CertVerify (return Nothing) arbitrary - , Finished <$> (genByteString 12) - ] + arbitrary = oneof + [ ClientHello + <$> arbitrary + <*> arbitrary + <*> arbitrary + <*> arbitraryCiphersIDs + <*> arbitraryCompressionIDs + <*> (return []) + <*> (return Nothing) + , ServerHello + <$> arbitrary + <*> arbitrary + <*> arbitrary + <*> arbitrary + <*> arbitrary + <*> (return []) + , liftM Certificates (CertificateChain <$> (resize 2 $ listOf $ arbitraryX509)) + , pure HelloRequest + , pure ServerHelloDone + , ClientKeyXchg <$> genByteString 48 + --, liftM ServerKeyXchg + , liftM3 CertRequest arbitrary (return Nothing) (return []) + , liftM2 CertVerify (return Nothing) arbitrary + , Finished <$> (genByteString 12) + ] {- quickcheck property -} @@ -118,45 +117,44 @@ prop_header_marshalling_id x = (decodeHeader $ encodeHeader x) == Right x prop_handshake_marshalling_id :: Handshake -> Bool prop_handshake_marshalling_id x = (decodeHs $ encodeHandshake x) == Right x - where - decodeHs b = either (Left . id) (uncurry (decodeHandshake cp) . head) $ decodeHandshakes b - cp = CurrentParams { cParamsVersion = TLS10, cParamsKeyXchgType = CipherKeyExchange_RSA, cParamsSupportNPN = True } + where decodeHs b = either (Left . id) (uncurry (decodeHandshake cp) . head) $ decodeHandshakes b + cp = CurrentParams { cParamsVersion = TLS10, cParamsKeyXchgType = CipherKeyExchange_RSA, cParamsSupportNPN = True } prop_pipe_work :: PropertyM IO () prop_pipe_work = do - pipe <- run newPipe - _ <- run (runPipe pipe) + pipe <- run newPipe + _ <- run (runPipe pipe) - let bSize = 16 - n <- pick (choose (1, 32)) + let bSize = 16 + n <- pick (choose (1, 32)) - let d1 = B.replicate (bSize * n) 40 - let d2 = B.replicate (bSize * n) 45 + let d1 = B.replicate (bSize * n) 40 + let d2 = B.replicate (bSize * n) 45 - d1' <- run (writePipeA pipe d1 >> readPipeB pipe (B.length d1)) - d1 `assertEq` d1' + d1' <- run (writePipeA pipe d1 >> readPipeB pipe (B.length d1)) + d1 `assertEq` d1' - d2' <- run (writePipeB pipe d2 >> readPipeA pipe (B.length d2)) - d2 `assertEq` d2' + d2' <- run (writePipeB pipe d2 >> readPipeA pipe (B.length d2)) + d2 `assertEq` d2' - return () + return () establish_data_pipe params tlsServer tlsClient = do - -- initial setup - pipe <- newPipe - _ <- (runPipe pipe) - startQueue <- newChan - resultQueue <- newChan + -- initial setup + pipe <- newPipe + _ <- (runPipe pipe) + startQueue <- newChan + resultQueue <- newChan - (cCtx, sCtx) <- newPairContext pipe params + (cCtx, sCtx) <- newPairContext pipe params - _ <- forkIO $ E.catch (tlsServer sCtx resultQueue) (printAndRaise "server") - _ <- forkIO $ E.catch (tlsClient startQueue cCtx) (printAndRaise "client") + _ <- forkIO $ E.catch (tlsServer sCtx resultQueue) (printAndRaise "server") + _ <- forkIO $ E.catch (tlsClient startQueue cCtx) (printAndRaise "client") - return (startQueue, resultQueue) - where - printAndRaise :: String -> SomeException -> IO () - printAndRaise s e = putStrLn (s ++ " exception: " ++ show e) >> throw e + return (startQueue, resultQueue) + where + printAndRaise :: String -> SomeException -> IO () + printAndRaise s e = putStrLn (s ++ " exception: " ++ show e) >> throw e recvDataNonNull ctx = recvData ctx >>= \l -> if B.null l then recvDataNonNull ctx else return l @@ -273,21 +271,20 @@ assertEq expected got = unless (expected == got) $ error ("got " ++ show got ++ main :: IO () main = defaultMain - [ tests_marshalling - , tests_handshake - ] - where - -- lowlevel tests to check the packet marshalling. - tests_marshalling = testGroup "Marshalling" - [ testProperty "Header" prop_header_marshalling_id - , testProperty "Handshake" prop_handshake_marshalling_id - ] + [ tests_marshalling + , tests_handshake + ] + where -- lowlevel tests to check the packet marshalling. + tests_marshalling = testGroup "Marshalling" + [ testProperty "Header" prop_header_marshalling_id + , testProperty "Handshake" prop_handshake_marshalling_id + ] - -- high level tests between a client and server with fake ciphers. - tests_handshake = testGroup "Handshakes" - [ testProperty "setup" (monadicIO prop_pipe_work) - , testProperty "initiate" (monadicIO prop_handshake_initiate) - , testProperty "initiate with npn" (monadicIO prop_handshake_npn_initiate) - , testProperty "renegociation" (monadicIO prop_handshake_renegociation) - , testProperty "resumption" (monadicIO prop_handshake_session_resumption) - ] + -- high level tests between a client and server with fake ciphers. + tests_handshake = testGroup "Handshakes" + [ testProperty "setup" (monadicIO prop_pipe_work) + , testProperty "initiate" (monadicIO prop_handshake_initiate) + , testProperty "initiate with npn" (monadicIO prop_handshake_npn_initiate) + , testProperty "renegociation" (monadicIO prop_handshake_renegociation) + , testProperty "resumption" (monadicIO prop_handshake_session_resumption) + ]