modify client API to be like the server API.

This commit is contained in:
Vincent Hanquez 2011-03-01 20:01:40 +00:00
parent 5481816e0e
commit f260c5b9cf
4 changed files with 81 additions and 150 deletions

View file

@ -11,155 +11,93 @@
-- aka. a client socket. -- aka. a client socket.
-- --
module Network.TLS.Client module Network.TLS.Client
( TLSParams(..) ( client
, TLSStateClient
, TLSClient (..)
, runTLSClient
-- * low level packet sending receiving.
, recvPacket
, sendPacket
-- * API, warning probably subject to change -- * API, warning probably subject to change
, initiate , initiate
, connect
, sendData , sendData
, recvData , recvData
, close , close
) where ) where
import Data.Maybe import Data.Maybe
import Control.Applicative ((<$>))
import Control.Monad.Trans import Control.Monad.Trans
import Control.Monad.State import Control.Monad.State
import Network.TLS.Cipher import Network.TLS.Cipher
import Network.TLS.Compression import Network.TLS.Compression
import Network.TLS.Struct import Network.TLS.Struct
import Network.TLS.Packet
import Network.TLS.State import Network.TLS.State
import Network.TLS.Sending
import Network.TLS.Receiving
import Network.TLS.SRandom import Network.TLS.SRandom
import Network.TLS.Core (TLSParams(..)) import Network.TLS.Core
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
import System.IO (Handle, hFlush) import System.IO (Handle, hFlush)
import Data.List (find) import Data.List (find)
data TLSStateClient = TLSStateClient client :: MonadIO m => TLSParams -> SRandomGen -> Handle -> m TLSCtx
{ scParams :: TLSParams -- ^ client params and config for this connection client params rng handle = liftIO $ newCtx handle params state
, scTLSState :: TLSState -- ^ client TLS State for this connection where state = (newTLSState rng) { stClientContext = True }
, scCertRequested :: Bool -- ^ mark that the server requested a certificate
}
newtype TLSClient m a = TLSClient { runTLSC :: StateT TLSStateClient m a } processServerInfo :: MonadIO m => TLSCtx -> Packet -> m ()
deriving (Monad, MonadState TLSStateClient) processServerInfo ctx (Handshake (ServerHello ver _ _ cipher _ _)) = do
let ciphers = pCiphers $ getParams ctx
let allowedvers = pAllowedVersions $ getParams ctx
instance MonadTrans TLSClient where
lift = TLSClient . lift
instance (Functor m, Monad m) => Functor (TLSClient m) where
fmap f = TLSClient . fmap f . runTLSC
runTLSClientST :: TLSClient m a -> TLSStateClient -> m (a, TLSStateClient)
runTLSClientST f s = runStateT (runTLSC f) s
runTLSClient :: TLSClient m a -> TLSParams -> SRandomGen -> m (a, TLSStateClient)
runTLSClient f params rng = runTLSClientST f (TLSStateClient { scParams = params, scTLSState = state, scCertRequested = False })
where state = (newTLSState rng) { stVersion = pConnectVersion params, stClientContext = True }
usingState :: Monad m => TLSSt a -> TLSClient m (Either TLSError a)
usingState f =
get >>= return . scTLSState >>= execAndStore
where
execAndStore st = do
let (a, newst) = runTLSState f st
modify (\stateclient -> stateclient { scTLSState = newst })
return a
usingState_ f = do
ret <- usingState f
case ret of
Left err -> error "assertion failed, error in path without an error"
Right r -> return r
getStateRNG n = usingState_ (withTLSRNG (\rng -> getRandomBytes rng n))
{- | receive a single TLS packet or on error a TLSError -}
recvPacket :: Handle -> TLSClient IO (Either TLSError [Packet])
recvPacket handle = do
hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader
case hdr of
Left err -> return $ Left err
Right header@(Header _ _ readlen) -> do
content <- lift $ B.hGet handle (fromIntegral readlen)
usingState $ readPacket header (EncryptedData content)
{- | send a single TLS packet -}
sendPacket :: Handle -> Packet -> TLSClient IO ()
sendPacket handle pkt = do
dataToSend <- usingState_ $ writePacket pkt
lift $ B.hPut handle dataToSend
processServerInfo :: Packet -> TLSClient IO ()
processServerInfo (Handshake (ServerHello ver _ _ cipher _ _)) = do
ciphers <- pCiphers . scParams <$> get
allowedvers <- pAllowedVersions . scParams <$> get
case find ((==) ver) allowedvers of case find ((==) ver) allowedvers of
Nothing -> error ("received version which is not allowed: " ++ show ver) Nothing -> error ("received version which is not allowed: " ++ show ver)
Just _ -> usingState_ $ setVersion ver Just _ -> usingState_ ctx $ setVersion ver
case find ((==) cipher . cipherID) ciphers of case find ((==) cipher . cipherID) ciphers of
Nothing -> error "no cipher in common with the server" Nothing -> error "no cipher in common with the server"
Just c -> usingState_ $ setCipher c Just c -> usingState_ ctx $ setCipher c
processServerInfo (Handshake (CertRequest _ _ _)) = do processServerInfo _ (Handshake (CertRequest _ _ _)) = do
modify (\sc -> sc { scCertRequested = True }) return ()
--modify (\sc -> sc { scCertRequested = True })
processServerInfo (Handshake (Certificates certs)) = do processServerInfo ctx (Handshake (Certificates certs)) = do
cb <- onCertificatesRecv . scParams <$> get let cb = onCertificatesRecv $ getParams ctx
valid <- lift $ cb certs valid <- liftIO $ cb certs
unless valid $ error "certificates received deemed invalid by user" unless valid $ error "certificates received deemed invalid by user"
processServerInfo _ = return () processServerInfo _ _ = return ()
recvServerInfo :: Handle -> TLSClient IO () recvServerInfo :: MonadIO m => TLSCtx -> m ()
recvServerInfo handle = do recvServerInfo ctx = do
whileStatus (/= (StatusHandshake HsStatusServerHelloDone)) $ do whileStatus ctx (/= (StatusHandshake HsStatusServerHelloDone)) $ do
pkts <- recvPacket handle pkts <- recvPacket ctx
case pkts of case pkts of
Left err -> error ("error received: " ++ show err) Left err -> error ("error received: " ++ show err)
Right l -> forM_ l processServerInfo Right l -> mapM_ (processServerInfo ctx) l
connectSendClientHello :: MonadIO m => TLSCtx -> m ()
connectSendClientHello ctx = do
crand <- getStateRNG ctx 32 >>= return . fromJust . clientRandom
sendPacket ctx $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) (map compressionID compressions) Nothing)
where where
whileStatus p a = do params = getParams ctx
b <- usingState_ (p . stStatus <$> get) ver = pConnectVersion params
when b (a >> whileStatus p a) ciphers = pCiphers params
compressions = pCompressions params
connectSendClientHello :: Handle -> TLSClient IO () connectSendClientCertificate :: MonadIO m => TLSCtx -> m ()
connectSendClientHello handle = do connectSendClientCertificate ctx = do
crand <- fromJust . clientRandom <$> getStateRNG 32 certRequested <- return False -- scCertRequested <$> get
ver <- pConnectVersion . scParams <$> get
ciphers <- pCiphers . scParams <$> get
compressions <- pCompressions . scParams <$> get
sendPacket handle $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) (map compressionID compressions) Nothing)
connectSendClientCertificate :: Handle -> TLSClient IO ()
connectSendClientCertificate handle = do
certRequested <- scCertRequested <$> get
when certRequested $ do when certRequested $ do
clientCerts <- map fst . pCertificates . scParams <$> get let clientCerts = map fst $ pCertificates $ getParams ctx
sendPacket handle $ Handshake (Certificates clientCerts) sendPacket ctx $ Handshake (Certificates clientCerts)
connectSendClientKeyXchg :: Handle -> TLSClient IO () connectSendClientKeyXchg :: MonadIO m => TLSCtx -> m ()
connectSendClientKeyXchg handle = do connectSendClientKeyXchg ctx = do
prerand <- ClientKeyData <$> getStateRNG 46 prerand <- getStateRNG ctx 46 >>= return . ClientKeyData
ver <- pConnectVersion . scParams <$> get let ver = pConnectVersion $ getParams ctx
sendPacket handle $ Handshake (ClientKeyXchg ver prerand) sendPacket ctx $ Handshake (ClientKeyXchg ver prerand)
connectSendFinish :: Handle -> TLSClient IO () connectSendFinish :: MonadIO m => TLSCtx -> m ()
connectSendFinish handle = do connectSendFinish ctx = do
cf <- usingState_ $ getHandshakeDigest True cf <- usingState_ ctx $ getHandshakeDigest True
sendPacket handle (Handshake $ Finished $ B.unpack cf) sendPacket ctx (Handshake $ Finished $ B.unpack cf)
{- | initiate a new TLS connection through a handshake on a handle. -} {- | initiate a new TLS connection through a handshake on a handle. -}
initiate :: Handle -> TLSClient IO () initiate :: MonadIO m => TLSCtx -> m ()
initiate handle = do initiate handle = do
connectSendClientHello handle connectSendClientHello handle
recvServerInfo handle recvServerInfo handle
@ -171,7 +109,7 @@ initiate handle = do
{- FIXME not implemented yet -} {- FIXME not implemented yet -}
sendPacket handle (ChangeCipherSpec) sendPacket handle (ChangeCipherSpec)
lift $ hFlush handle liftIO $ hFlush $ getHandle handle
{- send Finished -} {- send Finished -}
connectSendFinish handle connectSendFinish handle
@ -184,38 +122,32 @@ initiate handle = do
return () return ()
{-# DEPRECATED connect "use initiate" #-} {- | sendData sends a bunch of data -}
connect :: Handle -> TLSClient IO () sendData :: MonadIO m => TLSCtx -> L.ByteString -> m ()
connect = initiate sendData ctx dataToSend = mapM_ sendDataChunk (L.toChunks dataToSend)
where sendDataChunk d =
sendDataChunk :: Handle -> B.ByteString -> TLSClient IO ()
sendDataChunk handle d =
if B.length d > 16384 if B.length d > 16384
then do then do
let (sending, remain) = B.splitAt 16384 d let (sending, remain) = B.splitAt 16384 d
sendPacket handle $ AppData sending sendPacket ctx $ AppData sending
sendDataChunk handle remain sendDataChunk remain
else else
sendPacket handle $ AppData d sendPacket ctx $ AppData d
{- | sendData sends a bunch of data -}
sendData :: Handle -> L.ByteString -> TLSClient IO ()
sendData handle d = mapM_ (sendDataChunk handle) (L.toChunks d)
{- | recvData get data out of Data packet, and automatically try to renegociate if {- | recvData get data out of Data packet, and automatically renegociate if
- a Handshake HelloRequest is received -} - a Handshake ClientHello is received -}
recvData :: Handle -> TLSClient IO L.ByteString recvData :: MonadIO m => TLSCtx -> m L.ByteString
recvData handle = do recvData handle = do
pkt <- recvPacket handle pkt <- recvPacket handle
case pkt of case pkt of
Right [AppData x] -> return $ L.fromChunks [x] Right [AppData x] -> return $ L.fromChunks [x]
Right [Handshake HelloRequest] -> connect handle >> recvData handle Right [Handshake HelloRequest] -> initiate handle >> recvData handle
Left err -> error ("error received: " ++ show err) Left err -> error ("error received: " ++ show err)
_ -> error "unexpected item" _ -> error "unexpected item"
{- | close a TLS connection. {- | close a TLS connection.
- note that it doesn't close the handle, but just signal we're going to close - note that it doesn't close the handle, but just signal we're going to close
- the connection to the other side -} - the connection to the other side -}
close :: Handle -> TLSClient IO () close :: MonadIO m => TLSCtx -> m ()
close handle = do close ctx = sendPacket ctx $ Alert (AlertLevel_Warning, CloseNotify)
sendPacket handle $ Alert (AlertLevel_Warning, CloseNotify)

View file

@ -32,13 +32,11 @@ import Network.TLS.SRandom
import Data.Certificate.X509 import Data.Certificate.X509
import Data.List (intercalate) import Data.List (intercalate)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Control.Applicative ((<$>)) import Control.Applicative ((<$>))
import Control.Concurrent.MVar import Control.Concurrent.MVar
import Control.Monad (when, unless) --import Control.Monad (when, unless)
import Control.Monad.State import Control.Monad.State
import Control.Monad.Trans (MonadIO, liftIO)
import System.IO (Handle, hSetBuffering, BufferMode(..)) import System.IO (Handle, hSetBuffering, BufferMode(..))
data TLSParams = TLSParams data TLSParams = TLSParams

View file

@ -47,16 +47,15 @@ readOne h = do
Right True -> B.hGetNonBlocking h 4096 Right True -> B.hGetNonBlocking h 4096
Right False -> return B.empty Right False -> return B.empty
tlsclient :: Handle -> Handle -> C.TLSClient IO () tlsclient :: Handle -> TLSCtx -> IO ()
tlsclient srchandle dsthandle = do tlsclient srchandle dsthandle = do
lift $ hSetBuffering dsthandle NoBuffering hSetBuffering srchandle NoBuffering
lift $ hSetBuffering srchandle NoBuffering
C.initiate dsthandle C.initiate dsthandle
loopUntil $ do loopUntil $ do
b <- lift $ readOne srchandle b <- readOne srchandle
lift $ putStrLn ("sending " ++ show b) putStrLn ("sending " ++ show b)
if B.null b if B.null b
then do then do
C.close dsthandle C.close dsthandle
@ -231,8 +230,9 @@ doClient pargs = do
(StunnelSocket dst) <- connectAddressDescription dstaddr (StunnelSocket dst) <- connectAddressDescription dstaddr
dsth <- socketToHandle dst ReadWriteMode dsth <- socketToHandle dst ReadWriteMode
dstctx <- C.client clientstate rng dsth
_ <- forkIO $ finally _ <- forkIO $ finally
(C.runTLSClient (tlsclient srch dsth) clientstate rng >> return ()) (tlsclient srch dstctx)
(hClose srch >> hClose dsth) (hClose srch >> hClose dsth)
return () return ()
AddrFD _ _ -> error "bad error fd. not implemented" AddrFD _ _ -> error "bad error fd. not implemented"

View file

@ -121,8 +121,8 @@ makeValidParams serverCerts = do
{- | setup create all necessary connection point to create a data "pipe" {- | setup create all necessary connection point to create a data "pipe"
- ---(startQueue)---> tlsClient ---(socketPair)---> tlsServer ---(resultQueue)---> - ---(startQueue)---> tlsClient ---(socketPair)---> tlsServer ---(resultQueue)--->
-} -}
setup :: TLSParams -> IO (Handle, TLSCtx, SRandomGen, Chan a, Chan a) setup :: (TLSParams, TLSParams) -> IO (TLSCtx, TLSCtx, Chan a, Chan a)
setup serverState = do setup (clientState, serverState) = do
(cSocket, sSocket) <- socketPair AF_UNIX Stream defaultProtocol (cSocket, sSocket) <- socketPair AF_UNIX Stream defaultProtocol
cHandle <- socketToHandle cSocket ReadWriteMode cHandle <- socketToHandle cSocket ReadWriteMode
sHandle <- socketToHandle sSocket ReadWriteMode sHandle <- socketToHandle sSocket ReadWriteMode
@ -135,20 +135,21 @@ setup serverState = do
startQueue <- newChan startQueue <- newChan
resultQueue <- newChan resultQueue <- newChan
cCtx <- C.client clientState clientRNG cHandle
sCtx <- S.server serverState serverRNG sHandle sCtx <- S.server serverState serverRNG sHandle
return (cHandle, sCtx, clientRNG, startQueue, resultQueue) return (cCtx, sCtx, startQueue, resultQueue)
testInitiate spCert = do testInitiate spCert = do
(clientstate, serverstate) <- pick (makeValidParams spCert) states <- pick (makeValidParams spCert)
(cHandle, sCtx, clientRNG, startQueue, resultQueue) <- run (setup serverstate) (cCtx, sCtx, startQueue, resultQueue) <- run (setup states)
run $ forkIO $ do run $ forkIO $ do
catch (tlsServer sCtx resultQueue) catch (tlsServer sCtx resultQueue)
(\e -> putStrLn ("server exception: " ++ show e) >> throw (e :: SomeException)) (\e -> putStrLn ("server exception: " ++ show e) >> throw (e :: SomeException))
return () return ()
run $ forkIO $ do run $ forkIO $ do
catch (C.runTLSClient (tlsClient startQueue cHandle) clientstate clientRNG) catch (tlsClient startQueue cCtx)
(\e -> putStrLn ("client exception: " ++ show e) >> throw (e :: SomeException)) (\e -> putStrLn ("client exception: " ++ show e) >> throw (e :: SomeException))
return () return ()
@ -161,7 +162,7 @@ testInitiate spCert = do
assert $ d == dres assert $ d == dres
-- cleanup -- cleanup
run $ (hClose cHandle >> hClose (getHandle sCtx)) run $ (hClose (getHandle cCtx) >> hClose (getHandle sCtx))
where where
tlsServer handle queue = do tlsServer handle queue = do
@ -171,7 +172,7 @@ testInitiate spCert = do
return () return ()
tlsClient queue handle = do tlsClient queue handle = do
C.initiate handle C.initiate handle
d <- lift $ readChan queue d <- readChan queue
C.sendData handle d C.sendData handle d
return () return ()