r7750: handle STATUS_MORE_ENTRIES on send in tls
[Samba/ekacnet.git] / source4 / lib / tls / tls.c
blob559a54a2f0e2115fc25c8c778f2d0a92c5cf1bb7
1 /*
2 Unix SMB/CIFS implementation.
4 transport layer security handling code
6 Copyright (C) Andrew Tridgell 2005
8 This program is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
18 You should have received a copy of the GNU General Public License
19 along with this program; if not, write to the Free Software
20 Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
23 #include "includes.h"
24 #include "lib/events/events.h"
25 #include "lib/socket/socket.h"
26 #include "lib/tls/tls.h"
28 #if HAVE_LIBGNUTLS
29 #include "gnutls/gnutls.h"
31 #define DH_BITS 1024
33 /* hold persistent tls data */
34 struct tls_params {
35 gnutls_certificate_credentials x509_cred;
36 gnutls_dh_params dh_params;
37 BOOL tls_enabled;
40 /* hold per connection tls data */
41 struct tls_context {
42 struct tls_params *params;
43 struct socket_context *socket;
44 struct fd_event *fde;
45 gnutls_session session;
46 BOOL done_handshake;
47 BOOL have_first_byte;
48 uint8_t first_byte;
49 BOOL tls_enabled;
50 BOOL tls_detect;
51 const char *plain_chars;
52 BOOL output_pending;
57 callback for reading from a socket
59 static ssize_t tls_pull(gnutls_transport_ptr ptr, void *buf, size_t size)
61 struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
62 NTSTATUS status;
63 size_t nread;
65 if (tls->have_first_byte) {
66 *(uint8_t *)buf = tls->first_byte;
67 tls->have_first_byte = False;
68 return 1;
71 status = socket_recv(tls->socket, buf, size, &nread, 0);
72 if (NT_STATUS_EQUAL(status, NT_STATUS_END_OF_FILE)) {
73 return 0;
75 if (NT_STATUS_IS_ERR(status)) {
76 EVENT_FD_NOT_READABLE(tls->fde);
77 EVENT_FD_NOT_WRITEABLE(tls->fde);
78 errno = EBADF;
79 return -1;
81 if (!NT_STATUS_IS_OK(status)) {
82 EVENT_FD_READABLE(tls->fde);
83 EVENT_FD_NOT_WRITEABLE(tls->fde);
84 errno = EAGAIN;
85 return -1;
87 if (tls->output_pending) {
88 EVENT_FD_WRITEABLE(tls->fde);
90 if (size != nread) {
91 EVENT_FD_READABLE(tls->fde);
93 return nread;
97 callback for writing to a socket
99 static ssize_t tls_push(gnutls_transport_ptr ptr, const void *buf, size_t size)
101 struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
102 NTSTATUS status;
103 size_t nwritten;
104 DATA_BLOB b;
106 if (!tls->tls_enabled) {
107 return size;
110 b.data = discard_const(buf);
111 b.length = size;
113 status = socket_send(tls->socket, &b, &nwritten, 0);
114 if (NT_STATUS_EQUAL(status, STATUS_MORE_ENTRIES)) {
115 errno = EAGAIN;
116 return -1;
118 if (!NT_STATUS_IS_OK(status)) {
119 EVENT_FD_WRITEABLE(tls->fde);
120 return -1;
122 if (size != nwritten) {
123 EVENT_FD_WRITEABLE(tls->fde);
125 return nwritten;
129 destroy a tls session
131 static int tls_destructor(void *ptr)
133 struct tls_context *tls = talloc_get_type(ptr, struct tls_context);
134 int ret;
135 ret = gnutls_bye(tls->session, GNUTLS_SHUT_WR);
136 if (ret < 0) {
137 DEBUG(0,("TLS gnutls_bye failed - %s\n", gnutls_strerror(ret)));
139 return 0;
144 possibly continue the handshake process
146 static NTSTATUS tls_handshake(struct tls_context *tls)
148 int ret;
150 if (tls->done_handshake) {
151 return NT_STATUS_OK;
154 ret = gnutls_handshake(tls->session);
155 if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
156 return STATUS_MORE_ENTRIES;
158 if (ret < 0) {
159 DEBUG(0,("TLS gnutls_handshake failed - %s\n", gnutls_strerror(ret)));
160 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
162 tls->done_handshake = True;
163 return NT_STATUS_OK;
167 see how many bytes are pending on the connection
169 NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending)
171 if (!tls->tls_enabled || tls->tls_detect) {
172 return socket_pending(tls->socket, npending);
174 *npending = gnutls_record_check_pending(tls->session);
175 if (*npending == 0) {
176 return socket_pending(tls->socket, npending);
178 return NT_STATUS_OK;
182 receive data either by tls or normal socket_recv
184 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen,
185 size_t *nread)
187 int ret;
188 NTSTATUS status;
189 if (tls->tls_enabled && tls->tls_detect) {
190 status = socket_recv(tls->socket, &tls->first_byte, 1, nread, 0);
191 NT_STATUS_NOT_OK_RETURN(status);
192 if (*nread == 0) return NT_STATUS_OK;
193 tls->tls_detect = False;
194 /* look for the first byte of a valid HTTP operation */
195 if (strchr(tls->plain_chars, tls->first_byte)) {
196 /* not a tls link */
197 tls->tls_enabled = False;
198 *(uint8_t *)buf = tls->first_byte;
199 return NT_STATUS_OK;
201 tls->have_first_byte = True;
204 if (!tls->tls_enabled) {
205 return socket_recv(tls->socket, buf, wantlen, nread, 0);
208 status = tls_handshake(tls);
209 NT_STATUS_NOT_OK_RETURN(status);
211 ret = gnutls_record_recv(tls->session, buf, wantlen);
212 if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
213 return STATUS_MORE_ENTRIES;
215 if (ret < 0) {
216 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
218 *nread = ret;
219 return NT_STATUS_OK;
224 send data either by tls or normal socket_recv
226 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
228 NTSTATUS status;
229 int ret;
231 if (!tls->tls_enabled) {
232 return socket_send(tls->socket, blob, sendlen, 0);
235 status = tls_handshake(tls);
236 NT_STATUS_NOT_OK_RETURN(status);
238 ret = gnutls_record_send(tls->session, blob->data, blob->length);
239 if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
240 return STATUS_MORE_ENTRIES;
242 if (ret < 0) {
243 DEBUG(0,("gnutls_record_send of %d failed - %s\n", blob->length, gnutls_strerror(ret)));
244 return NT_STATUS_UNEXPECTED_NETWORK_ERROR;
246 *sendlen = ret;
247 tls->output_pending = (ret < blob->length);
248 return NT_STATUS_OK;
253 initialise global tls state
255 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
257 struct tls_params *params;
258 int ret;
259 const char *keyfile = lp_tls_keyfile();
260 const char *certfile = lp_tls_certfile();
261 const char *cafile = lp_tls_cafile();
262 const char *crlfile = lp_tls_crlfile();
263 void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *);
265 params = talloc(mem_ctx, struct tls_params);
266 if (params == NULL) return NULL;
268 if (!lp_tls_enabled() || keyfile == NULL || *keyfile == 0) {
269 params->tls_enabled = False;
270 return params;
273 if (!file_exist(cafile)) {
274 tls_cert_generate(params, keyfile, certfile, cafile);
277 ret = gnutls_global_init();
278 if (ret < 0) goto init_failed;
280 gnutls_certificate_allocate_credentials(&params->x509_cred);
281 if (ret < 0) goto init_failed;
283 if (cafile && *cafile) {
284 ret = gnutls_certificate_set_x509_trust_file(params->x509_cred, cafile,
285 GNUTLS_X509_FMT_PEM);
286 if (ret < 0) {
287 DEBUG(0,("TLS failed to initialise cafile %s\n", cafile));
288 goto init_failed;
292 if (crlfile && *crlfile) {
293 ret = gnutls_certificate_set_x509_crl_file(params->x509_cred,
294 crlfile,
295 GNUTLS_X509_FMT_PEM);
296 if (ret < 0) {
297 DEBUG(0,("TLS failed to initialise crlfile %s\n", crlfile));
298 goto init_failed;
302 ret = gnutls_certificate_set_x509_key_file(params->x509_cred,
303 certfile, keyfile,
304 GNUTLS_X509_FMT_PEM);
305 if (ret < 0) {
306 DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s\n",
307 certfile, keyfile));
308 goto init_failed;
311 ret = gnutls_dh_params_init(&params->dh_params);
312 if (ret < 0) goto init_failed;
314 ret = gnutls_dh_params_generate2(params->dh_params, DH_BITS);
315 if (ret < 0) goto init_failed;
317 gnutls_certificate_set_dh_params(params->x509_cred, params->dh_params);
319 params->tls_enabled = True;
320 return params;
322 init_failed:
323 DEBUG(0,("GNUTLS failed to initialise - %s\n", gnutls_strerror(ret)));
324 params->tls_enabled = False;
325 return params;
330 setup for a new connection
332 struct tls_context *tls_init_server(struct tls_params *params,
333 struct socket_context *socket,
334 struct fd_event *fde,
335 const char *plain_chars)
337 struct tls_context *tls;
338 int ret;
340 tls = talloc(socket, struct tls_context);
341 if (tls == NULL) return NULL;
343 tls->socket = socket;
344 tls->fde = fde;
346 if (!params->tls_enabled) {
347 tls->tls_enabled = False;
348 return tls;
351 #define TLSCHECK(call) do { \
352 ret = call; \
353 if (ret < 0) { \
354 DEBUG(0,("TLS %s - %s\n", #call, gnutls_strerror(ret))); \
355 goto failed; \
357 } while (0)
359 TLSCHECK(gnutls_init(&tls->session, GNUTLS_SERVER));
361 talloc_set_destructor(tls, tls_destructor);
363 TLSCHECK(gnutls_set_default_priority(tls->session));
364 TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE,
365 params->x509_cred));
366 gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST);
367 gnutls_dh_set_prime_bits(tls->session, DH_BITS);
368 gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls);
369 gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull);
370 gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push);
371 gnutls_transport_set_lowat(tls->session, 0);
373 tls->plain_chars = plain_chars;
374 if (plain_chars) {
375 tls->tls_detect = True;
376 } else {
377 tls->tls_detect = False;
380 tls->output_pending = False;
381 tls->params = params;
382 tls->done_handshake = False;
383 tls->have_first_byte = False;
384 tls->tls_enabled = True;
386 return tls;
388 failed:
389 DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret)));
390 tls->tls_enabled = False;
391 params->tls_enabled = False;
392 return tls;
395 BOOL tls_enabled(struct tls_context *tls)
397 return tls->tls_enabled;
400 BOOL tls_support(struct tls_params *params)
402 return params->tls_enabled;
406 #else
408 /* for systems without tls we just map the tls socket calls to the
409 normal socket calls */
411 struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx)
413 return talloc_new(mem_ctx);
416 struct tls_context *tls_init_server(struct tls_params *params,
417 struct socket_context *sock,
418 struct fd_event *fde,
419 const char *plain_chars)
421 if (plain_chars == NULL) return NULL;
422 return (struct tls_context *)sock;
426 NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen,
427 size_t *nread)
429 return socket_recv((struct socket_context *)tls, buf, wantlen, nread, 0);
432 NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t *sendlen)
434 return socket_send((struct socket_context *)tls, blob, sendlen, 0);
437 BOOL tls_enabled(struct tls_context *tls)
439 return False;
442 BOOL tls_support(struct tls_params *params)
444 return False;
447 NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending)
449 return socket_pending((struct socket_context *)tls, npending);
452 #endif