sq3: do not explicitly init/shutdown sqlite library; added some missing constants
[iv.d.git] / sslsocket.d
blob27f677b4c942e407d98a6c0b678390e69acc9610
1 /* Invisible Vector Library
2 * coded by Ketmar // Invisible Vector <ketmar@ketmar.no-ip.org>
3 * Understanding is not required. Only obedience.
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, version 3 of the License ONLY.
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
17 // loosely based on opticron and Adam D. Ruppe work
18 module iv.sslsocket /*is aliced*/;
20 import iv.alice;
21 public import std.socket;
22 import iv.gnutls;
25 // ///////////////////////////////////////////////////////////////////////// //
26 shared static this () { gnutls_global_init(); }
27 shared static ~this () { gnutls_global_deinit(); }
30 // ///////////////////////////////////////////////////////////////////////// //
31 /// deprecated!
32 class SSLClientSocket : Socket {
33 gnutls_certificate_credentials_t xcred;
34 gnutls_session_t session;
35 private bool sslInitialized;
36 bool manualHandshake = false; // for non-blocking sockets this should be `true`
37 bool isblocking = false;
38 string certBaseName;
40 // take care of pre-connection TLS stuff
41 //FIXME: possible memory leak on exception? (sholdn't be, as `close()` will free the things)
42 private void sslInit (string acertbasename) {
43 if (sslInitialized) return;
45 // x509 stuff
46 gnutls_certificate_allocate_credentials(&xcred);
48 // sets the trusted certificate authority file (no need for us, as we aren't checking any certificate)
49 //gnutls_certificate_set_x509_trust_file(xcred, CAFILE, GNUTLS_X509_FMT_PEM);
50 if (acertbasename.length) {
51 certBaseName = acertbasename;
52 import std.internal.cstring : tempCString;
53 string cfname = certBaseName~".cer";
54 string kfname = certBaseName~".key";
55 int ret = gnutls_certificate_set_x509_key_file(xcred, cfname.tempCString, kfname.tempCString, GNUTLS_X509_FMT_PEM);
56 if (ret < 0) {
57 import std.string : fromStringz;
58 import std.conv : to;
59 string errstr = gnutls_strerror(ret).fromStringz.idup;
60 gnutls_certificate_free_credentials(xcred);
61 throw new Exception("TLS Error ("~errstr~"): cannot load certificate, err="~ret.to!string);
65 // initialize TLS session
66 gnutls_init(&session, GNUTLS_CLIENT|(certBaseName.length ? GNUTLS_FORCE_CLIENT_CERT : 0));
68 // use default priorities
70 const(char)* err;
71 auto ret = gnutls_priority_set_direct(session, "PERFORMANCE", &err);
72 if (ret < 0) {
73 import std.string : fromStringz;
74 import std.conv : to;
75 if (ret == GNUTLS_E_INVALID_REQUEST) throw new Exception("Syntax error at: "~err.fromStringz.idup);
76 string errstr = gnutls_strerror(ret).fromStringz.idup;
77 gnutls_deinit(session);
78 gnutls_certificate_free_credentials(xcred);
79 throw new Exception("TLS Error ("~errstr~"): returned with "~ret.to!string);
82 auto ret = gnutls_set_default_priority(session);
83 if (ret < 0) {
84 import std.string : fromStringz;
85 import std.conv : to;
86 //if (ret == GNUTLS_E_INVALID_REQUEST) throw new Exception("Syntax error at: "~err.fromStringz.idup);
87 string errstr = gnutls_strerror(ret).fromStringz.idup;
88 gnutls_deinit(session);
89 gnutls_certificate_free_credentials(xcred);
90 throw new Exception("TLS Error ("~errstr~"): returned with "~ret.to!string);
92 gnutls_session_enable_compatibility_mode(session);
94 // put the x509 credentials to the current session
95 gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, xcred);
96 sslInitialized = true;
99 // this is required for new TLS (fuck)
100 // call this before connecting
101 public void sslhostname (const(char)[] hname) @trusted {
102 import std.internal.cstring : tempCString;
103 int res = gnutls_server_name_set(session, GNUTLS_NAME_DNS, hname.tempCString, hname.length);
104 if (res < 0) {
105 import std.string : fromStringz;
106 import std.conv : to;
107 string errstr = gnutls_strerror(res).fromStringz.idup;
108 //gnutls_deinit(session);
109 //gnutls_certificate_free_credentials(xcred);
110 throw new Exception("TLS Error ("~errstr~"): returned with "~res.to!string);
114 public void sslHandshake () {
115 // lob the socket handle off to gnutls
116 gnutls_transport_set_ptr(session, cast(gnutls_transport_ptr_t)handle);
117 // perform the TLS handshake
118 for (;;) {
119 auto ret = gnutls_handshake(session);
120 if (ret < 0 && !gnutls_error_is_fatal(ret)) continue;
121 if (ret < 0) {
122 import std.string : fromStringz;
123 throw new Exception("Handshake failed: "~gnutls_strerror(ret).fromStringz.idup);
125 break;
129 override @property void blocking (bool byes) @trusted {
130 super.blocking(byes);
131 isblocking = byes;
134 override void connect (Address to) @trusted {
135 super.connect(to);
136 if (!manualHandshake) sslHandshake();
139 // close the encrypted connection
140 override void close () @trusted {
141 scope(exit) sslInitialized = false;
142 if (sslInitialized) {
143 //{ import core.stdc.stdio : printf; printf("deiniting\n"); }
144 gnutls_bye(session, GNUTLS_SHUT_RDWR);
145 gnutls_deinit(session);
146 gnutls_certificate_free_credentials(xcred);
148 super.close();
151 override ptrdiff_t send (const(void)[] buf, SocketFlags flags) @trusted {
152 if (buf.length == 0) return 0;
153 for (;;) {
154 auto res = gnutls_record_send(session, buf.ptr, buf.length);
155 if (res >= 0 || !isblocking) return res;
156 if (res >= 0 || !isblocking) {
157 import core.stdc.errno;
158 if (res == GNUTLS_E_INTERRUPTED) res = EINTR;
159 else if (res == GNUTLS_E_AGAIN) res = EAGAIN;
160 return res;
162 if (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) continue;
163 if (gnutls_error_is_fatal(res)) return res;
164 return res;
168 override ptrdiff_t send (const(void)[] buf) {
169 import core.sys.posix.sys.socket;
170 static if (is(typeof(MSG_NOSIGNAL))) {
171 return send(buf, cast(SocketFlags)MSG_NOSIGNAL);
172 } else {
173 return send(buf, SocketFlags.NOSIGNAL);
177 override ptrdiff_t receive (void[] buf, SocketFlags flags) @trusted {
178 if (buf.length == 0) return 0;
179 for (;;) {
180 auto res = gnutls_record_recv(session, buf.ptr, buf.length);
181 if (res >= 0 || !isblocking) {
182 import core.stdc.errno;
183 if (res == GNUTLS_E_INTERRUPTED) res = EINTR;
184 else if (res == GNUTLS_E_AGAIN) res = EAGAIN;
185 return res;
187 if (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) continue;
188 if (gnutls_error_is_fatal(res)) return res;
189 return res;
193 override ptrdiff_t receive (void[] buf) { return receive(buf, SocketFlags.NONE); }
195 this (AddressFamily af, SocketType type=SocketType.STREAM, string certbasename=null) {
196 sslInit(certbasename);
197 super(af, type);
200 this (socket_t sock, AddressFamily af, string certbasename=null) {
201 sslInit(certbasename);
202 super(sock, af);
207 // ///////////////////////////////////////////////////////////////////////// //
208 // this can be used as both client and server socket
209 // don't forget to set certificate file (and key file, if you have both) for server!
210 // `connect()` will do client mode, `accept()` will do server mode (and will return `SSLSocket` instance)
211 class SSLSocket : Socket {
212 gnutls_certificate_credentials_t xcred;
213 gnutls_session_t session;
214 private bool sslInitialized = false;
215 bool manualHandshake = false; // for non-blocking sockets this should be `true`
216 bool isblocking = false;
217 private bool thisIsServer = false;
218 // server
219 private string certfilez; // "cert.pem"
220 private string keyfilez; // "key.pem"
222 // both key and cert can be in one file
223 void setKeyCertFile (const(char)[] certname, const(char)[] keyname=null) {
224 if (certname.length == 0) { certname = keyname; keyname = null; }
225 if (certname.length == 0) {
226 certfilez = keyfilez = "";
227 } else {
228 auto buf = new char[](certname.length+1);
229 buf[] = 0;
230 buf[0..certname.length] = certname;
231 certfilez = cast(string)buf;
232 if (keyname.length != 0) {
233 buf = new char[](keyname.length+1);
234 buf[] = 0;
235 buf[0..keyname.length] = keyname;
237 keyfilez = cast(string)buf;
241 // take care of pre-connection TLS stuff
242 //FIXME: possible memory leak on exception? (sholdn't be, as `close()` will free the things)
243 private void sslInit () {
244 if (sslInitialized) return;
245 sslInitialized = true;
247 // x509 stuff
248 gnutls_certificate_allocate_credentials(&xcred);
250 // sets the trusted certificate authority file (no need for us, as we aren't checking any certificate)
251 //gnutls_certificate_set_x509_trust_file(xcred, CAFILE, GNUTLS_X509_FMT_PEM);
253 if (thisIsServer) {
254 // server
255 if (certfilez.length < 1) throw new SocketException("TLS Error: certificate file not set");
256 if (keyfilez.length < 1) throw new SocketException("TLS Error: key file not set");
257 auto res = gnutls_certificate_set_x509_key_file(xcred, certfilez.ptr, keyfilez.ptr, GNUTLS_X509_FMT_PEM);
258 if (res < 0) {
259 import std.conv : to;
260 throw new SocketException("TLS Error: returned with "~res.to!string);
262 gnutls_init(&session, GNUTLS_SERVER);
263 gnutls_certificate_server_set_request(session, GNUTLS_CERT_IGNORE);
264 gnutls_handshake_set_timeout(session, /*GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT*/2300);
265 } else {
266 // client
267 // initialize TLS session
268 gnutls_init(&session, GNUTLS_CLIENT);
271 // use default priorities
273 const(char)* err;
274 auto ret = gnutls_priority_set_direct(session, "PERFORMANCE", &err);
275 if (ret < 0) {
276 import std.string : fromStringz;
277 import std.conv : to;
278 if (ret == GNUTLS_E_INVALID_REQUEST) throw new SocketException("Syntax error at: "~err.fromStringz.idup);
279 throw new SocketException("TLS Error: returned with "~ret.to!string);
282 auto ret = gnutls_set_default_priority(session);
283 if (ret < 0) {
284 import std.string : fromStringz;
285 import std.conv : to;
286 //if (ret == GNUTLS_E_INVALID_REQUEST) throw new Exception("Syntax error at: "~err.fromStringz.idup);
287 string errstr = gnutls_strerror(ret).fromStringz.idup;
288 gnutls_deinit(session);
289 gnutls_certificate_free_credentials(xcred);
290 throw new Exception("TLS Error ("~errstr~"): returned with "~ret.to!string);
292 gnutls_session_enable_compatibility_mode(session);
294 // put the x509 credentials to the current session
295 gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, xcred);
298 // this is required for new TLS (fuck)
299 // call this before connecting
300 public void sslhostname (const(char)[] hname) @trusted {
301 import std.internal.cstring : tempCString;
302 int res = gnutls_server_name_set(session, GNUTLS_NAME_DNS, hname.tempCString, hname.length);
303 if (res < 0) {
304 import std.string : fromStringz;
305 import std.conv : to;
306 string errstr = gnutls_strerror(res).fromStringz.idup;
307 //gnutls_deinit(session);
308 //gnutls_certificate_free_credentials(xcred);
309 throw new Exception("TLS Error ("~errstr~"): returned with "~res.to!string);
313 public void sslHandshake () {
314 sslInit();
315 // lob the socket handle off to gnutls
316 gnutls_transport_set_ptr(session, cast(gnutls_transport_ptr_t)handle);
317 // perform the TLS handshake
318 for (;;) {
319 auto ret = gnutls_handshake(session);
320 if (ret < 0 && !gnutls_error_is_fatal(ret)) continue;
321 if (ret < 0) {
322 import std.string : fromStringz;
323 throw new Exception("Handshake failed: "~gnutls_strerror(ret).fromStringz.idup);
325 break;
329 override @property void blocking (bool byes) @trusted {
330 super.blocking(byes);
331 manualHandshake = !byes;
332 isblocking = byes;
335 override void connect (Address to) @trusted {
336 if (sslInitialized && thisIsServer) throw new SocketException("wtf?!");
337 thisIsServer = false;
338 sslInit();
339 super.connect(to);
340 if (!manualHandshake) sslHandshake();
343 protected override Socket accepting () pure nothrow {
344 return new SSLSocket();
347 override Socket accept () @trusted {
348 auto sk = super.accept();
349 if (auto ssk = cast(SSLSocket)sk) {
350 ssk.keyfilez = keyfilez;
351 ssk.certfilez = certfilez;
352 ssk.manualHandshake = manualHandshake;
353 ssk.thisIsServer = true;
354 ssk.sslInit();
355 if (!ssk.manualHandshake) ssk.sslHandshake();
356 } else {
357 throw new SocketAcceptException("failed to create ssl socket");
359 return sk;
362 // close the encrypted connection
363 override void close () @trusted {
364 scope(exit) sslInitialized = false;
365 if (sslInitialized) {
366 //{ import core.stdc.stdio : printf; printf("deiniting\n"); }
367 gnutls_bye(session, GNUTLS_SHUT_RDWR);
368 gnutls_deinit(session);
369 gnutls_certificate_free_credentials(xcred);
371 super.close();
374 override ptrdiff_t send (const(void)[] buf, SocketFlags flags) @trusted {
375 if (session is null || !sslInitialized) throw new SocketException("not initialized");
376 //return gnutls_record_send(session, buf.ptr, buf.length);
377 if (buf.length == 0) return 0;
378 for (;;) {
379 auto res = gnutls_record_send(session, buf.ptr, buf.length);
380 if (res >= 0 || !isblocking) return res;
381 if (res >= 0 || !isblocking) {
382 import core.stdc.errno;
383 if (res == GNUTLS_E_INTERRUPTED) res = EINTR;
384 else if (res == GNUTLS_E_AGAIN) res = EAGAIN;
385 return res;
387 if (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) continue;
388 if (gnutls_error_is_fatal(res)) return res;
389 return res;
393 override ptrdiff_t send (const(void)[] buf) {
394 import core.sys.posix.sys.socket;
395 static if (is(typeof(MSG_NOSIGNAL))) {
396 return send(buf, cast(SocketFlags)MSG_NOSIGNAL);
397 } else {
398 return send(buf, SocketFlags.NOSIGNAL);
402 override ptrdiff_t receive (void[] buf, SocketFlags flags) @trusted {
403 if (session is null || !sslInitialized) throw new SocketException("not initialized");
404 //return gnutls_record_recv(session, buf.ptr, buf.length);
405 if (buf.length == 0) return 0;
406 for (;;) {
407 auto res = gnutls_record_recv(session, buf.ptr, buf.length);
408 if (res >= 0 || !isblocking) {
409 import core.stdc.errno;
410 if (res == GNUTLS_E_INTERRUPTED) res = EINTR;
411 else if (res == GNUTLS_E_AGAIN) res = EAGAIN;
412 return res;
414 if (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN) continue;
415 if (gnutls_error_is_fatal(res)) return res;
416 return res;
420 override ptrdiff_t receive (void[] buf) { return receive(buf, SocketFlags.NONE); }
422 private this () pure nothrow @safe {}
424 this (AddressFamily af, SocketType type=SocketType.STREAM) {
425 super(af, type);
428 this (socket_t sock, AddressFamily af) {
429 super(sock, af);