bump 0.1.2.5
[htalkat.git] / RelayStream.hs
blob48261dc9df929006c5a74c7073138bc260531a1d
1 -- This file is part of htalkat
2 -- Copyright (C) 2021 Martin Bays <mbays@sdf.org>
3 --
4 -- This program is free software: you can redistribute it and/or modify
5 -- it under the terms of version 3 of the GNU General Public License as
6 -- published by the Free Software Foundation, or any later version.
7 --
8 -- You should have received a copy of the GNU General Public License
9 -- along with this program. If not, see http://www.gnu.org/licenses/.
11 {-# LANGUAGE LambdaCase #-}
13 module RelayStream where
15 import Control.Concurrent
16 import Control.Exception (SomeException, handle)
17 import Control.Monad (foldM_, forever, unless, void,
18 when)
19 import System.Timeout (timeout)
21 import qualified Data.ByteString as BS
22 import qualified Data.ByteString.Lazy as BL
23 import qualified Data.Text.Encoding.Error as T
24 import qualified Data.Text.Lazy as T
25 import qualified Data.Text.Lazy.Encoding as T
26 import qualified Network.Socket as S
27 import qualified Network.Socket.ByteString.Lazy as SL
28 import qualified Network.TLS as TLS
29 import qualified Time.System as TM
30 import qualified Time.Types as TM
32 import Mundanities
33 import TimedText
35 data WriteOrder = WriteFirst | WriteSecond
36 deriving (Eq,Ord,Show)
38 relayStream :: TLS.Context -> WriteOrder -> S.Socket -> IO ()
39 relayStream ctxt ord dSock = do
40 receivedHandshake <- newEmptyMVar
41 finished <- newEmptyMVar
42 rawInChan <- newChan
43 let abort = putMVar finished ()
44 abortOnErr = handle abortHandler where
45 abortHandler :: Monoid a => SomeException -> IO a
46 abortHandler _ = abort >> pure mempty
47 recvAll = do
48 b <- TLS.recvData ctxt
49 case BS.uncons b of
50 Nothing -> abort
51 Just (h,_) -> do
52 ok <- tryReadMVar receivedHandshake >>= \case
53 Just ok -> pure ok
54 Nothing -> do
55 let isHandshakeByte = h == introByte
56 putMVar receivedHandshake isHandshakeByte
57 if isHandshakeByte
58 then pure True
59 else abort >> pure False
60 if ok then writeChan rawInChan b >> recvAll
61 else writeChan rawInChan BS.empty
62 sendHandshake = do
63 when (ord == WriteSecond) . void $ readMVar receivedHandshake
64 TLS.sendData ctxt $ BL.singleton introByte
66 sockMV <- newEmptyMVar
67 sockThread <- forkIO $ putMVar sockMV . fst =<< S.accept dSock
69 _ <- forkIO $ do
70 sock <- readMVar sockMV
71 abortOnErr sendHandshake
72 tsOutChan <- newChan
73 rawOutChan <- newChan
74 _ <- forkIO . abortOnErr $ do
75 writeList2Chan rawOutChan . T.unpack . T.decodeUtf8With T.lenientDecode =<<
76 SL.getContents sock
77 abort
78 pausesThread <- forkIO $ insertPauses rawOutChan tsOutChan
79 abortOnErr $ sendAll tsOutChan
80 killThread pausesThread
82 _ <- forkIO $ do
83 tsInChan <- newChan
84 decodeTTThread <- forkIO $
85 writeList2Chan tsInChan . decodeTimedText . BL.fromChunks =<< getChanContents rawInChan
86 _ <- forkIO . abortOnErr $ relayTimed tsInChan =<< readMVar sockMV
87 abortOnErr recvAll
88 killThread decodeTTThread
90 _ <- takeMVar finished
91 ignoreIOErr $ TLS.bye ctxt >> killThread sockThread
92 tryTakeMVar sockMV >>= \case
93 Nothing -> pure ()
94 Just sock -> S.gracefulClose sock 1000
95 where
96 introByte = fromIntegral $ fromEnum 'T'
98 insertPauses rawChan ttChan = TM.timeCurrentP >>= insertPauses'
99 where
100 insertPauses' e = do
101 c <- readChan rawChan
102 e' <- TM.timeCurrentP
103 let ms = elapsedPToMS $ e' - e
104 when (ms > 0) . writeChan ttChan . Left $ fromIntegral ms
105 writeChan ttChan $ Right c
106 insertPauses' e'
108 sendAll ttChan = forever $ do
109 readBufMV <- newMVar []
110 _ <- timeout sendTimeout . forever $
111 modifyMVar_ readBufMV . (pure .) . (:) =<< readChan ttChan
112 rtt <- readMVar readBufMV
113 unless (null rtt) . TLS.sendData ctxt . rechunk . encodeTimedText $ reverse rtt
114 where
115 rechunk =
116 -- TLS.sendData sends one packet per chunk, while encodeTimedText
117 -- returns a chunk per char, so it's important to rechunk.
118 BL.fromStrict . BL.toStrict
119 sendTimeout = 1000 * 300
121 relayTimed chan sock = foldM_ sendTimed' Nothing =<< getChanContents chan where
122 sendTimed' :: Maybe TM.ElapsedP -> Either Int Char -> IO (Maybe TM.ElapsedP)
123 sendTimed' Nothing (Right c) = do
124 threadDelay bufferTime
125 e <- TM.timeCurrentP
126 sendTimed' (Just e) (Right c)
127 sendTimed' Nothing _ = pure Nothing
128 sendTimed' (Just e) (Right c) = do
129 SL.sendAll sock . T.encodeUtf8 $ T.singleton c
130 pure $ Just e
131 sendTimed' (Just e) (Left n) = do
132 delayed <- elapsedPToMS . flip (-) e <$> TM.timeCurrentP
133 when (n > delayed) . threadDelay . (1000 *) $ n - delayed
134 pure $ if n == pauseMax && n < delayed
135 then Nothing
136 else Just $ e + msToElapsedP n
137 bufferTime = 1000 * 300
139 msToElapsedP :: Int -> TM.ElapsedP
140 msToElapsedP ms | (s,ms') <- fromIntegral ms `divMod` 1000 =
141 TM.ElapsedP (TM.Elapsed (TM.Seconds s)) (TM.NanoSeconds $ 1000000 * ms')
142 elapsedPToMS :: TM.ElapsedP -> Int
143 elapsedPToMS (TM.ElapsedP (TM.Elapsed (TM.Seconds s)) (TM.NanoSeconds ns)) =
144 fromIntegral $ s*1000 + ns `div` 1000000