hs-tls/Stunnel.hs

262 lines
8.3 KiB
Haskell
Raw Normal View History

{-# LANGUAGE DeriveDataTypeable #-}
import Network.BSD
import Network.Socket
2010-09-09 21:47:19 +00:00
import System.IO
import System.IO.Error hiding (try)
import System.Console.CmdArgs
2010-09-19 09:50:37 +00:00
2010-09-09 21:47:19 +00:00
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
2010-09-19 09:50:37 +00:00
import Control.Concurrent (forkIO)
import Control.Exception (finally, try, throw)
import Control.Monad (when, forever)
2010-09-19 09:50:37 +00:00
import Control.Monad.Trans (lift)
import Data.Char (isDigit)
2010-09-19 09:50:37 +00:00
2010-09-09 21:47:19 +00:00
import Data.Certificate.PEM
import Data.Certificate.X509
2011-01-02 09:49:21 +00:00
import qualified Data.Certificate.KeyRSA as KeyRSA
2010-09-09 21:47:19 +00:00
2010-09-19 09:50:37 +00:00
import Network.TLS.Cipher
import Network.TLS.SRandom
import Network.TLS.Struct
import qualified Network.TLS.Client as C
import qualified Network.TLS.Server as S
2010-09-09 21:47:19 +00:00
ciphers :: [Cipher]
ciphers =
[ cipher_AES128_SHA1
, cipher_AES256_SHA1
, cipher_RC4_128_MD5
, cipher_RC4_128_SHA1
]
loopUntil :: Monad m => m Bool -> m ()
loopUntil f = f >>= \v -> if v then return () else loopUntil f
2010-09-09 21:47:19 +00:00
readOne h = do
r <- try $ hWaitForInput h (-1)
case r of
Left err -> if isEOFError err then return B.empty else throw err
Right True -> B.hGetNonBlocking h 4096
Right False -> return B.empty
2010-09-09 21:47:19 +00:00
tlsclient :: Handle -> Handle -> C.TLSClient IO ()
tlsclient srchandle dsthandle = do
lift $ hSetBuffering dsthandle NoBuffering
lift $ hSetBuffering srchandle NoBuffering
2010-09-09 21:47:19 +00:00
C.initiate dsthandle
2010-09-09 21:47:19 +00:00
loopUntil $ do
b <- lift $ readOne srchandle
lift $ putStrLn ("sending " ++ show b)
if B.null b
then do
C.close dsthandle
return True
else do
C.sendData dsthandle (L.fromChunks [b])
return False
2010-09-09 21:47:19 +00:00
return ()
getRandomGen :: IO SRandomGen
getRandomGen = makeSRandomGen >>= either (fail . show) (return . id)
tlsserver srchandle dsthandle = do
lift $ hSetBuffering dsthandle NoBuffering
lift $ hSetBuffering srchandle NoBuffering
S.listen srchandle
loopUntil $ do
d <- S.recvData srchandle
lift $ putStrLn ("received: " ++ show d)
S.sendData srchandle (L.pack $ map (toEnum . fromEnum) "this is some data")
lift $ hFlush srchandle
return False
2010-09-09 21:47:19 +00:00
lift $ putStrLn "end"
clientProcess ((certdata, cert), pk) handle dsthandle _ = do
rng <- getRandomGen
2010-09-09 21:47:19 +00:00
let serverstate = S.TLSServerParams
2010-12-14 23:26:51 +00:00
{ S.spAllowedVersions = [SSL3,TLS10,TLS11]
2010-09-09 21:47:19 +00:00
, S.spSessions = []
, S.spCiphers = ciphers
, S.spCertificate = Just (certdata, cert, pk)
, S.spWantClientCert = False
, S.spCallbacks = S.TLSServerCallbacks
{ S.cbCertificates = Nothing }
2010-09-09 21:47:19 +00:00
}
S.runTLSServer (tlsserver handle dsthandle) serverstate rng
2010-09-09 21:47:19 +00:00
readCertificate :: FilePath -> IO (B.ByteString, Certificate)
2010-09-09 21:47:19 +00:00
readCertificate filepath = do
content <- B.readFile filepath
let certdata = case parsePEMCert content of
2010-10-03 09:32:37 +00:00
Nothing -> error ("no valid certificate section")
Just x -> x
let cert = case decodeCertificate $ L.fromChunks [certdata] of
2010-09-09 21:47:19 +00:00
Left err -> error ("cannot decode certificate: " ++ err)
Right x -> x
return (certdata, cert)
2011-01-02 09:49:21 +00:00
readPrivateKey :: FilePath -> IO (L.ByteString, KeyRSA.Private)
2010-09-09 21:47:19 +00:00
readPrivateKey filepath = do
content <- B.readFile filepath
2010-10-03 09:32:37 +00:00
let pkdata = case parsePEMKeyRSA content of
Nothing -> error ("no valid RSA key section")
Just x -> L.fromChunks [x]
2011-01-02 09:49:21 +00:00
let pk = case KeyRSA.decodePrivate pkdata of
2010-09-09 21:47:19 +00:00
Left err -> error ("cannot decode key: " ++ err)
Right x -> x
return (pkdata, pk)
data Stunnel =
Client
{ destinationType :: String
, destination :: String
, sourceType :: String
, source :: String }
| Server
{ destinationType :: String
, destination :: String
, sourceType :: String
, source :: String
, certificate :: FilePath
, key :: FilePath }
deriving (Show, Data, Typeable)
clientOpts = Client
{ destinationType = "tcp" &= help "type of source (tcp, unix, fd)" &= typ "DESTTYPE"
, destination = "localhost:6061" &= help "destination address influenced by destination type" &= typ "ADDRESS"
, sourceType = "tcp" &= help "type of source (tcp, unix, fd)" &= typ "SOURCETYPE"
, source = "localhost:6060" &= help "source address influenced by source type" &= typ "ADDRESS"
}
&= help "connect to a remote destination that use SSL/TLS"
serverOpts = Server
{ destinationType = "tcp" &= help "type of source (tcp, unix, fd)" &= typ "DESTTYPE"
, destination = "localhost:6060" &= help "destination address influenced by destination type" &= typ "ADDRESS"
, sourceType = "tcp" &= help "type of source (tcp, unix, fd)" &= typ "SOURCETYPE"
, source = "localhost:6061" &= help "source address influenced by source type" &= typ "ADDRESS"
, certificate = "certificate.pem" &= help "X509 public certificate to use" &= typ "FILE"
, key = "certificate.key" &= help "private key linked to the certificate" &= typ "FILE"
}
&= help "listen for connection that use SSL/TLS and relay it to a different connection"
mode = cmdArgsMode $ modes [clientOpts,serverOpts]
&= help "create SSL/TLS tunnel in client or server mode" &= program "stunnel" &= summary "Stunnel v0.1 (Haskell TLS)"
data StunnelAddr =
AddrSocket Family SockAddr
| AddrFD Handle Handle
data StunnelHandle =
StunnelSocket Socket
| StunnelFd Handle Handle
getAddressDescription :: String -> String -> IO StunnelAddr
getAddressDescription "tcp" desc = do
let (s, p) = break ((==) ':') desc
when (p == "") (error "missing port: expecting [source]:port")
pn <- if and $ map isDigit $ drop 1 p
then return $ fromIntegral $ (read (drop 1 p) :: Int)
else do
service <- getServiceByName (drop 1 p) "tcp"
return $ servicePort service
he <- getHostByName s
return $ AddrSocket AF_INET (SockAddrInet pn (head $ hostAddresses he))
getAddressDescription "unix" desc = do
return $ AddrSocket AF_UNIX (SockAddrUnix desc)
getAddressDescription "fd" _ =
return $ AddrFD stdin stdout
getAddressDescription _ _ = error "unrecognized source type (expecting tcp/unix/fd)"
connectAddressDescription (AddrSocket family sockaddr) = do
sock <- socket family Stream defaultProtocol
catch (connect sock sockaddr)
(\_ -> sClose sock >> error ("cannot open socket " ++ show sockaddr))
return $ StunnelSocket sock
connectAddressDescription (AddrFD h1 h2) = do
return $ StunnelFd h1 h2
listenAddressDescription (AddrSocket family sockaddr) = do
sock <- socket family Stream defaultProtocol
catch (bindSocket sock sockaddr >> listen sock 10 >> setSocketOption sock ReuseAddr 1)
(\_ -> sClose sock >> error ("cannot open socket " ++ show sockaddr))
return $ StunnelSocket sock
listenAddressDescription (AddrFD _ _) = do
error "cannot listen on fd"
doClient :: Stunnel -> IO ()
doClient pargs = do
srcaddr <- getAddressDescription (sourceType pargs) (source pargs)
dstaddr <- getAddressDescription (destinationType pargs) (destination pargs)
let clientstate = C.TLSClientParams
{ C.cpConnectVersion = TLS10
, C.cpAllowedVersions = [ TLS10, TLS11 ]
, C.cpSession = Nothing
, C.cpCiphers = ciphers
, C.cpCertificate = Nothing
, C.cpCallbacks = C.TLSClientCallbacks
{ C.cbCertificates = Nothing
}
}
case srcaddr of
AddrSocket _ _ -> do
(StunnelSocket srcsocket) <- listenAddressDescription srcaddr
forever $ do
(s, _) <- accept srcsocket
rng <- getRandomGen
srch <- socketToHandle s ReadWriteMode
(StunnelSocket dst) <- connectAddressDescription dstaddr
dsth <- socketToHandle dst ReadWriteMode
_ <- forkIO $ finally
(C.runTLSClient (tlsclient srch dsth) clientstate rng >> return ())
(hClose srch >> hClose dsth)
return ()
AddrFD _ _ -> error "bad error fd. not implemented"
doServer :: Stunnel -> IO ()
doServer pargs = do
cert <- readCertificate $ certificate pargs
pk <- readPrivateKey $ key pargs
srcaddr <- getAddressDescription (sourceType pargs) (source pargs)
dstaddr <- getAddressDescription (destinationType pargs) (destination pargs)
case srcaddr of
AddrSocket _ _ -> do
(StunnelSocket srcsocket) <- listenAddressDescription srcaddr
forever $ do
(s, addr) <- accept srcsocket
srch <- socketToHandle s ReadWriteMode
(StunnelSocket dst) <- connectAddressDescription dstaddr
dsth <- socketToHandle dst ReadWriteMode
_ <- forkIO $ finally
(clientProcess (cert, snd pk) srch dsth addr >> return ())
(hClose srch >> hClose dsth)
return ()
AddrFD _ _ -> error "bad error fd. not implemented"
main :: IO ()
2010-09-09 21:47:19 +00:00
main = do
x <- cmdArgsRun mode
case x of
Client _ _ _ _ -> doClient x
Server _ _ _ _ _ _ -> doServer x