hs-tls/core/Network/TLS/Handshake/Common.hs

141 lines
5.5 KiB
Haskell
Raw Normal View History

{-# LANGUAGE DeriveDataTypeable, OverloadedStrings #-}
module Network.TLS.Handshake.Common
( HandshakeFailed(..)
, handshakeFailed
, errorToAlert
, unexpected
, newSession
, handshakeTerminate
-- * sending packets
, sendChangeCipherAndFinish
-- * receiving packets
, recvChangeCipherAndFinish
, RecvState(..)
, runRecvState
, recvPacketHandshake
) where
import Control.Concurrent.MVar
import Network.TLS.Context
import Network.TLS.Session
import Network.TLS.Struct
import Network.TLS.IO
import Network.TLS.State hiding (getNegotiatedProtocol)
import Network.TLS.Handshake.Process
import Network.TLS.Record.State
import Network.TLS.Measurement
2013-07-23 07:14:48 +00:00
import Network.TLS.Types
import Network.TLS.Cipher
import Network.TLS.Util
import Data.Data
import Data.ByteString.Char8 ()
import Control.Monad.State
import Control.Exception (throwIO, Exception())
data HandshakeFailed = HandshakeFailed TLSError
2013-07-10 06:37:52 +00:00
deriving (Show,Eq,Typeable)
instance Exception HandshakeFailed
handshakeFailed :: TLSError -> IO ()
handshakeFailed err = throwIO $ HandshakeFailed err
errorToAlert :: TLSError -> Packet
errorToAlert (Error_Protocol (_, _, ad)) = Alert [(AlertLevel_Fatal, ad)]
errorToAlert _ = Alert [(AlertLevel_Fatal, InternalError)]
unexpected :: MonadIO m => String -> Maybe [Char] -> m a
unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" expected: " ++) expected)
newSession :: MonadIO m => Context -> m Session
newSession ctx
2013-07-10 06:37:52 +00:00
| pUseSession $ ctxParams ctx = getStateRNG ctx 32 >>= return . Session . Just
| otherwise = return $ Session Nothing
-- | when a new handshake is done, wrap up & clean up.
handshakeTerminate :: MonadIO m => Context -> m ()
handshakeTerminate ctx = do
2013-07-10 06:37:52 +00:00
session <- usingState_ ctx getSession
-- only callback the session established if we have a session
case session of
Session (Just sessionId) -> do
sessionData <- getSessionData ctx
withSessionManager (ctxParams ctx) (\s -> liftIO $ sessionEstablish s sessionId (fromJust "session-data" sessionData))
2013-07-10 06:37:52 +00:00
_ -> return ()
-- forget all handshake data now and reset bytes counters.
usingState_ ctx endHandshake
updateMeasure ctx resetBytesCounters
-- mark the secure connection up and running.
setEstablished ctx True
return ()
2013-07-23 07:14:48 +00:00
sendChangeCipherAndFinish :: MonadIO m => Context -> Role -> m ()
sendChangeCipherAndFinish ctx role = do
2013-07-10 06:37:52 +00:00
sendPacket ctx ChangeCipherSpec
2013-07-23 07:14:48 +00:00
when (role == ClientRole) $ do
2013-07-10 06:37:52 +00:00
let cparams = getClientParams $ ctxParams ctx
suggest <- usingState_ ctx $ getServerNextProtocolSuggest
case (onNPNServerSuggest cparams, suggest) of
-- client offered, server picked up. send NPN handshake.
(Just io, Just protos) -> do proto <- liftIO $ io protos
sendPacket ctx (Handshake [HsNextProtocolNegotiation proto])
usingState_ ctx $ setNegotiatedProtocol proto
-- client offered, server didn't pick up. do nothing.
(Just _, Nothing) -> return ()
-- client didn't offer. do nothing.
(Nothing, _) -> return ()
2013-07-10 06:37:52 +00:00
liftIO $ contextFlush ctx
2013-07-23 07:39:52 +00:00
cf <- usingState_ ctx getVersion >>= \ver -> usingHState ctx $ getHandshakeDigest ver role
2013-07-10 06:37:52 +00:00
sendPacket ctx (Handshake [Finished cf])
liftIO $ contextFlush ctx
recvChangeCipherAndFinish :: MonadIO m => Context -> m ()
recvChangeCipherAndFinish ctx = runRecvState ctx (RecvStateNext expectChangeCipher)
2013-07-10 06:37:52 +00:00
where expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish
expectChangeCipher p = unexpected (show p) (Just "change cipher")
expectFinish (Finished _) = return RecvStateDone
expectFinish p = unexpected (show p) (Just "Handshake Finished")
data RecvState m =
2013-07-10 06:37:52 +00:00
RecvStateNext (Packet -> m (RecvState m))
| RecvStateHandshake (Handshake -> m (RecvState m))
| RecvStateDone
recvPacketHandshake :: MonadIO m => Context -> m [Handshake]
recvPacketHandshake ctx = do
2013-07-10 06:37:52 +00:00
pkts <- recvPacket ctx
case pkts of
Right (Handshake l) -> return l
Right x -> fail ("unexpected type received. expecting handshake and got: " ++ show x)
Left err -> throwCore err
runRecvState :: MonadIO m => Context -> RecvState m -> m ()
runRecvState _ (RecvStateDone) = return ()
runRecvState ctx (RecvStateNext f) = recvPacket ctx >>= either throwCore f >>= runRecvState ctx
runRecvState ctx iniState = recvPacketHandshake ctx >>= loop iniState >>= runRecvState ctx
2013-07-10 06:37:52 +00:00
where
loop :: MonadIO m => RecvState m -> [Handshake] -> m (RecvState m)
loop recvState [] = return recvState
loop (RecvStateHandshake f) (x:xs) = do
nstate <- f x
2013-07-30 05:14:09 +00:00
processHandshake ctx x
2013-07-10 06:37:52 +00:00
loop nstate xs
loop _ _ = unexpected "spurious handshake" Nothing
getSessionData :: MonadIO m => Context -> m (Maybe SessionData)
getSessionData ctx = do
ver <- usingState_ ctx getVersion
mms <- usingHState ctx (gets hstMasterSecret)
tx <- liftIO $ readMVar (ctxTxState ctx)
case mms of
Nothing -> return Nothing
Just ms -> return $ Just $ SessionData
{ sessionVersion = ver
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher tx
, sessionSecret = ms
}