modify client API to be like the server API.
This commit is contained in:
parent
5481816e0e
commit
f260c5b9cf
4 changed files with 81 additions and 150 deletions
|
@ -11,155 +11,93 @@
|
||||||
-- aka. a client socket.
|
-- aka. a client socket.
|
||||||
--
|
--
|
||||||
module Network.TLS.Client
|
module Network.TLS.Client
|
||||||
( TLSParams(..)
|
( client
|
||||||
, TLSStateClient
|
|
||||||
, TLSClient (..)
|
|
||||||
, runTLSClient
|
|
||||||
-- * low level packet sending receiving.
|
|
||||||
, recvPacket
|
|
||||||
, sendPacket
|
|
||||||
-- * API, warning probably subject to change
|
-- * API, warning probably subject to change
|
||||||
, initiate
|
, initiate
|
||||||
, connect
|
|
||||||
, sendData
|
, sendData
|
||||||
, recvData
|
, recvData
|
||||||
, close
|
, close
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.Maybe
|
import Data.Maybe
|
||||||
import Control.Applicative ((<$>))
|
|
||||||
import Control.Monad.Trans
|
import Control.Monad.Trans
|
||||||
import Control.Monad.State
|
import Control.Monad.State
|
||||||
import Network.TLS.Cipher
|
import Network.TLS.Cipher
|
||||||
import Network.TLS.Compression
|
import Network.TLS.Compression
|
||||||
import Network.TLS.Struct
|
import Network.TLS.Struct
|
||||||
import Network.TLS.Packet
|
|
||||||
import Network.TLS.State
|
import Network.TLS.State
|
||||||
import Network.TLS.Sending
|
|
||||||
import Network.TLS.Receiving
|
|
||||||
import Network.TLS.SRandom
|
import Network.TLS.SRandom
|
||||||
import Network.TLS.Core (TLSParams(..))
|
import Network.TLS.Core
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Lazy as L
|
import qualified Data.ByteString.Lazy as L
|
||||||
import System.IO (Handle, hFlush)
|
import System.IO (Handle, hFlush)
|
||||||
import Data.List (find)
|
import Data.List (find)
|
||||||
|
|
||||||
data TLSStateClient = TLSStateClient
|
client :: MonadIO m => TLSParams -> SRandomGen -> Handle -> m TLSCtx
|
||||||
{ scParams :: TLSParams -- ^ client params and config for this connection
|
client params rng handle = liftIO $ newCtx handle params state
|
||||||
, scTLSState :: TLSState -- ^ client TLS State for this connection
|
where state = (newTLSState rng) { stClientContext = True }
|
||||||
, scCertRequested :: Bool -- ^ mark that the server requested a certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
newtype TLSClient m a = TLSClient { runTLSC :: StateT TLSStateClient m a }
|
processServerInfo :: MonadIO m => TLSCtx -> Packet -> m ()
|
||||||
deriving (Monad, MonadState TLSStateClient)
|
processServerInfo ctx (Handshake (ServerHello ver _ _ cipher _ _)) = do
|
||||||
|
let ciphers = pCiphers $ getParams ctx
|
||||||
|
let allowedvers = pAllowedVersions $ getParams ctx
|
||||||
|
|
||||||
instance MonadTrans TLSClient where
|
|
||||||
lift = TLSClient . lift
|
|
||||||
|
|
||||||
instance (Functor m, Monad m) => Functor (TLSClient m) where
|
|
||||||
fmap f = TLSClient . fmap f . runTLSC
|
|
||||||
|
|
||||||
runTLSClientST :: TLSClient m a -> TLSStateClient -> m (a, TLSStateClient)
|
|
||||||
runTLSClientST f s = runStateT (runTLSC f) s
|
|
||||||
|
|
||||||
runTLSClient :: TLSClient m a -> TLSParams -> SRandomGen -> m (a, TLSStateClient)
|
|
||||||
runTLSClient f params rng = runTLSClientST f (TLSStateClient { scParams = params, scTLSState = state, scCertRequested = False })
|
|
||||||
where state = (newTLSState rng) { stVersion = pConnectVersion params, stClientContext = True }
|
|
||||||
|
|
||||||
usingState :: Monad m => TLSSt a -> TLSClient m (Either TLSError a)
|
|
||||||
usingState f =
|
|
||||||
get >>= return . scTLSState >>= execAndStore
|
|
||||||
where
|
|
||||||
execAndStore st = do
|
|
||||||
let (a, newst) = runTLSState f st
|
|
||||||
modify (\stateclient -> stateclient { scTLSState = newst })
|
|
||||||
return a
|
|
||||||
|
|
||||||
usingState_ f = do
|
|
||||||
ret <- usingState f
|
|
||||||
case ret of
|
|
||||||
Left err -> error "assertion failed, error in path without an error"
|
|
||||||
Right r -> return r
|
|
||||||
|
|
||||||
getStateRNG n = usingState_ (withTLSRNG (\rng -> getRandomBytes rng n))
|
|
||||||
|
|
||||||
{- | receive a single TLS packet or on error a TLSError -}
|
|
||||||
recvPacket :: Handle -> TLSClient IO (Either TLSError [Packet])
|
|
||||||
recvPacket handle = do
|
|
||||||
hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader
|
|
||||||
case hdr of
|
|
||||||
Left err -> return $ Left err
|
|
||||||
Right header@(Header _ _ readlen) -> do
|
|
||||||
content <- lift $ B.hGet handle (fromIntegral readlen)
|
|
||||||
usingState $ readPacket header (EncryptedData content)
|
|
||||||
|
|
||||||
{- | send a single TLS packet -}
|
|
||||||
sendPacket :: Handle -> Packet -> TLSClient IO ()
|
|
||||||
sendPacket handle pkt = do
|
|
||||||
dataToSend <- usingState_ $ writePacket pkt
|
|
||||||
lift $ B.hPut handle dataToSend
|
|
||||||
|
|
||||||
processServerInfo :: Packet -> TLSClient IO ()
|
|
||||||
processServerInfo (Handshake (ServerHello ver _ _ cipher _ _)) = do
|
|
||||||
ciphers <- pCiphers . scParams <$> get
|
|
||||||
allowedvers <- pAllowedVersions . scParams <$> get
|
|
||||||
case find ((==) ver) allowedvers of
|
case find ((==) ver) allowedvers of
|
||||||
Nothing -> error ("received version which is not allowed: " ++ show ver)
|
Nothing -> error ("received version which is not allowed: " ++ show ver)
|
||||||
Just _ -> usingState_ $ setVersion ver
|
Just _ -> usingState_ ctx $ setVersion ver
|
||||||
case find ((==) cipher . cipherID) ciphers of
|
case find ((==) cipher . cipherID) ciphers of
|
||||||
Nothing -> error "no cipher in common with the server"
|
Nothing -> error "no cipher in common with the server"
|
||||||
Just c -> usingState_ $ setCipher c
|
Just c -> usingState_ ctx $ setCipher c
|
||||||
|
|
||||||
processServerInfo (Handshake (CertRequest _ _ _)) = do
|
processServerInfo _ (Handshake (CertRequest _ _ _)) = do
|
||||||
modify (\sc -> sc { scCertRequested = True })
|
return ()
|
||||||
|
--modify (\sc -> sc { scCertRequested = True })
|
||||||
|
|
||||||
processServerInfo (Handshake (Certificates certs)) = do
|
processServerInfo ctx (Handshake (Certificates certs)) = do
|
||||||
cb <- onCertificatesRecv . scParams <$> get
|
let cb = onCertificatesRecv $ getParams ctx
|
||||||
valid <- lift $ cb certs
|
valid <- liftIO $ cb certs
|
||||||
unless valid $ error "certificates received deemed invalid by user"
|
unless valid $ error "certificates received deemed invalid by user"
|
||||||
|
|
||||||
processServerInfo _ = return ()
|
processServerInfo _ _ = return ()
|
||||||
|
|
||||||
recvServerInfo :: Handle -> TLSClient IO ()
|
recvServerInfo :: MonadIO m => TLSCtx -> m ()
|
||||||
recvServerInfo handle = do
|
recvServerInfo ctx = do
|
||||||
whileStatus (/= (StatusHandshake HsStatusServerHelloDone)) $ do
|
whileStatus ctx (/= (StatusHandshake HsStatusServerHelloDone)) $ do
|
||||||
pkts <- recvPacket handle
|
pkts <- recvPacket ctx
|
||||||
case pkts of
|
case pkts of
|
||||||
Left err -> error ("error received: " ++ show err)
|
Left err -> error ("error received: " ++ show err)
|
||||||
Right l -> forM_ l processServerInfo
|
Right l -> mapM_ (processServerInfo ctx) l
|
||||||
|
|
||||||
|
connectSendClientHello :: MonadIO m => TLSCtx -> m ()
|
||||||
|
connectSendClientHello ctx = do
|
||||||
|
crand <- getStateRNG ctx 32 >>= return . fromJust . clientRandom
|
||||||
|
sendPacket ctx $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) (map compressionID compressions) Nothing)
|
||||||
where
|
where
|
||||||
whileStatus p a = do
|
params = getParams ctx
|
||||||
b <- usingState_ (p . stStatus <$> get)
|
ver = pConnectVersion params
|
||||||
when b (a >> whileStatus p a)
|
ciphers = pCiphers params
|
||||||
|
compressions = pCompressions params
|
||||||
|
|
||||||
connectSendClientHello :: Handle -> TLSClient IO ()
|
connectSendClientCertificate :: MonadIO m => TLSCtx -> m ()
|
||||||
connectSendClientHello handle = do
|
connectSendClientCertificate ctx = do
|
||||||
crand <- fromJust . clientRandom <$> getStateRNG 32
|
certRequested <- return False -- scCertRequested <$> get
|
||||||
ver <- pConnectVersion . scParams <$> get
|
|
||||||
ciphers <- pCiphers . scParams <$> get
|
|
||||||
compressions <- pCompressions . scParams <$> get
|
|
||||||
sendPacket handle $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) (map compressionID compressions) Nothing)
|
|
||||||
|
|
||||||
connectSendClientCertificate :: Handle -> TLSClient IO ()
|
|
||||||
connectSendClientCertificate handle = do
|
|
||||||
certRequested <- scCertRequested <$> get
|
|
||||||
when certRequested $ do
|
when certRequested $ do
|
||||||
clientCerts <- map fst . pCertificates . scParams <$> get
|
let clientCerts = map fst $ pCertificates $ getParams ctx
|
||||||
sendPacket handle $ Handshake (Certificates clientCerts)
|
sendPacket ctx $ Handshake (Certificates clientCerts)
|
||||||
|
|
||||||
connectSendClientKeyXchg :: Handle -> TLSClient IO ()
|
connectSendClientKeyXchg :: MonadIO m => TLSCtx -> m ()
|
||||||
connectSendClientKeyXchg handle = do
|
connectSendClientKeyXchg ctx = do
|
||||||
prerand <- ClientKeyData <$> getStateRNG 46
|
prerand <- getStateRNG ctx 46 >>= return . ClientKeyData
|
||||||
ver <- pConnectVersion . scParams <$> get
|
let ver = pConnectVersion $ getParams ctx
|
||||||
sendPacket handle $ Handshake (ClientKeyXchg ver prerand)
|
sendPacket ctx $ Handshake (ClientKeyXchg ver prerand)
|
||||||
|
|
||||||
connectSendFinish :: Handle -> TLSClient IO ()
|
connectSendFinish :: MonadIO m => TLSCtx -> m ()
|
||||||
connectSendFinish handle = do
|
connectSendFinish ctx = do
|
||||||
cf <- usingState_ $ getHandshakeDigest True
|
cf <- usingState_ ctx $ getHandshakeDigest True
|
||||||
sendPacket handle (Handshake $ Finished $ B.unpack cf)
|
sendPacket ctx (Handshake $ Finished $ B.unpack cf)
|
||||||
|
|
||||||
{- | initiate a new TLS connection through a handshake on a handle. -}
|
{- | initiate a new TLS connection through a handshake on a handle. -}
|
||||||
initiate :: Handle -> TLSClient IO ()
|
initiate :: MonadIO m => TLSCtx -> m ()
|
||||||
initiate handle = do
|
initiate handle = do
|
||||||
connectSendClientHello handle
|
connectSendClientHello handle
|
||||||
recvServerInfo handle
|
recvServerInfo handle
|
||||||
|
@ -171,7 +109,7 @@ initiate handle = do
|
||||||
{- FIXME not implemented yet -}
|
{- FIXME not implemented yet -}
|
||||||
|
|
||||||
sendPacket handle (ChangeCipherSpec)
|
sendPacket handle (ChangeCipherSpec)
|
||||||
lift $ hFlush handle
|
liftIO $ hFlush $ getHandle handle
|
||||||
|
|
||||||
{- send Finished -}
|
{- send Finished -}
|
||||||
connectSendFinish handle
|
connectSendFinish handle
|
||||||
|
@ -184,38 +122,32 @@ initiate handle = do
|
||||||
|
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
{-# DEPRECATED connect "use initiate" #-}
|
{- | sendData sends a bunch of data -}
|
||||||
connect :: Handle -> TLSClient IO ()
|
sendData :: MonadIO m => TLSCtx -> L.ByteString -> m ()
|
||||||
connect = initiate
|
sendData ctx dataToSend = mapM_ sendDataChunk (L.toChunks dataToSend)
|
||||||
|
where sendDataChunk d =
|
||||||
sendDataChunk :: Handle -> B.ByteString -> TLSClient IO ()
|
|
||||||
sendDataChunk handle d =
|
|
||||||
if B.length d > 16384
|
if B.length d > 16384
|
||||||
then do
|
then do
|
||||||
let (sending, remain) = B.splitAt 16384 d
|
let (sending, remain) = B.splitAt 16384 d
|
||||||
sendPacket handle $ AppData sending
|
sendPacket ctx $ AppData sending
|
||||||
sendDataChunk handle remain
|
sendDataChunk remain
|
||||||
else
|
else
|
||||||
sendPacket handle $ AppData d
|
sendPacket ctx $ AppData d
|
||||||
|
|
||||||
{- | sendData sends a bunch of data -}
|
|
||||||
sendData :: Handle -> L.ByteString -> TLSClient IO ()
|
|
||||||
sendData handle d = mapM_ (sendDataChunk handle) (L.toChunks d)
|
|
||||||
|
|
||||||
{- | recvData get data out of Data packet, and automatically try to renegociate if
|
{- | recvData get data out of Data packet, and automatically renegociate if
|
||||||
- a Handshake HelloRequest is received -}
|
- a Handshake ClientHello is received -}
|
||||||
recvData :: Handle -> TLSClient IO L.ByteString
|
recvData :: MonadIO m => TLSCtx -> m L.ByteString
|
||||||
recvData handle = do
|
recvData handle = do
|
||||||
pkt <- recvPacket handle
|
pkt <- recvPacket handle
|
||||||
case pkt of
|
case pkt of
|
||||||
Right [AppData x] -> return $ L.fromChunks [x]
|
Right [AppData x] -> return $ L.fromChunks [x]
|
||||||
Right [Handshake HelloRequest] -> connect handle >> recvData handle
|
Right [Handshake HelloRequest] -> initiate handle >> recvData handle
|
||||||
Left err -> error ("error received: " ++ show err)
|
Left err -> error ("error received: " ++ show err)
|
||||||
_ -> error "unexpected item"
|
_ -> error "unexpected item"
|
||||||
|
|
||||||
{- | close a TLS connection.
|
{- | close a TLS connection.
|
||||||
- note that it doesn't close the handle, but just signal we're going to close
|
- note that it doesn't close the handle, but just signal we're going to close
|
||||||
- the connection to the other side -}
|
- the connection to the other side -}
|
||||||
close :: Handle -> TLSClient IO ()
|
close :: MonadIO m => TLSCtx -> m ()
|
||||||
close handle = do
|
close ctx = sendPacket ctx $ Alert (AlertLevel_Warning, CloseNotify)
|
||||||
sendPacket handle $ Alert (AlertLevel_Warning, CloseNotify)
|
|
||||||
|
|
|
@ -32,13 +32,11 @@ import Network.TLS.SRandom
|
||||||
import Data.Certificate.X509
|
import Data.Certificate.X509
|
||||||
import Data.List (intercalate)
|
import Data.List (intercalate)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Lazy as L
|
|
||||||
|
|
||||||
import Control.Applicative ((<$>))
|
import Control.Applicative ((<$>))
|
||||||
import Control.Concurrent.MVar
|
import Control.Concurrent.MVar
|
||||||
import Control.Monad (when, unless)
|
--import Control.Monad (when, unless)
|
||||||
import Control.Monad.State
|
import Control.Monad.State
|
||||||
import Control.Monad.Trans (MonadIO, liftIO)
|
|
||||||
import System.IO (Handle, hSetBuffering, BufferMode(..))
|
import System.IO (Handle, hSetBuffering, BufferMode(..))
|
||||||
|
|
||||||
data TLSParams = TLSParams
|
data TLSParams = TLSParams
|
||||||
|
|
12
Stunnel.hs
12
Stunnel.hs
|
@ -47,16 +47,15 @@ readOne h = do
|
||||||
Right True -> B.hGetNonBlocking h 4096
|
Right True -> B.hGetNonBlocking h 4096
|
||||||
Right False -> return B.empty
|
Right False -> return B.empty
|
||||||
|
|
||||||
tlsclient :: Handle -> Handle -> C.TLSClient IO ()
|
tlsclient :: Handle -> TLSCtx -> IO ()
|
||||||
tlsclient srchandle dsthandle = do
|
tlsclient srchandle dsthandle = do
|
||||||
lift $ hSetBuffering dsthandle NoBuffering
|
hSetBuffering srchandle NoBuffering
|
||||||
lift $ hSetBuffering srchandle NoBuffering
|
|
||||||
|
|
||||||
C.initiate dsthandle
|
C.initiate dsthandle
|
||||||
|
|
||||||
loopUntil $ do
|
loopUntil $ do
|
||||||
b <- lift $ readOne srchandle
|
b <- readOne srchandle
|
||||||
lift $ putStrLn ("sending " ++ show b)
|
putStrLn ("sending " ++ show b)
|
||||||
if B.null b
|
if B.null b
|
||||||
then do
|
then do
|
||||||
C.close dsthandle
|
C.close dsthandle
|
||||||
|
@ -231,8 +230,9 @@ doClient pargs = do
|
||||||
(StunnelSocket dst) <- connectAddressDescription dstaddr
|
(StunnelSocket dst) <- connectAddressDescription dstaddr
|
||||||
|
|
||||||
dsth <- socketToHandle dst ReadWriteMode
|
dsth <- socketToHandle dst ReadWriteMode
|
||||||
|
dstctx <- C.client clientstate rng dsth
|
||||||
_ <- forkIO $ finally
|
_ <- forkIO $ finally
|
||||||
(C.runTLSClient (tlsclient srch dsth) clientstate rng >> return ())
|
(tlsclient srch dstctx)
|
||||||
(hClose srch >> hClose dsth)
|
(hClose srch >> hClose dsth)
|
||||||
return ()
|
return ()
|
||||||
AddrFD _ _ -> error "bad error fd. not implemented"
|
AddrFD _ _ -> error "bad error fd. not implemented"
|
||||||
|
|
|
@ -121,8 +121,8 @@ makeValidParams serverCerts = do
|
||||||
{- | setup create all necessary connection point to create a data "pipe"
|
{- | setup create all necessary connection point to create a data "pipe"
|
||||||
- ---(startQueue)---> tlsClient ---(socketPair)---> tlsServer ---(resultQueue)--->
|
- ---(startQueue)---> tlsClient ---(socketPair)---> tlsServer ---(resultQueue)--->
|
||||||
-}
|
-}
|
||||||
setup :: TLSParams -> IO (Handle, TLSCtx, SRandomGen, Chan a, Chan a)
|
setup :: (TLSParams, TLSParams) -> IO (TLSCtx, TLSCtx, Chan a, Chan a)
|
||||||
setup serverState = do
|
setup (clientState, serverState) = do
|
||||||
(cSocket, sSocket) <- socketPair AF_UNIX Stream defaultProtocol
|
(cSocket, sSocket) <- socketPair AF_UNIX Stream defaultProtocol
|
||||||
cHandle <- socketToHandle cSocket ReadWriteMode
|
cHandle <- socketToHandle cSocket ReadWriteMode
|
||||||
sHandle <- socketToHandle sSocket ReadWriteMode
|
sHandle <- socketToHandle sSocket ReadWriteMode
|
||||||
|
@ -135,20 +135,21 @@ setup serverState = do
|
||||||
startQueue <- newChan
|
startQueue <- newChan
|
||||||
resultQueue <- newChan
|
resultQueue <- newChan
|
||||||
|
|
||||||
|
cCtx <- C.client clientState clientRNG cHandle
|
||||||
sCtx <- S.server serverState serverRNG sHandle
|
sCtx <- S.server serverState serverRNG sHandle
|
||||||
|
|
||||||
return (cHandle, sCtx, clientRNG, startQueue, resultQueue)
|
return (cCtx, sCtx, startQueue, resultQueue)
|
||||||
|
|
||||||
testInitiate spCert = do
|
testInitiate spCert = do
|
||||||
(clientstate, serverstate) <- pick (makeValidParams spCert)
|
states <- pick (makeValidParams spCert)
|
||||||
(cHandle, sCtx, clientRNG, startQueue, resultQueue) <- run (setup serverstate)
|
(cCtx, sCtx, startQueue, resultQueue) <- run (setup states)
|
||||||
|
|
||||||
run $ forkIO $ do
|
run $ forkIO $ do
|
||||||
catch (tlsServer sCtx resultQueue)
|
catch (tlsServer sCtx resultQueue)
|
||||||
(\e -> putStrLn ("server exception: " ++ show e) >> throw (e :: SomeException))
|
(\e -> putStrLn ("server exception: " ++ show e) >> throw (e :: SomeException))
|
||||||
return ()
|
return ()
|
||||||
run $ forkIO $ do
|
run $ forkIO $ do
|
||||||
catch (C.runTLSClient (tlsClient startQueue cHandle) clientstate clientRNG)
|
catch (tlsClient startQueue cCtx)
|
||||||
(\e -> putStrLn ("client exception: " ++ show e) >> throw (e :: SomeException))
|
(\e -> putStrLn ("client exception: " ++ show e) >> throw (e :: SomeException))
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
|
@ -161,7 +162,7 @@ testInitiate spCert = do
|
||||||
assert $ d == dres
|
assert $ d == dres
|
||||||
|
|
||||||
-- cleanup
|
-- cleanup
|
||||||
run $ (hClose cHandle >> hClose (getHandle sCtx))
|
run $ (hClose (getHandle cCtx) >> hClose (getHandle sCtx))
|
||||||
|
|
||||||
where
|
where
|
||||||
tlsServer handle queue = do
|
tlsServer handle queue = do
|
||||||
|
@ -171,7 +172,7 @@ testInitiate spCert = do
|
||||||
return ()
|
return ()
|
||||||
tlsClient queue handle = do
|
tlsClient queue handle = do
|
||||||
C.initiate handle
|
C.initiate handle
|
||||||
d <- lift $ readChan queue
|
d <- readChan queue
|
||||||
C.sendData handle d
|
C.sendData handle d
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue