modify client API to be like the server API.

This commit is contained in:
Vincent Hanquez 2011-03-01 20:01:40 +00:00
parent 5481816e0e
commit f260c5b9cf
4 changed files with 81 additions and 150 deletions

View file

@ -11,155 +11,93 @@
-- aka. a client socket.
--
module Network.TLS.Client
( TLSParams(..)
, TLSStateClient
, TLSClient (..)
, runTLSClient
-- * low level packet sending receiving.
, recvPacket
, sendPacket
( client
-- * API, warning probably subject to change
, initiate
, connect
, sendData
, recvData
, close
) where
import Data.Maybe
import Control.Applicative ((<$>))
import Control.Monad.Trans
import Control.Monad.State
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Struct
import Network.TLS.Packet
import Network.TLS.State
import Network.TLS.Sending
import Network.TLS.Receiving
import Network.TLS.SRandom
import Network.TLS.Core (TLSParams(..))
import Network.TLS.Core
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import System.IO (Handle, hFlush)
import Data.List (find)
data TLSStateClient = TLSStateClient
{ scParams :: TLSParams -- ^ client params and config for this connection
, scTLSState :: TLSState -- ^ client TLS State for this connection
, scCertRequested :: Bool -- ^ mark that the server requested a certificate
}
client :: MonadIO m => TLSParams -> SRandomGen -> Handle -> m TLSCtx
client params rng handle = liftIO $ newCtx handle params state
where state = (newTLSState rng) { stClientContext = True }
newtype TLSClient m a = TLSClient { runTLSC :: StateT TLSStateClient m a }
deriving (Monad, MonadState TLSStateClient)
processServerInfo :: MonadIO m => TLSCtx -> Packet -> m ()
processServerInfo ctx (Handshake (ServerHello ver _ _ cipher _ _)) = do
let ciphers = pCiphers $ getParams ctx
let allowedvers = pAllowedVersions $ getParams ctx
instance MonadTrans TLSClient where
lift = TLSClient . lift
instance (Functor m, Monad m) => Functor (TLSClient m) where
fmap f = TLSClient . fmap f . runTLSC
runTLSClientST :: TLSClient m a -> TLSStateClient -> m (a, TLSStateClient)
runTLSClientST f s = runStateT (runTLSC f) s
runTLSClient :: TLSClient m a -> TLSParams -> SRandomGen -> m (a, TLSStateClient)
runTLSClient f params rng = runTLSClientST f (TLSStateClient { scParams = params, scTLSState = state, scCertRequested = False })
where state = (newTLSState rng) { stVersion = pConnectVersion params, stClientContext = True }
usingState :: Monad m => TLSSt a -> TLSClient m (Either TLSError a)
usingState f =
get >>= return . scTLSState >>= execAndStore
where
execAndStore st = do
let (a, newst) = runTLSState f st
modify (\stateclient -> stateclient { scTLSState = newst })
return a
usingState_ f = do
ret <- usingState f
case ret of
Left err -> error "assertion failed, error in path without an error"
Right r -> return r
getStateRNG n = usingState_ (withTLSRNG (\rng -> getRandomBytes rng n))
{- | receive a single TLS packet or on error a TLSError -}
recvPacket :: Handle -> TLSClient IO (Either TLSError [Packet])
recvPacket handle = do
hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader
case hdr of
Left err -> return $ Left err
Right header@(Header _ _ readlen) -> do
content <- lift $ B.hGet handle (fromIntegral readlen)
usingState $ readPacket header (EncryptedData content)
{- | send a single TLS packet -}
sendPacket :: Handle -> Packet -> TLSClient IO ()
sendPacket handle pkt = do
dataToSend <- usingState_ $ writePacket pkt
lift $ B.hPut handle dataToSend
processServerInfo :: Packet -> TLSClient IO ()
processServerInfo (Handshake (ServerHello ver _ _ cipher _ _)) = do
ciphers <- pCiphers . scParams <$> get
allowedvers <- pAllowedVersions . scParams <$> get
case find ((==) ver) allowedvers of
Nothing -> error ("received version which is not allowed: " ++ show ver)
Just _ -> usingState_ $ setVersion ver
Just _ -> usingState_ ctx $ setVersion ver
case find ((==) cipher . cipherID) ciphers of
Nothing -> error "no cipher in common with the server"
Just c -> usingState_ $ setCipher c
Just c -> usingState_ ctx $ setCipher c
processServerInfo (Handshake (CertRequest _ _ _)) = do
modify (\sc -> sc { scCertRequested = True })
processServerInfo _ (Handshake (CertRequest _ _ _)) = do
return ()
--modify (\sc -> sc { scCertRequested = True })
processServerInfo (Handshake (Certificates certs)) = do
cb <- onCertificatesRecv . scParams <$> get
valid <- lift $ cb certs
processServerInfo ctx (Handshake (Certificates certs)) = do
let cb = onCertificatesRecv $ getParams ctx
valid <- liftIO $ cb certs
unless valid $ error "certificates received deemed invalid by user"
processServerInfo _ = return ()
processServerInfo _ _ = return ()
recvServerInfo :: Handle -> TLSClient IO ()
recvServerInfo handle = do
whileStatus (/= (StatusHandshake HsStatusServerHelloDone)) $ do
pkts <- recvPacket handle
recvServerInfo :: MonadIO m => TLSCtx -> m ()
recvServerInfo ctx = do
whileStatus ctx (/= (StatusHandshake HsStatusServerHelloDone)) $ do
pkts <- recvPacket ctx
case pkts of
Left err -> error ("error received: " ++ show err)
Right l -> forM_ l processServerInfo
Right l -> mapM_ (processServerInfo ctx) l
connectSendClientHello :: MonadIO m => TLSCtx -> m ()
connectSendClientHello ctx = do
crand <- getStateRNG ctx 32 >>= return . fromJust . clientRandom
sendPacket ctx $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) (map compressionID compressions) Nothing)
where
whileStatus p a = do
b <- usingState_ (p . stStatus <$> get)
when b (a >> whileStatus p a)
params = getParams ctx
ver = pConnectVersion params
ciphers = pCiphers params
compressions = pCompressions params
connectSendClientHello :: Handle -> TLSClient IO ()
connectSendClientHello handle = do
crand <- fromJust . clientRandom <$> getStateRNG 32
ver <- pConnectVersion . scParams <$> get
ciphers <- pCiphers . scParams <$> get
compressions <- pCompressions . scParams <$> get
sendPacket handle $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) (map compressionID compressions) Nothing)
connectSendClientCertificate :: Handle -> TLSClient IO ()
connectSendClientCertificate handle = do
certRequested <- scCertRequested <$> get
connectSendClientCertificate :: MonadIO m => TLSCtx -> m ()
connectSendClientCertificate ctx = do
certRequested <- return False -- scCertRequested <$> get
when certRequested $ do
clientCerts <- map fst . pCertificates . scParams <$> get
sendPacket handle $ Handshake (Certificates clientCerts)
let clientCerts = map fst $ pCertificates $ getParams ctx
sendPacket ctx $ Handshake (Certificates clientCerts)
connectSendClientKeyXchg :: Handle -> TLSClient IO ()
connectSendClientKeyXchg handle = do
prerand <- ClientKeyData <$> getStateRNG 46
ver <- pConnectVersion . scParams <$> get
sendPacket handle $ Handshake (ClientKeyXchg ver prerand)
connectSendClientKeyXchg :: MonadIO m => TLSCtx -> m ()
connectSendClientKeyXchg ctx = do
prerand <- getStateRNG ctx 46 >>= return . ClientKeyData
let ver = pConnectVersion $ getParams ctx
sendPacket ctx $ Handshake (ClientKeyXchg ver prerand)
connectSendFinish :: Handle -> TLSClient IO ()
connectSendFinish handle = do
cf <- usingState_ $ getHandshakeDigest True
sendPacket handle (Handshake $ Finished $ B.unpack cf)
connectSendFinish :: MonadIO m => TLSCtx -> m ()
connectSendFinish ctx = do
cf <- usingState_ ctx $ getHandshakeDigest True
sendPacket ctx (Handshake $ Finished $ B.unpack cf)
{- | initiate a new TLS connection through a handshake on a handle. -}
initiate :: Handle -> TLSClient IO ()
initiate :: MonadIO m => TLSCtx -> m ()
initiate handle = do
connectSendClientHello handle
recvServerInfo handle
@ -171,7 +109,7 @@ initiate handle = do
{- FIXME not implemented yet -}
sendPacket handle (ChangeCipherSpec)
lift $ hFlush handle
liftIO $ hFlush $ getHandle handle
{- send Finished -}
connectSendFinish handle
@ -184,38 +122,32 @@ initiate handle = do
return ()
{-# DEPRECATED connect "use initiate" #-}
connect :: Handle -> TLSClient IO ()
connect = initiate
sendDataChunk :: Handle -> B.ByteString -> TLSClient IO ()
sendDataChunk handle d =
if B.length d > 16384
then do
let (sending, remain) = B.splitAt 16384 d
sendPacket handle $ AppData sending
sendDataChunk handle remain
else
sendPacket handle $ AppData d
{- | sendData sends a bunch of data -}
sendData :: Handle -> L.ByteString -> TLSClient IO ()
sendData handle d = mapM_ (sendDataChunk handle) (L.toChunks d)
sendData :: MonadIO m => TLSCtx -> L.ByteString -> m ()
sendData ctx dataToSend = mapM_ sendDataChunk (L.toChunks dataToSend)
where sendDataChunk d =
if B.length d > 16384
then do
let (sending, remain) = B.splitAt 16384 d
sendPacket ctx $ AppData sending
sendDataChunk remain
else
sendPacket ctx $ AppData d
{- | recvData get data out of Data packet, and automatically try to renegociate if
- a Handshake HelloRequest is received -}
recvData :: Handle -> TLSClient IO L.ByteString
{- | recvData get data out of Data packet, and automatically renegociate if
- a Handshake ClientHello is received -}
recvData :: MonadIO m => TLSCtx -> m L.ByteString
recvData handle = do
pkt <- recvPacket handle
case pkt of
Right [AppData x] -> return $ L.fromChunks [x]
Right [Handshake HelloRequest] -> connect handle >> recvData handle
Right [Handshake HelloRequest] -> initiate handle >> recvData handle
Left err -> error ("error received: " ++ show err)
_ -> error "unexpected item"
{- | close a TLS connection.
- note that it doesn't close the handle, but just signal we're going to close
- the connection to the other side -}
close :: Handle -> TLSClient IO ()
close handle = do
sendPacket handle $ Alert (AlertLevel_Warning, CloseNotify)
close :: MonadIO m => TLSCtx -> m ()
close ctx = sendPacket ctx $ Alert (AlertLevel_Warning, CloseNotify)

View file

@ -32,13 +32,11 @@ import Network.TLS.SRandom
import Data.Certificate.X509
import Data.List (intercalate)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Control.Applicative ((<$>))
import Control.Concurrent.MVar
import Control.Monad (when, unless)
--import Control.Monad (when, unless)
import Control.Monad.State
import Control.Monad.Trans (MonadIO, liftIO)
import System.IO (Handle, hSetBuffering, BufferMode(..))
data TLSParams = TLSParams

View file

@ -47,16 +47,15 @@ readOne h = do
Right True -> B.hGetNonBlocking h 4096
Right False -> return B.empty
tlsclient :: Handle -> Handle -> C.TLSClient IO ()
tlsclient :: Handle -> TLSCtx -> IO ()
tlsclient srchandle dsthandle = do
lift $ hSetBuffering dsthandle NoBuffering
lift $ hSetBuffering srchandle NoBuffering
hSetBuffering srchandle NoBuffering
C.initiate dsthandle
loopUntil $ do
b <- lift $ readOne srchandle
lift $ putStrLn ("sending " ++ show b)
b <- readOne srchandle
putStrLn ("sending " ++ show b)
if B.null b
then do
C.close dsthandle
@ -231,8 +230,9 @@ doClient pargs = do
(StunnelSocket dst) <- connectAddressDescription dstaddr
dsth <- socketToHandle dst ReadWriteMode
dstctx <- C.client clientstate rng dsth
_ <- forkIO $ finally
(C.runTLSClient (tlsclient srch dsth) clientstate rng >> return ())
(tlsclient srch dstctx)
(hClose srch >> hClose dsth)
return ()
AddrFD _ _ -> error "bad error fd. not implemented"

View file

@ -121,8 +121,8 @@ makeValidParams serverCerts = do
{- | setup create all necessary connection point to create a data "pipe"
- ---(startQueue)---> tlsClient ---(socketPair)---> tlsServer ---(resultQueue)--->
-}
setup :: TLSParams -> IO (Handle, TLSCtx, SRandomGen, Chan a, Chan a)
setup serverState = do
setup :: (TLSParams, TLSParams) -> IO (TLSCtx, TLSCtx, Chan a, Chan a)
setup (clientState, serverState) = do
(cSocket, sSocket) <- socketPair AF_UNIX Stream defaultProtocol
cHandle <- socketToHandle cSocket ReadWriteMode
sHandle <- socketToHandle sSocket ReadWriteMode
@ -135,20 +135,21 @@ setup serverState = do
startQueue <- newChan
resultQueue <- newChan
cCtx <- C.client clientState clientRNG cHandle
sCtx <- S.server serverState serverRNG sHandle
return (cHandle, sCtx, clientRNG, startQueue, resultQueue)
return (cCtx, sCtx, startQueue, resultQueue)
testInitiate spCert = do
(clientstate, serverstate) <- pick (makeValidParams spCert)
(cHandle, sCtx, clientRNG, startQueue, resultQueue) <- run (setup serverstate)
states <- pick (makeValidParams spCert)
(cCtx, sCtx, startQueue, resultQueue) <- run (setup states)
run $ forkIO $ do
catch (tlsServer sCtx resultQueue)
(\e -> putStrLn ("server exception: " ++ show e) >> throw (e :: SomeException))
return ()
run $ forkIO $ do
catch (C.runTLSClient (tlsClient startQueue cHandle) clientstate clientRNG)
catch (tlsClient startQueue cCtx)
(\e -> putStrLn ("client exception: " ++ show e) >> throw (e :: SomeException))
return ()
@ -161,7 +162,7 @@ testInitiate spCert = do
assert $ d == dres
-- cleanup
run $ (hClose cHandle >> hClose (getHandle sCtx))
run $ (hClose (getHandle cCtx) >> hClose (getHandle sCtx))
where
tlsServer handle queue = do
@ -171,7 +172,7 @@ testInitiate spCert = do
return ()
tlsClient queue handle = do
C.initiate handle
d <- lift $ readChan queue
d <- readChan queue
C.sendData handle d
return ()