modify client API to be like the server API.
This commit is contained in:
parent
5481816e0e
commit
f260c5b9cf
4 changed files with 81 additions and 150 deletions
|
@ -11,155 +11,93 @@
|
|||
-- aka. a client socket.
|
||||
--
|
||||
module Network.TLS.Client
|
||||
( TLSParams(..)
|
||||
, TLSStateClient
|
||||
, TLSClient (..)
|
||||
, runTLSClient
|
||||
-- * low level packet sending receiving.
|
||||
, recvPacket
|
||||
, sendPacket
|
||||
( client
|
||||
-- * API, warning probably subject to change
|
||||
, initiate
|
||||
, connect
|
||||
, sendData
|
||||
, recvData
|
||||
, close
|
||||
) where
|
||||
|
||||
import Data.Maybe
|
||||
import Control.Applicative ((<$>))
|
||||
import Control.Monad.Trans
|
||||
import Control.Monad.State
|
||||
import Network.TLS.Cipher
|
||||
import Network.TLS.Compression
|
||||
import Network.TLS.Struct
|
||||
import Network.TLS.Packet
|
||||
import Network.TLS.State
|
||||
import Network.TLS.Sending
|
||||
import Network.TLS.Receiving
|
||||
import Network.TLS.SRandom
|
||||
import Network.TLS.Core (TLSParams(..))
|
||||
import Network.TLS.Core
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Lazy as L
|
||||
import System.IO (Handle, hFlush)
|
||||
import Data.List (find)
|
||||
|
||||
data TLSStateClient = TLSStateClient
|
||||
{ scParams :: TLSParams -- ^ client params and config for this connection
|
||||
, scTLSState :: TLSState -- ^ client TLS State for this connection
|
||||
, scCertRequested :: Bool -- ^ mark that the server requested a certificate
|
||||
}
|
||||
client :: MonadIO m => TLSParams -> SRandomGen -> Handle -> m TLSCtx
|
||||
client params rng handle = liftIO $ newCtx handle params state
|
||||
where state = (newTLSState rng) { stClientContext = True }
|
||||
|
||||
newtype TLSClient m a = TLSClient { runTLSC :: StateT TLSStateClient m a }
|
||||
deriving (Monad, MonadState TLSStateClient)
|
||||
processServerInfo :: MonadIO m => TLSCtx -> Packet -> m ()
|
||||
processServerInfo ctx (Handshake (ServerHello ver _ _ cipher _ _)) = do
|
||||
let ciphers = pCiphers $ getParams ctx
|
||||
let allowedvers = pAllowedVersions $ getParams ctx
|
||||
|
||||
instance MonadTrans TLSClient where
|
||||
lift = TLSClient . lift
|
||||
|
||||
instance (Functor m, Monad m) => Functor (TLSClient m) where
|
||||
fmap f = TLSClient . fmap f . runTLSC
|
||||
|
||||
runTLSClientST :: TLSClient m a -> TLSStateClient -> m (a, TLSStateClient)
|
||||
runTLSClientST f s = runStateT (runTLSC f) s
|
||||
|
||||
runTLSClient :: TLSClient m a -> TLSParams -> SRandomGen -> m (a, TLSStateClient)
|
||||
runTLSClient f params rng = runTLSClientST f (TLSStateClient { scParams = params, scTLSState = state, scCertRequested = False })
|
||||
where state = (newTLSState rng) { stVersion = pConnectVersion params, stClientContext = True }
|
||||
|
||||
usingState :: Monad m => TLSSt a -> TLSClient m (Either TLSError a)
|
||||
usingState f =
|
||||
get >>= return . scTLSState >>= execAndStore
|
||||
where
|
||||
execAndStore st = do
|
||||
let (a, newst) = runTLSState f st
|
||||
modify (\stateclient -> stateclient { scTLSState = newst })
|
||||
return a
|
||||
|
||||
usingState_ f = do
|
||||
ret <- usingState f
|
||||
case ret of
|
||||
Left err -> error "assertion failed, error in path without an error"
|
||||
Right r -> return r
|
||||
|
||||
getStateRNG n = usingState_ (withTLSRNG (\rng -> getRandomBytes rng n))
|
||||
|
||||
{- | receive a single TLS packet or on error a TLSError -}
|
||||
recvPacket :: Handle -> TLSClient IO (Either TLSError [Packet])
|
||||
recvPacket handle = do
|
||||
hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader
|
||||
case hdr of
|
||||
Left err -> return $ Left err
|
||||
Right header@(Header _ _ readlen) -> do
|
||||
content <- lift $ B.hGet handle (fromIntegral readlen)
|
||||
usingState $ readPacket header (EncryptedData content)
|
||||
|
||||
{- | send a single TLS packet -}
|
||||
sendPacket :: Handle -> Packet -> TLSClient IO ()
|
||||
sendPacket handle pkt = do
|
||||
dataToSend <- usingState_ $ writePacket pkt
|
||||
lift $ B.hPut handle dataToSend
|
||||
|
||||
processServerInfo :: Packet -> TLSClient IO ()
|
||||
processServerInfo (Handshake (ServerHello ver _ _ cipher _ _)) = do
|
||||
ciphers <- pCiphers . scParams <$> get
|
||||
allowedvers <- pAllowedVersions . scParams <$> get
|
||||
case find ((==) ver) allowedvers of
|
||||
Nothing -> error ("received version which is not allowed: " ++ show ver)
|
||||
Just _ -> usingState_ $ setVersion ver
|
||||
Just _ -> usingState_ ctx $ setVersion ver
|
||||
case find ((==) cipher . cipherID) ciphers of
|
||||
Nothing -> error "no cipher in common with the server"
|
||||
Just c -> usingState_ $ setCipher c
|
||||
Just c -> usingState_ ctx $ setCipher c
|
||||
|
||||
processServerInfo (Handshake (CertRequest _ _ _)) = do
|
||||
modify (\sc -> sc { scCertRequested = True })
|
||||
processServerInfo _ (Handshake (CertRequest _ _ _)) = do
|
||||
return ()
|
||||
--modify (\sc -> sc { scCertRequested = True })
|
||||
|
||||
processServerInfo (Handshake (Certificates certs)) = do
|
||||
cb <- onCertificatesRecv . scParams <$> get
|
||||
valid <- lift $ cb certs
|
||||
processServerInfo ctx (Handshake (Certificates certs)) = do
|
||||
let cb = onCertificatesRecv $ getParams ctx
|
||||
valid <- liftIO $ cb certs
|
||||
unless valid $ error "certificates received deemed invalid by user"
|
||||
|
||||
processServerInfo _ = return ()
|
||||
processServerInfo _ _ = return ()
|
||||
|
||||
recvServerInfo :: Handle -> TLSClient IO ()
|
||||
recvServerInfo handle = do
|
||||
whileStatus (/= (StatusHandshake HsStatusServerHelloDone)) $ do
|
||||
pkts <- recvPacket handle
|
||||
recvServerInfo :: MonadIO m => TLSCtx -> m ()
|
||||
recvServerInfo ctx = do
|
||||
whileStatus ctx (/= (StatusHandshake HsStatusServerHelloDone)) $ do
|
||||
pkts <- recvPacket ctx
|
||||
case pkts of
|
||||
Left err -> error ("error received: " ++ show err)
|
||||
Right l -> forM_ l processServerInfo
|
||||
Right l -> mapM_ (processServerInfo ctx) l
|
||||
|
||||
connectSendClientHello :: MonadIO m => TLSCtx -> m ()
|
||||
connectSendClientHello ctx = do
|
||||
crand <- getStateRNG ctx 32 >>= return . fromJust . clientRandom
|
||||
sendPacket ctx $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) (map compressionID compressions) Nothing)
|
||||
where
|
||||
whileStatus p a = do
|
||||
b <- usingState_ (p . stStatus <$> get)
|
||||
when b (a >> whileStatus p a)
|
||||
params = getParams ctx
|
||||
ver = pConnectVersion params
|
||||
ciphers = pCiphers params
|
||||
compressions = pCompressions params
|
||||
|
||||
connectSendClientHello :: Handle -> TLSClient IO ()
|
||||
connectSendClientHello handle = do
|
||||
crand <- fromJust . clientRandom <$> getStateRNG 32
|
||||
ver <- pConnectVersion . scParams <$> get
|
||||
ciphers <- pCiphers . scParams <$> get
|
||||
compressions <- pCompressions . scParams <$> get
|
||||
sendPacket handle $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) (map compressionID compressions) Nothing)
|
||||
|
||||
connectSendClientCertificate :: Handle -> TLSClient IO ()
|
||||
connectSendClientCertificate handle = do
|
||||
certRequested <- scCertRequested <$> get
|
||||
connectSendClientCertificate :: MonadIO m => TLSCtx -> m ()
|
||||
connectSendClientCertificate ctx = do
|
||||
certRequested <- return False -- scCertRequested <$> get
|
||||
when certRequested $ do
|
||||
clientCerts <- map fst . pCertificates . scParams <$> get
|
||||
sendPacket handle $ Handshake (Certificates clientCerts)
|
||||
let clientCerts = map fst $ pCertificates $ getParams ctx
|
||||
sendPacket ctx $ Handshake (Certificates clientCerts)
|
||||
|
||||
connectSendClientKeyXchg :: Handle -> TLSClient IO ()
|
||||
connectSendClientKeyXchg handle = do
|
||||
prerand <- ClientKeyData <$> getStateRNG 46
|
||||
ver <- pConnectVersion . scParams <$> get
|
||||
sendPacket handle $ Handshake (ClientKeyXchg ver prerand)
|
||||
connectSendClientKeyXchg :: MonadIO m => TLSCtx -> m ()
|
||||
connectSendClientKeyXchg ctx = do
|
||||
prerand <- getStateRNG ctx 46 >>= return . ClientKeyData
|
||||
let ver = pConnectVersion $ getParams ctx
|
||||
sendPacket ctx $ Handshake (ClientKeyXchg ver prerand)
|
||||
|
||||
connectSendFinish :: Handle -> TLSClient IO ()
|
||||
connectSendFinish handle = do
|
||||
cf <- usingState_ $ getHandshakeDigest True
|
||||
sendPacket handle (Handshake $ Finished $ B.unpack cf)
|
||||
connectSendFinish :: MonadIO m => TLSCtx -> m ()
|
||||
connectSendFinish ctx = do
|
||||
cf <- usingState_ ctx $ getHandshakeDigest True
|
||||
sendPacket ctx (Handshake $ Finished $ B.unpack cf)
|
||||
|
||||
{- | initiate a new TLS connection through a handshake on a handle. -}
|
||||
initiate :: Handle -> TLSClient IO ()
|
||||
initiate :: MonadIO m => TLSCtx -> m ()
|
||||
initiate handle = do
|
||||
connectSendClientHello handle
|
||||
recvServerInfo handle
|
||||
|
@ -171,7 +109,7 @@ initiate handle = do
|
|||
{- FIXME not implemented yet -}
|
||||
|
||||
sendPacket handle (ChangeCipherSpec)
|
||||
lift $ hFlush handle
|
||||
liftIO $ hFlush $ getHandle handle
|
||||
|
||||
{- send Finished -}
|
||||
connectSendFinish handle
|
||||
|
@ -184,38 +122,32 @@ initiate handle = do
|
|||
|
||||
return ()
|
||||
|
||||
{-# DEPRECATED connect "use initiate" #-}
|
||||
connect :: Handle -> TLSClient IO ()
|
||||
connect = initiate
|
||||
|
||||
sendDataChunk :: Handle -> B.ByteString -> TLSClient IO ()
|
||||
sendDataChunk handle d =
|
||||
{- | sendData sends a bunch of data -}
|
||||
sendData :: MonadIO m => TLSCtx -> L.ByteString -> m ()
|
||||
sendData ctx dataToSend = mapM_ sendDataChunk (L.toChunks dataToSend)
|
||||
where sendDataChunk d =
|
||||
if B.length d > 16384
|
||||
then do
|
||||
let (sending, remain) = B.splitAt 16384 d
|
||||
sendPacket handle $ AppData sending
|
||||
sendDataChunk handle remain
|
||||
sendPacket ctx $ AppData sending
|
||||
sendDataChunk remain
|
||||
else
|
||||
sendPacket handle $ AppData d
|
||||
sendPacket ctx $ AppData d
|
||||
|
||||
{- | sendData sends a bunch of data -}
|
||||
sendData :: Handle -> L.ByteString -> TLSClient IO ()
|
||||
sendData handle d = mapM_ (sendDataChunk handle) (L.toChunks d)
|
||||
|
||||
{- | recvData get data out of Data packet, and automatically try to renegociate if
|
||||
- a Handshake HelloRequest is received -}
|
||||
recvData :: Handle -> TLSClient IO L.ByteString
|
||||
{- | recvData get data out of Data packet, and automatically renegociate if
|
||||
- a Handshake ClientHello is received -}
|
||||
recvData :: MonadIO m => TLSCtx -> m L.ByteString
|
||||
recvData handle = do
|
||||
pkt <- recvPacket handle
|
||||
case pkt of
|
||||
Right [AppData x] -> return $ L.fromChunks [x]
|
||||
Right [Handshake HelloRequest] -> connect handle >> recvData handle
|
||||
Right [Handshake HelloRequest] -> initiate handle >> recvData handle
|
||||
Left err -> error ("error received: " ++ show err)
|
||||
_ -> error "unexpected item"
|
||||
|
||||
{- | close a TLS connection.
|
||||
- note that it doesn't close the handle, but just signal we're going to close
|
||||
- the connection to the other side -}
|
||||
close :: Handle -> TLSClient IO ()
|
||||
close handle = do
|
||||
sendPacket handle $ Alert (AlertLevel_Warning, CloseNotify)
|
||||
close :: MonadIO m => TLSCtx -> m ()
|
||||
close ctx = sendPacket ctx $ Alert (AlertLevel_Warning, CloseNotify)
|
||||
|
|
|
@ -32,13 +32,11 @@ import Network.TLS.SRandom
|
|||
import Data.Certificate.X509
|
||||
import Data.List (intercalate)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Lazy as L
|
||||
|
||||
import Control.Applicative ((<$>))
|
||||
import Control.Concurrent.MVar
|
||||
import Control.Monad (when, unless)
|
||||
--import Control.Monad (when, unless)
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Trans (MonadIO, liftIO)
|
||||
import System.IO (Handle, hSetBuffering, BufferMode(..))
|
||||
|
||||
data TLSParams = TLSParams
|
||||
|
|
12
Stunnel.hs
12
Stunnel.hs
|
@ -47,16 +47,15 @@ readOne h = do
|
|||
Right True -> B.hGetNonBlocking h 4096
|
||||
Right False -> return B.empty
|
||||
|
||||
tlsclient :: Handle -> Handle -> C.TLSClient IO ()
|
||||
tlsclient :: Handle -> TLSCtx -> IO ()
|
||||
tlsclient srchandle dsthandle = do
|
||||
lift $ hSetBuffering dsthandle NoBuffering
|
||||
lift $ hSetBuffering srchandle NoBuffering
|
||||
hSetBuffering srchandle NoBuffering
|
||||
|
||||
C.initiate dsthandle
|
||||
|
||||
loopUntil $ do
|
||||
b <- lift $ readOne srchandle
|
||||
lift $ putStrLn ("sending " ++ show b)
|
||||
b <- readOne srchandle
|
||||
putStrLn ("sending " ++ show b)
|
||||
if B.null b
|
||||
then do
|
||||
C.close dsthandle
|
||||
|
@ -231,8 +230,9 @@ doClient pargs = do
|
|||
(StunnelSocket dst) <- connectAddressDescription dstaddr
|
||||
|
||||
dsth <- socketToHandle dst ReadWriteMode
|
||||
dstctx <- C.client clientstate rng dsth
|
||||
_ <- forkIO $ finally
|
||||
(C.runTLSClient (tlsclient srch dsth) clientstate rng >> return ())
|
||||
(tlsclient srch dstctx)
|
||||
(hClose srch >> hClose dsth)
|
||||
return ()
|
||||
AddrFD _ _ -> error "bad error fd. not implemented"
|
||||
|
|
|
@ -121,8 +121,8 @@ makeValidParams serverCerts = do
|
|||
{- | setup create all necessary connection point to create a data "pipe"
|
||||
- ---(startQueue)---> tlsClient ---(socketPair)---> tlsServer ---(resultQueue)--->
|
||||
-}
|
||||
setup :: TLSParams -> IO (Handle, TLSCtx, SRandomGen, Chan a, Chan a)
|
||||
setup serverState = do
|
||||
setup :: (TLSParams, TLSParams) -> IO (TLSCtx, TLSCtx, Chan a, Chan a)
|
||||
setup (clientState, serverState) = do
|
||||
(cSocket, sSocket) <- socketPair AF_UNIX Stream defaultProtocol
|
||||
cHandle <- socketToHandle cSocket ReadWriteMode
|
||||
sHandle <- socketToHandle sSocket ReadWriteMode
|
||||
|
@ -135,20 +135,21 @@ setup serverState = do
|
|||
startQueue <- newChan
|
||||
resultQueue <- newChan
|
||||
|
||||
cCtx <- C.client clientState clientRNG cHandle
|
||||
sCtx <- S.server serverState serverRNG sHandle
|
||||
|
||||
return (cHandle, sCtx, clientRNG, startQueue, resultQueue)
|
||||
return (cCtx, sCtx, startQueue, resultQueue)
|
||||
|
||||
testInitiate spCert = do
|
||||
(clientstate, serverstate) <- pick (makeValidParams spCert)
|
||||
(cHandle, sCtx, clientRNG, startQueue, resultQueue) <- run (setup serverstate)
|
||||
states <- pick (makeValidParams spCert)
|
||||
(cCtx, sCtx, startQueue, resultQueue) <- run (setup states)
|
||||
|
||||
run $ forkIO $ do
|
||||
catch (tlsServer sCtx resultQueue)
|
||||
(\e -> putStrLn ("server exception: " ++ show e) >> throw (e :: SomeException))
|
||||
return ()
|
||||
run $ forkIO $ do
|
||||
catch (C.runTLSClient (tlsClient startQueue cHandle) clientstate clientRNG)
|
||||
catch (tlsClient startQueue cCtx)
|
||||
(\e -> putStrLn ("client exception: " ++ show e) >> throw (e :: SomeException))
|
||||
return ()
|
||||
|
||||
|
@ -161,7 +162,7 @@ testInitiate spCert = do
|
|||
assert $ d == dres
|
||||
|
||||
-- cleanup
|
||||
run $ (hClose cHandle >> hClose (getHandle sCtx))
|
||||
run $ (hClose (getHandle cCtx) >> hClose (getHandle sCtx))
|
||||
|
||||
where
|
||||
tlsServer handle queue = do
|
||||
|
@ -171,7 +172,7 @@ testInitiate spCert = do
|
|||
return ()
|
||||
tlsClient queue handle = do
|
||||
C.initiate handle
|
||||
d <- lift $ readChan queue
|
||||
d <- readChan queue
|
||||
C.sendData handle d
|
||||
return ()
|
||||
|
||||
|
|
Loading…
Reference in a new issue