1 -- This file is part of htalkat
2 -- Copyright (C) 2021 Martin Bays <mbays@sdf.org>
4 -- This program is free software: you can redistribute it and/or modify
5 -- it under the terms of version 3 of the GNU General Public License as
6 -- published by the Free Software Foundation, or any later version.
8 -- You should have received a copy of the GNU General Public License
9 -- along with this program. If not, see http://www.gnu.org/licenses/.
11 {-# LANGUAGE LambdaCase #-}
13 module RelayStream
where
15 import Control
.Concurrent
16 import Control
.Exception
(SomeException
, handle
)
17 import Control
.Monad
(foldM_
, forever
, unless, void
,
19 import System
.Timeout
(timeout
)
21 import qualified Data
.ByteString
as BS
22 import qualified Data
.ByteString
.Lazy
as BL
23 import qualified Data
.Text
.Encoding
.Error
as T
24 import qualified Data
.Text
.Lazy
as T
25 import qualified Data
.Text
.Lazy
.Encoding
as T
26 import qualified Network
.Socket
as S
27 import qualified Network
.Socket
.ByteString
.Lazy
as SL
28 import qualified Network
.TLS
as TLS
29 import qualified Time
.System
as TM
30 import qualified Time
.Types
as TM
35 data WriteOrder
= WriteFirst | WriteSecond
36 deriving (Eq
,Ord
,Show)
38 relayStream
:: TLS
.Context
-> WriteOrder
-> S
.Socket
-> IO ()
39 relayStream ctxt
ord dSock
= do
40 receivedHandshake
<- newEmptyMVar
41 finished
<- newEmptyMVar
43 let abort
= putMVar finished
()
44 abortOnErr
= handle abortHandler
where
45 abortHandler
:: Monoid a
=> SomeException
-> IO a
46 abortHandler _
= abort
>> pure mempty
48 b
<- TLS
.recvData ctxt
52 ok
<- tryReadMVar receivedHandshake
>>= \case
55 let isHandshakeByte
= h
== introByte
56 putMVar receivedHandshake isHandshakeByte
59 else abort
>> pure
False
60 if ok
then writeChan rawInChan b
>> recvAll
61 else writeChan rawInChan BS
.empty
63 when (ord == WriteSecond
) . void
$ readMVar receivedHandshake
64 TLS
.sendData ctxt
$ BL
.singleton introByte
66 sockMV
<- newEmptyMVar
67 sockThread
<- forkIO
$ putMVar sockMV
. fst =<< S
.accept dSock
70 sock
<- readMVar sockMV
71 abortOnErr sendHandshake
74 _
<- forkIO
. abortOnErr
$ do
75 writeList2Chan rawOutChan
. T
.unpack
. T
.decodeUtf8With T
.lenientDecode
=<<
78 pausesThread
<- forkIO
$ insertPauses rawOutChan tsOutChan
79 abortOnErr
$ sendAll tsOutChan
80 killThread pausesThread
84 decodeTTThread
<- forkIO
$
85 writeList2Chan tsInChan
. decodeTimedText
. BL
.fromChunks
=<< getChanContents rawInChan
86 _
<- forkIO
. abortOnErr
$ relayTimed tsInChan
=<< readMVar sockMV
88 killThread decodeTTThread
90 _
<- takeMVar finished
91 ignoreIOErr
$ TLS
.bye ctxt
>> killThread sockThread
92 tryTakeMVar sockMV
>>= \case
94 Just sock
-> S
.gracefulClose sock
1000
96 introByte
= fromIntegral $ fromEnum 'T
'
98 insertPauses rawChan ttChan
= TM
.timeCurrentP
>>= insertPauses
'
101 c
<- readChan rawChan
102 e
' <- TM
.timeCurrentP
103 let ms
= elapsedPToMS
$ e
' - e
104 when (ms
> 0) . writeChan ttChan
. Left
$ fromIntegral ms
105 writeChan ttChan
$ Right c
108 sendAll ttChan
= forever
$ do
109 readBufMV
<- newMVar
[]
110 _
<- timeout sendTimeout
. forever
$
111 modifyMVar_ readBufMV
. (pure
.) . (:) =<< readChan ttChan
112 rtt
<- readMVar readBufMV
113 unless (null rtt
) . TLS
.sendData ctxt
. rechunk
. encodeTimedText
$ reverse rtt
116 -- TLS.sendData sends one packet per chunk, while encodeTimedText
117 -- returns a chunk per char, so it's important to rechunk.
118 BL
.fromStrict
. BL
.toStrict
119 sendTimeout
= 1000 * 300
121 relayTimed chan sock
= foldM_ sendTimed
' Nothing
=<< getChanContents chan
where
122 sendTimed
' :: Maybe TM
.ElapsedP
-> Either Int Char -> IO (Maybe TM
.ElapsedP
)
123 sendTimed
' Nothing
(Right c
) = do
124 threadDelay bufferTime
126 sendTimed
' (Just e
) (Right c
)
127 sendTimed
' Nothing _
= pure Nothing
128 sendTimed
' (Just e
) (Right c
) = do
129 SL
.sendAll sock
. T
.encodeUtf8
$ T
.singleton c
131 sendTimed
' (Just e
) (Left n
) = do
132 delayed
<- elapsedPToMS
. flip (-) e
<$> TM
.timeCurrentP
133 when (n
> delayed
) . threadDelay
. (1000 *) $ n
- delayed
134 pure
$ if n
== pauseMax
&& n
< delayed
136 else Just
$ e
+ msToElapsedP n
137 bufferTime
= 1000 * 300
139 msToElapsedP
:: Int -> TM
.ElapsedP
140 msToElapsedP ms |
(s
,ms
') <- fromIntegral ms `
divMod`
1000 =
141 TM
.ElapsedP
(TM
.Elapsed
(TM
.Seconds s
)) (TM
.NanoSeconds
$ 1000000 * ms
')
142 elapsedPToMS
:: TM
.ElapsedP
-> Int
143 elapsedPToMS
(TM
.ElapsedP
(TM
.Elapsed
(TM
.Seconds s
)) (TM
.NanoSeconds ns
)) =
144 fromIntegral $ s
*1000 + ns `
div`
1000000