hs-tls/Network/TLS/Server.hs

145 lines
4.8 KiB
Haskell

{-# LANGUAGE GeneralizedNewtypeDeriving, MultiParamTypeClasses #-}
-- |
-- Module : Network.TLS.Server
-- License : BSD-style
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
-- Stability : experimental
-- Portability : unknown
--
-- the Server module contains the necessary calls to create a listening TLS socket
-- aka. a server socket.
--
module Network.TLS.Server
( server
-- * API, warning probably subject to change
, listen
, sendData
, recvData
) where
import Data.Maybe
import Data.List (intersect, find)
import Control.Monad.Trans
import Control.Monad.State
import Control.Applicative ((<$>))
import Network.TLS.Core
import Network.TLS.Cipher
import Network.TLS.Struct
import Network.TLS.State
import Network.TLS.SRandom
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import System.IO (Handle, hFlush)
server :: MonadIO m => TLSParams -> SRandomGen -> Handle -> m TLSCtx
server params rng handle = liftIO $ newCtx handle params state
where state = (newTLSState rng) { stClientContext = False }
handleClientHello :: MonadIO m => TLSCtx -> Handshake -> m ()
handleClientHello ctx (ClientHello ver _ _ ciphers compressionID _) = do
let cfg = getParams ctx
when (not $ elem ver (pAllowedVersions cfg)) $ do
{- unsupported version -}
fail "unsupported version"
let commonCiphers = intersect ciphers (map cipherID $ pCiphers cfg)
when (commonCiphers == []) $ do
{- unsupported cipher -}
fail ("unsupported cipher: " ++ show ciphers ++ " : server : " ++ (show $ map cipherID $ pCiphers cfg))
when (not $ elem 0 compressionID) $ do
{- unsupported compression -}
fail "unsupported compression"
usingState_ ctx $ modify (\st -> st
{ stVersion = ver
, stCipher = find (\c -> cipherID c == (head commonCiphers)) (pCiphers cfg)
})
handleClientHello _ _ = do
fail "unexpected handshake type received. expecting client hello"
handshakeSendServerData :: MonadIO m => TLSCtx -> m ()
handshakeSendServerData ctx = do
srand <- getStateRNG ctx 32 >>= return . fromJust . serverRandom
let sp = getParams ctx
--st <- get >>= return . scTLSState
cipher <- usingState_ ctx (fromJust . stCipher <$> get)
ver <- usingState_ ctx (stVersion <$> get)
let srvhello = ServerHello ver srand (Session Nothing) (cipherID cipher) 0 Nothing
let srvCerts = Certificates $ map fst $ pCertificates sp
case map snd $ pCertificates sp of
(Just privkey : _) -> usingState_ ctx $ setPrivateKey privkey
_ -> return () -- return a sensible error
-- in TLS12, we need to check as well the certificates we are sending if they have in the extension
-- the necessary bits set.
let needkeyxchg = cipherExchangeNeedMoreData $ cipherKeyExchange cipher
sendPacket ctx (Handshake srvhello)
sendPacket ctx (Handshake srvCerts)
when needkeyxchg $ do
let skg = SKX_RSA Nothing
sendPacket ctx (Handshake $ ServerKeyXchg skg)
-- FIXME we don't do this on a Anonymous server
when (pWantClientCert sp) $ do
let certTypes = [ CertificateType_RSA_Sign ]
let creq = CertRequest certTypes Nothing [0,0,0]
sendPacket ctx (Handshake creq)
sendPacket ctx (Handshake ServerHelloDone)
handshakeSendFinish :: MonadIO m => TLSCtx -> m ()
handshakeSendFinish ctx = do
cf <- usingState_ ctx $ getHandshakeDigest False
sendPacket ctx (Handshake $ Finished $ B.unpack cf)
{- after receiving a client hello, we need to redo a handshake -}
handshake :: MonadIO m => TLSCtx -> m ()
handshake ctx = do
handshakeSendServerData ctx
liftIO $ hFlush $ getHandle ctx
whileStatus ctx (/= (StatusHandshake HsStatusClientFinished)) (recvPacket ctx)
sendPacket ctx ChangeCipherSpec
handshakeSendFinish ctx
liftIO $ hFlush $ getHandle ctx
return ()
{- | listen on a handle to a new TLS connection. -}
listen :: MonadIO m => TLSCtx -> m ()
listen ctx = do
pkts <- recvPacket ctx
case pkts of
Right [Handshake hs] -> handleClientHello ctx hs
x -> fail ("unexpected type received. expecting handshake ++ " ++ show x)
handshake ctx
{- | sendData sends a bunch of data -}
sendData :: MonadIO m => TLSCtx -> L.ByteString -> m ()
sendData ctx dataToSend = mapM_ sendDataChunk (L.toChunks dataToSend)
where sendDataChunk d =
if B.length d > 16384
then do
let (sending, remain) = B.splitAt 16384 d
sendPacket ctx $ AppData sending
sendDataChunk remain
else
sendPacket ctx $ AppData d
{- | recvData get data out of Data packet, and automatically renegociate if
- a Handshake ClientHello is received -}
recvData :: MonadIO m => TLSCtx -> m L.ByteString
recvData ctx = do
pkt <- recvPacket ctx
case pkt of
Right [Handshake (ClientHello _ _ _ _ _ _)] -> handshake ctx >> recvData ctx
Right [AppData x] -> return $ L.fromChunks [x]
Left err -> error ("error received: " ++ show err)
_ -> error "unexpected item"