hlint
[htalkat.git] / RelayStream.hs
blobf8c24375038b5faf8c92012dea11421c4b1e28da
1 -- This file is part of htalkat
2 -- Copyright (C) 2021 Martin Bays <mbays@sdf.org>
3 --
4 -- This program is free software: you can redistribute it and/or modify
5 -- it under the terms of version 3 of the GNU General Public License as
6 -- published by the Free Software Foundation, or any later version.
7 --
8 -- You should have received a copy of the GNU General Public License
9 -- along with this program. If not, see http://www.gnu.org/licenses/.
11 {-# LANGUAGE LambdaCase #-}
12 {-# LANGUAGE OverloadedStrings #-}
14 module RelayStream where
16 import Control.Concurrent
17 import Control.Exception (SomeException, handle)
18 import Control.Monad (foldM_, forever, unless, void,
19 when)
20 import System.Timeout (timeout)
22 import qualified Data.ByteString as BS
23 import qualified Data.ByteString.Lazy as BL
24 import qualified Data.Text.Encoding.Error as T
25 import qualified Data.Text.Lazy as T
26 import qualified Data.Text.Lazy.Encoding as T
27 import qualified Network.Socket as S
28 import qualified Network.Socket.ByteString.Lazy as SL
29 import qualified Network.TLS as TLS
30 import qualified Time.System as TM
31 import qualified Time.Types as TM
33 import Mundanities
34 import TimedText
36 data WriteOrder = WriteFirst | WriteSecond
37 deriving (Eq,Ord,Show)
39 relayStream :: TLS.Context -> WriteOrder -> S.Socket -> IO ()
40 relayStream ctxt ord dSock = do
41 receivedHandshake <- newEmptyMVar
42 finished <- newEmptyMVar
43 rawInChan <- newChan
44 let abort = putMVar finished ()
45 abortOnErr = handle abortHandler where
46 abortHandler :: Monoid a => SomeException -> IO a
47 abortHandler _ = abort >> pure mempty
48 recvAll = do
49 b <- TLS.recvData ctxt
50 case BS.uncons b of
51 Nothing -> abort
52 Just (h,_) -> do
53 ok <- tryReadMVar receivedHandshake >>= \case
54 Just ok -> pure ok
55 Nothing -> do
56 let isHandshakeByte = h == introByte
57 putMVar receivedHandshake isHandshakeByte
58 if isHandshakeByte
59 then pure True
60 else abort >> pure False
61 if ok then writeChan rawInChan b >> recvAll
62 else writeChan rawInChan BS.empty
63 sendHandshake = do
64 when (ord == WriteSecond) . void $ readMVar receivedHandshake
65 TLS.sendData ctxt $ BL.singleton introByte
67 sockMV <- newEmptyMVar
68 sockThread <- forkIO $ putMVar sockMV . fst =<< S.accept dSock
70 _ <- forkIO $ do
71 sock <- readMVar sockMV
72 abortOnErr sendHandshake
73 tsOutChan <- newChan
74 rawOutChan <- newChan
75 _ <- forkIO . abortOnErr $ do
76 writeList2Chan rawOutChan . T.unpack . T.decodeUtf8With T.lenientDecode =<<
77 SL.getContents sock
78 abort
79 pausesThread <- forkIO $ insertPauses rawOutChan tsOutChan
80 abortOnErr $ sendAll tsOutChan
81 killThread pausesThread
83 _ <- forkIO $ do
84 tsInChan <- newChan
85 decodeTTThread <- forkIO $
86 writeList2Chan tsInChan . decodeTimedText . BL.fromChunks =<< getChanContents rawInChan
87 _ <- forkIO . abortOnErr $ relayTimed tsInChan =<< readMVar sockMV
88 abortOnErr recvAll
89 killThread decodeTTThread
91 _ <- takeMVar finished
92 ignoreIOErr $ TLS.bye ctxt >> killThread sockThread
93 tryTakeMVar sockMV >>= \case
94 Nothing -> pure ()
95 Just sock -> S.gracefulClose sock 1000
96 where
97 introByte = fromIntegral $ fromEnum 'T'
99 insertPauses rawChan ttChan = TM.timeCurrentP >>= insertPauses'
100 where
101 insertPauses' e = do
102 c <- readChan rawChan
103 e' <- TM.timeCurrentP
104 let ms = elapsedPToMS $ e' - e
105 when (ms > 0) . writeChan ttChan . Left $ fromIntegral ms
106 writeChan ttChan $ Right c
107 insertPauses' e'
109 sendAll ttChan = forever $ do
110 readBufMV <- newMVar []
111 _ <- timeout sendTimeout . forever $
112 modifyMVar_ readBufMV . (pure .) . (:) =<< readChan ttChan
113 rtt <- readMVar readBufMV
114 unless (null rtt) . TLS.sendData ctxt . rechunk . encodeTimedText $ reverse rtt
115 where
116 rechunk =
117 -- TLS.sendData sends one packet per chunk, while encodeTimedText
118 -- returns a chunk per char, so it's important to rechunk.
119 BL.fromStrict . BL.toStrict
120 sendTimeout = 1000 * 300
122 relayTimed chan sock = foldM_ sendTimed' Nothing =<< getChanContents chan where
123 sendTimed' :: Maybe TM.ElapsedP -> Either Int Char -> IO (Maybe TM.ElapsedP)
124 sendTimed' Nothing (Right c) = do
125 threadDelay bufferTime
126 e <- TM.timeCurrentP
127 sendTimed' (Just e) (Right c)
128 sendTimed' Nothing _ = pure Nothing
129 sendTimed' (Just e) (Right c) = do
130 SL.sendAll sock . T.encodeUtf8 $ T.singleton c
131 pure $ Just e
132 sendTimed' (Just e) (Left n) = do
133 delayed <- elapsedPToMS . flip (-) e <$> TM.timeCurrentP
134 when (n > delayed) . threadDelay . (1000 *) $ n - delayed
135 pure $ if n == pauseMax && n < delayed
136 then Nothing
137 else Just $ e + msToElapsedP n
138 bufferTime = 1000 * 300
140 msToElapsedP :: Int -> TM.ElapsedP
141 msToElapsedP ms | (s,ms') <- fromIntegral ms `divMod` 1000 =
142 TM.ElapsedP (TM.Elapsed (TM.Seconds s)) (TM.NanoSeconds $ 1000000 * ms')
143 elapsedPToMS :: TM.ElapsedP -> Int
144 elapsedPToMS (TM.ElapsedP (TM.Elapsed (TM.Seconds s)) (TM.NanoSeconds ns)) =
145 fromIntegral $ s*1000 + ns `div` 1000000