1 -- This file is part of htalkat
2 -- Copyright (C) 2021 Martin Bays <mbays@sdf.org>
4 -- This program is free software: you can redistribute it and/or modify
5 -- it under the terms of version 3 of the GNU General Public License as
6 -- published by the Free Software Foundation, or any later version.
8 -- You should have received a copy of the GNU General Public License
9 -- along with this program. If not, see http://www.gnu.org/licenses/.
11 {-# LANGUAGE LambdaCase #-}
12 {-# LANGUAGE OverloadedStrings #-}
14 module RelayStream
where
16 import Control
.Concurrent
17 import Control
.Exception
(SomeException
, handle
)
18 import Control
.Monad
(foldM_
, forever
, unless, void
,
20 import System
.Timeout
(timeout
)
22 import qualified Data
.ByteString
as BS
23 import qualified Data
.ByteString
.Lazy
as BL
24 import qualified Data
.Text
.Encoding
.Error
as T
25 import qualified Data
.Text
.Lazy
as T
26 import qualified Data
.Text
.Lazy
.Encoding
as T
27 import qualified Network
.Socket
as S
28 import qualified Network
.Socket
.ByteString
.Lazy
as SL
29 import qualified Network
.TLS
as TLS
30 import qualified Time
.System
as TM
31 import qualified Time
.Types
as TM
36 data WriteOrder
= WriteFirst | WriteSecond
37 deriving (Eq
,Ord
,Show)
39 relayStream
:: TLS
.Context
-> WriteOrder
-> S
.Socket
-> IO ()
40 relayStream ctxt
ord dSock
= do
41 receivedHandshake
<- newEmptyMVar
42 finished
<- newEmptyMVar
44 let abort
= putMVar finished
()
45 abortOnErr
= handle abortHandler
where
46 abortHandler
:: Monoid a
=> SomeException
-> IO a
47 abortHandler _
= abort
>> pure mempty
49 b
<- TLS
.recvData ctxt
53 ok
<- tryReadMVar receivedHandshake
>>= \case
56 let isHandshakeByte
= h
== introByte
57 putMVar receivedHandshake isHandshakeByte
60 else abort
>> pure
False
61 if ok
then writeChan rawInChan b
>> recvAll
62 else writeChan rawInChan BS
.empty
64 when (ord == WriteSecond
) . void
$ readMVar receivedHandshake
65 TLS
.sendData ctxt
$ BL
.singleton introByte
67 sockMV
<- newEmptyMVar
68 sockThread
<- forkIO
$ putMVar sockMV
. fst =<< S
.accept dSock
71 sock
<- readMVar sockMV
72 abortOnErr sendHandshake
75 _
<- forkIO
. abortOnErr
$ do
76 writeList2Chan rawOutChan
. T
.unpack
. T
.decodeUtf8With T
.lenientDecode
=<<
79 pausesThread
<- forkIO
$ insertPauses rawOutChan tsOutChan
80 abortOnErr
$ sendAll tsOutChan
81 killThread pausesThread
85 decodeTTThread
<- forkIO
$
86 writeList2Chan tsInChan
. decodeTimedText
. BL
.fromChunks
=<< getChanContents rawInChan
87 _
<- forkIO
. abortOnErr
$ relayTimed tsInChan
=<< readMVar sockMV
89 killThread decodeTTThread
91 _
<- takeMVar finished
92 ignoreIOErr
$ TLS
.bye ctxt
>> killThread sockThread
93 tryTakeMVar sockMV
>>= \case
95 Just sock
-> S
.gracefulClose sock
1000
97 introByte
= fromIntegral $ fromEnum 'T
'
99 insertPauses rawChan ttChan
= TM
.timeCurrentP
>>= insertPauses
'
102 c
<- readChan rawChan
103 e
' <- TM
.timeCurrentP
104 let ms
= elapsedPToMS
$ e
' - e
105 when (ms
> 0) . writeChan ttChan
. Left
$ fromIntegral ms
106 writeChan ttChan
$ Right c
109 sendAll ttChan
= forever
$ do
110 readBufMV
<- newMVar
[]
111 _
<- timeout sendTimeout
. forever
$
112 modifyMVar_ readBufMV
. (pure
.) . (:) =<< readChan ttChan
113 rtt
<- readMVar readBufMV
114 unless (null rtt
) . TLS
.sendData ctxt
. rechunk
. encodeTimedText
$ reverse rtt
117 -- TLS.sendData sends one packet per chunk, while encodeTimedText
118 -- returns a chunk per char, so it's important to rechunk.
119 BL
.fromStrict
. BL
.toStrict
120 sendTimeout
= 1000 * 300
122 relayTimed chan sock
= foldM_ sendTimed
' Nothing
=<< getChanContents chan
where
123 sendTimed
' :: Maybe TM
.ElapsedP
-> Either Int Char -> IO (Maybe TM
.ElapsedP
)
124 sendTimed
' Nothing
(Right c
) = do
125 threadDelay bufferTime
127 sendTimed
' (Just e
) (Right c
)
128 sendTimed
' Nothing _
= pure Nothing
129 sendTimed
' (Just e
) (Right c
) = do
130 SL
.sendAll sock
. T
.encodeUtf8
$ T
.singleton c
132 sendTimed
' (Just e
) (Left n
) = do
133 delayed
<- elapsedPToMS
. flip (-) e
<$> TM
.timeCurrentP
134 when (n
> delayed
) . threadDelay
. (1000 *) $ n
- delayed
135 pure
$ if n
== pauseMax
&& n
< delayed
137 else Just
$ e
+ msToElapsedP n
138 bufferTime
= 1000 * 300
140 msToElapsedP
:: Int -> TM
.ElapsedP
141 msToElapsedP ms |
(s
,ms
') <- fromIntegral ms `
divMod`
1000 =
142 TM
.ElapsedP
(TM
.Elapsed
(TM
.Seconds s
)) (TM
.NanoSeconds
$ 1000000 * ms
')
143 elapsedPToMS
:: TM
.ElapsedP
-> Int
144 elapsedPToMS
(TM
.ElapsedP
(TM
.Elapsed
(TM
.Seconds s
)) (TM
.NanoSeconds ns
)) =
145 fromIntegral $ s
*1000 + ns `
div`
1000000