1 /* Copyright 2018-2021, The Tor Project Inc. */
2 /* See LICENSE for licensing information */
5 * \file nss_countbytes.c
6 * \brief A PRFileDesc layer to let us count the number of bytes
7 * bytes actually written on a PRFileDesc.
12 #include "lib/log/util_bug.h"
13 #include "lib/malloc/malloc.h"
14 #include "lib/tls/nss_countbytes.h"
21 /** Boolean: have we initialized this module */
22 static bool countbytes_initialized
= false;
24 /** Integer to identity this layer. */
25 static PRDescIdentity countbytes_layer_id
= PR_INVALID_IO_LAYER
;
27 /** Table of methods for this layer.*/
28 static PRIOMethods countbytes_methods
;
30 /** Default close function provided by NSPR. We use this to help
31 * implement our own close function.*/
32 static PRStatus(*default_close_fn
)(PRFileDesc
*fd
);
34 static PRStatus
countbytes_close_fn(PRFileDesc
*fd
);
35 static PRInt32
countbytes_read_fn(PRFileDesc
*fd
, void *buf
, PRInt32 amount
);
36 static PRInt32
countbytes_write_fn(PRFileDesc
*fd
, const void *buf
,
38 static PRInt32
countbytes_writev_fn(PRFileDesc
*fd
, const PRIOVec
*iov
,
39 PRInt32 size
, PRIntervalTime timeout
);
40 static PRInt32
countbytes_send_fn(PRFileDesc
*fd
, const void *buf
,
41 PRInt32 amount
, PRIntn flags
,
42 PRIntervalTime timeout
);
43 static PRInt32
countbytes_recv_fn(PRFileDesc
*fd
, void *buf
, PRInt32 amount
,
44 PRIntn flags
, PRIntervalTime timeout
);
46 /** Private fields for the byte-counter layer. We cast this to and from
47 * PRFilePrivate*, which is supposed to be allowed. */
48 typedef struct tor_nss_bytecounts_t
{
51 } tor_nss_bytecounts_t
;
54 * Initialize this module, if it is not already initialized.
57 tor_nss_countbytes_init(void)
59 if (countbytes_initialized
)
62 countbytes_layer_id
= PR_GetUniqueIdentity("Tor byte-counting layer");
63 tor_assert(countbytes_layer_id
!= PR_INVALID_IO_LAYER
);
65 memcpy(&countbytes_methods
, PR_GetDefaultIOMethods(), sizeof(PRIOMethods
));
67 default_close_fn
= countbytes_methods
.close
;
68 countbytes_methods
.close
= countbytes_close_fn
;
69 countbytes_methods
.read
= countbytes_read_fn
;
70 countbytes_methods
.write
= countbytes_write_fn
;
71 countbytes_methods
.writev
= countbytes_writev_fn
;
72 countbytes_methods
.send
= countbytes_send_fn
;
73 countbytes_methods
.recv
= countbytes_recv_fn
;
74 /* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think
75 * NSS won't be using them for TLS connections. */
77 countbytes_initialized
= true;
81 * Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that
82 * the IO layer is in fact a layer created by this module.
84 static tor_nss_bytecounts_t
*
85 get_counts(PRFileDesc
*fd
)
87 tor_assert(fd
->identity
== countbytes_layer_id
);
88 return (tor_nss_bytecounts_t
*) fd
->secret
;
91 /** Helper: increment the read-count of an fd by n. */
92 #define INC_READ(fd, n) STMT_BEGIN \
93 get_counts(fd)->n_read += (n); \
96 /** Helper: increment the write-count of an fd by n. */
97 #define INC_WRITTEN(fd, n) STMT_BEGIN \
98 get_counts(fd)->n_written += (n); \
101 /** Implementation for PR_Close: frees the 'secret' field, then passes control
102 * to the default close function */
104 countbytes_close_fn(PRFileDesc
*fd
)
108 tor_nss_bytecounts_t
*counts
= (tor_nss_bytecounts_t
*)fd
->secret
;
112 return default_close_fn(fd
);
115 /** Implementation for PR_Read: Calls the lower-level read function,
116 * and records what it said. */
118 countbytes_read_fn(PRFileDesc
*fd
, void *buf
, PRInt32 amount
)
121 tor_assert(fd
->lower
);
123 PRInt32 result
= (fd
->lower
->methods
->read
)(fd
->lower
, buf
, amount
);
125 INC_READ(fd
, result
);
128 /** Implementation for PR_Write: Calls the lower-level write function,
129 * and records what it said. */
131 countbytes_write_fn(PRFileDesc
*fd
, const void *buf
, PRInt32 amount
)
134 tor_assert(fd
->lower
);
136 PRInt32 result
= (fd
->lower
->methods
->write
)(fd
->lower
, buf
, amount
);
138 INC_WRITTEN(fd
, result
);
141 /** Implementation for PR_Writev: Calls the lower-level writev function,
142 * and records what it said. */
144 countbytes_writev_fn(PRFileDesc
*fd
, const PRIOVec
*iov
,
145 PRInt32 size
, PRIntervalTime timeout
)
148 tor_assert(fd
->lower
);
150 PRInt32 result
= (fd
->lower
->methods
->writev
)(fd
->lower
, iov
, size
, timeout
);
152 INC_WRITTEN(fd
, result
);
155 /** Implementation for PR_Send: Calls the lower-level send function,
156 * and records what it said. */
158 countbytes_send_fn(PRFileDesc
*fd
, const void *buf
,
159 PRInt32 amount
, PRIntn flags
, PRIntervalTime timeout
)
162 tor_assert(fd
->lower
);
164 PRInt32 result
= (fd
->lower
->methods
->send
)(fd
->lower
, buf
, amount
, flags
,
167 INC_WRITTEN(fd
, result
);
170 /** Implementation for PR_Recv: Calls the lower-level recv function,
171 * and records what it said. */
173 countbytes_recv_fn(PRFileDesc
*fd
, void *buf
, PRInt32 amount
,
174 PRIntn flags
, PRIntervalTime timeout
)
177 tor_assert(fd
->lower
);
179 PRInt32 result
= (fd
->lower
->methods
->recv
)(fd
->lower
, buf
, amount
, flags
,
182 INC_READ(fd
, result
);
187 * Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the
188 * total number of bytes read and written. Return the new PRFileDesc.
190 * This function takes ownership of its input.
193 tor_wrap_prfiledesc_with_byte_counter(PRFileDesc
*stack
)
195 if (BUG(! countbytes_initialized
)) {
196 tor_nss_countbytes_init();
199 tor_nss_bytecounts_t
*bytecounts
= tor_malloc_zero(sizeof(*bytecounts
));
201 PRFileDesc
*newfd
= PR_CreateIOLayerStub(countbytes_layer_id
,
202 &countbytes_methods
);
204 newfd
->secret
= (PRFilePrivate
*)bytecounts
;
206 /* This does some complicated messing around with the headers of these
207 objects; see the NSPR documentation for more. The upshot is that
208 after PushIOLayer, "stack" will be the head of the stack.
210 PRStatus status
= PR_PushIOLayer(stack
, PR_TOP_IO_LAYER
, newfd
);
211 tor_assert(status
== PR_SUCCESS
);
217 * Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(),
218 * or another PRFileDesc wrapping that PRFileDesc, set the provided
219 * pointers to the number of bytes read and written on the descriptor since
222 * Return 0 on success, -1 on failure.
225 tor_get_prfiledesc_byte_counts(PRFileDesc
*fd
,
226 uint64_t *n_read_out
,
227 uint64_t *n_written_out
)
229 if (BUG(! countbytes_initialized
)) {
230 tor_nss_countbytes_init();
234 PRFileDesc
*bclayer
= PR_GetIdentitiesLayer(fd
, countbytes_layer_id
);
235 if (BUG(bclayer
== NULL
))
238 tor_nss_bytecounts_t
*counts
= get_counts(bclayer
);
240 *n_read_out
= counts
->n_read
;
241 *n_written_out
= counts
->n_written
;