Linux 2.2.0
[davej-history.git] / fs / smbfs / sock.c
blob5d8ddc96243912969e16044adf78f626dee0f16e
1 /*
2 * sock.c
4 * Copyright (C) 1995, 1996 by Paal-Kr. Engstad and Volker Lendecke
5 * Copyright (C) 1997 by Volker Lendecke
7 */
9 #include <linux/sched.h>
10 #include <linux/errno.h>
11 #include <linux/socket.h>
12 #include <linux/fcntl.h>
13 #include <linux/file.h>
14 #include <linux/in.h>
15 #include <linux/net.h>
16 #include <linux/mm.h>
17 #include <linux/netdevice.h>
18 #include <net/scm.h>
19 #include <net/ip.h>
21 #include <linux/smb_fs.h>
22 #include <linux/smb.h>
23 #include <linux/smbno.h>
25 #include <asm/uaccess.h>
27 #define SMBFS_PARANOIA 1
28 /* #define SMBFS_DEBUG_VERBOSE 1 */
30 static int
31 _recvfrom(struct socket *socket, unsigned char *ubuf, int size,
32 unsigned flags)
34 struct iovec iov;
35 struct msghdr msg;
36 struct scm_cookie scm;
38 msg.msg_name = NULL;
39 msg.msg_namelen = 0;
40 msg.msg_iov = &iov;
41 msg.msg_iovlen = 1;
42 msg.msg_control = NULL;
43 iov.iov_base = ubuf;
44 iov.iov_len = size;
46 memset(&scm, 0,sizeof(scm));
47 size=socket->ops->recvmsg(socket, &msg, size, flags, &scm);
48 if(size>=0)
49 scm_recv(socket,&msg,&scm,flags);
50 return size;
53 static int
54 _send(struct socket *socket, const void *buff, int len)
56 struct iovec iov;
57 struct msghdr msg;
58 struct scm_cookie scm;
59 int err;
61 msg.msg_name = NULL;
62 msg.msg_namelen = 0;
63 msg.msg_iov = &iov;
64 msg.msg_iovlen = 1;
65 msg.msg_control = NULL;
66 msg.msg_controllen = 0;
68 iov.iov_base = (void *)buff;
69 iov.iov_len = len;
71 msg.msg_flags = 0;
73 err = scm_send(socket, &msg, &scm);
74 if (err >= 0)
76 err = socket->ops->sendmsg(socket, &msg, len, &scm);
77 scm_destroy(&scm);
79 return err;
83 * N.B. What happens if we're in here when the socket closes??
85 static void
86 smb_data_callback(struct sock *sk, int len)
88 struct socket *socket = sk->socket;
89 unsigned char peek_buf[4];
90 int result;
91 mm_segment_t fs;
93 fs = get_fs();
94 set_fs(get_ds());
96 while (1)
98 result = -EIO;
99 if (sk->dead)
101 #ifdef SMBFS_PARANOIA
102 printk("smb_data_callback: sock dead!\n");
103 #endif
104 break;
107 result = _recvfrom(socket, (void *) peek_buf, 1,
108 MSG_PEEK | MSG_DONTWAIT);
109 if (result == -EAGAIN)
110 break;
111 if (peek_buf[0] != 0x85)
112 break;
114 /* got SESSION KEEP ALIVE */
115 result = _recvfrom(socket, (void *) peek_buf, 4,
116 MSG_DONTWAIT);
118 pr_debug("smb_data_callback: got SESSION KEEPALIVE\n");
120 if (result == -EAGAIN)
121 break;
123 set_fs(fs);
125 if (result != -EAGAIN)
127 wake_up_interruptible(sk->sleep);
132 smb_valid_socket(struct inode * inode)
134 return (inode && S_ISSOCK(inode->i_mode) &&
135 inode->u.socket_i.type == SOCK_STREAM);
138 static struct socket *
139 server_sock(struct smb_sb_info *server)
141 struct file *file;
143 if (server && (file = server->sock_file))
145 #ifdef SMBFS_PARANOIA
146 if (!smb_valid_socket(file->f_dentry->d_inode))
147 printk("smb_server_sock: bad socket!\n");
148 #endif
149 return &file->f_dentry->d_inode->u.socket_i;
151 return NULL;
155 smb_catch_keepalive(struct smb_sb_info *server)
157 struct socket *socket;
158 struct sock *sk;
159 void *data_ready;
160 int error;
162 error = -EINVAL;
163 socket = server_sock(server);
164 if (!socket)
166 printk("smb_catch_keepalive: did not get valid server!\n");
167 server->data_ready = NULL;
168 goto out;
171 sk = socket->sk;
172 if (sk == NULL)
174 pr_debug("smb_catch_keepalive: sk == NULL");
175 server->data_ready = NULL;
176 goto out;
178 pr_debug("smb_catch_keepalive.: sk->d_r = %x, server->d_r = %x\n",
179 (unsigned int) (sk->data_ready),
180 (unsigned int) (server->data_ready));
183 * Install the callback atomically to avoid races ...
185 data_ready = xchg(&sk->data_ready, smb_data_callback);
186 if (data_ready != smb_data_callback)
188 server->data_ready = data_ready;
189 error = 0;
190 } else
191 printk(KERN_ERR "smb_catch_keepalive: already done\n");
192 out:
193 return error;
197 smb_dont_catch_keepalive(struct smb_sb_info *server)
199 struct socket *socket;
200 struct sock *sk;
201 void * data_ready;
202 int error;
204 error = -EINVAL;
205 socket = server_sock(server);
206 if (!socket)
208 printk("smb_dont_catch_keepalive: did not get valid server!\n");
209 goto out;
212 sk = socket->sk;
213 if (sk == NULL)
215 printk("smb_dont_catch_keepalive: sk == NULL");
216 goto out;
219 /* Is this really an error?? */
220 if (server->data_ready == NULL)
222 printk("smb_dont_catch_keepalive: "
223 "server->data_ready == NULL\n");
224 goto out;
226 pr_debug("smb_dont_catch_keepalive: sk->d_r = %x, server->d_r = %x\n",
227 (unsigned int) (sk->data_ready),
228 (unsigned int) (server->data_ready));
231 * Restore the original callback atomically to avoid races ...
233 data_ready = xchg(&sk->data_ready, server->data_ready);
234 server->data_ready = NULL;
235 if (data_ready != smb_data_callback)
237 printk("smb_dont_catch_keepalive: "
238 "sk->data_callback != smb_data_callback\n");
240 error = 0;
241 out:
242 return error;
246 * Called with the server locked.
248 void
249 smb_close_socket(struct smb_sb_info *server)
251 struct file * file = server->sock_file;
253 if (file)
255 #ifdef SMBFS_DEBUG_VERBOSE
256 printk("smb_close_socket: closing socket %p\n", server_sock(server));
257 #endif
258 #ifdef SMBFS_PARANOIA
259 if (server_sock(server)->sk->data_ready == smb_data_callback)
260 printk("smb_close_socket: still catching keepalives!\n");
261 #endif
262 server->sock_file = NULL;
263 fput(file);
267 static int
268 smb_send_raw(struct socket *socket, unsigned char *source, int length)
270 int result;
271 int already_sent = 0;
273 while (already_sent < length)
275 result = _send(socket,
276 (void *) (source + already_sent),
277 length - already_sent);
279 if (result == 0)
281 return -EIO;
283 if (result < 0)
285 pr_debug("smb_send_raw: sendto error = %d\n",
286 -result);
287 return result;
289 already_sent += result;
291 return already_sent;
294 static int
295 smb_receive_raw(struct socket *socket, unsigned char *target, int length)
297 int result;
298 int already_read = 0;
300 while (already_read < length)
302 result = _recvfrom(socket,
303 (void *) (target + already_read),
304 length - already_read, 0);
306 if (result == 0)
308 return -EIO;
310 if (result < 0)
312 pr_debug("smb_receive_raw: recvfrom error = %d\n",
313 -result);
314 return result;
316 already_read += result;
318 return already_read;
321 static int
322 smb_get_length(struct socket *socket, unsigned char *header)
324 int result;
325 unsigned char peek_buf[4];
326 mm_segment_t fs;
328 re_recv:
329 fs = get_fs();
330 set_fs(get_ds());
331 result = smb_receive_raw(socket, peek_buf, 4);
332 set_fs(fs);
334 if (result < 0)
336 #ifdef SMBFS_PARANOIA
337 printk("smb_get_length: recv error = %d\n", -result);
338 #endif
339 return result;
341 switch (peek_buf[0])
343 case 0x00:
344 case 0x82:
345 break;
347 case 0x85:
348 pr_debug("smb_get_length: Got SESSION KEEP ALIVE\n");
349 goto re_recv;
351 default:
352 #ifdef SMBFS_PARANOIA
353 printk("smb_get_length: Invalid NBT packet, code=%x\n", peek_buf[0]);
354 #endif
355 return -EIO;
358 if (header != NULL)
360 memcpy(header, peek_buf, 4);
362 /* The length in the RFC NB header is the raw data length */
363 return smb_len(peek_buf);
367 * Since we allocate memory in increments of PAGE_SIZE,
368 * round up the packet length to the next multiple.
371 smb_round_length(int len)
373 return (len + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1);
377 * smb_receive
378 * fs points to the correct segment
380 static int
381 smb_receive(struct smb_sb_info *server)
383 struct socket *socket = server_sock(server);
384 unsigned char * packet = server->packet;
385 int len, result;
386 unsigned char peek_buf[4];
388 result = smb_get_length(socket, peek_buf);
389 if (result < 0)
390 goto out;
391 len = result;
393 * Some servers do not respect our max_xmit and send
394 * larger packets. Try to allocate a new packet,
395 * but don't free the old one unless we succeed.
397 if (len + 4 > server->packet_size)
399 int new_len = smb_round_length(len + 4);
401 result = -ENOMEM;
402 packet = smb_vmalloc(new_len);
403 if (packet == NULL)
404 goto out;
405 smb_vfree(server->packet);
406 server->packet = packet;
407 server->packet_size = new_len;
409 memcpy(packet, peek_buf, 4);
410 result = smb_receive_raw(socket, packet + 4, len);
411 if (result < 0)
413 #ifdef SMBFS_DEBUG_VERBOSE
414 printk("smb_receive: receive error: %d\n", result);
415 #endif
416 goto out;
418 server->rcls = *(packet + smb_rcls);
419 server->err = WVAL(packet, smb_err);
421 #ifdef SMBFS_DEBUG_VERBOSE
422 if (server->rcls != 0)
423 printk("smb_receive: rcls=%d, err=%d\n", server->rcls, server->err);
424 #endif
425 out:
426 return result;
430 * This routine checks first for "fast track" processing, as most
431 * packets won't need to be copied. Otherwise, it allocates a new
432 * packet to hold the incoming data.
434 * Note that the final server packet must be the larger of the two;
435 * server packets aren't allowed to shrink.
437 static int
438 smb_receive_trans2(struct smb_sb_info *server,
439 int *ldata, unsigned char **data,
440 int *lparm, unsigned char **parm)
442 unsigned char *inbuf, *base, *rcv_buf = NULL;
443 unsigned int parm_disp, parm_offset, parm_count, parm_tot, parm_len = 0;
444 unsigned int data_disp, data_offset, data_count, data_tot, data_len = 0;
445 unsigned int total_p = 0, total_d = 0, buf_len = 0;
446 int result;
448 while (1)
450 result = smb_receive(server);
451 if (result < 0)
452 goto out;
453 inbuf = server->packet;
454 if (server->rcls != 0)
456 *parm = *data = inbuf;
457 *ldata = *lparm = 0;
458 goto out;
461 * Extract the control data from the packet.
463 data_tot = WVAL(inbuf, smb_tdrcnt);
464 parm_tot = WVAL(inbuf, smb_tprcnt);
465 parm_disp = WVAL(inbuf, smb_prdisp);
466 parm_offset = WVAL(inbuf, smb_proff);
467 parm_count = WVAL(inbuf, smb_prcnt);
468 data_disp = WVAL(inbuf, smb_drdisp);
469 data_offset = WVAL(inbuf, smb_droff);
470 data_count = WVAL(inbuf, smb_drcnt);
471 base = smb_base(inbuf);
474 * Assume success and increment lengths.
476 parm_len += parm_count;
477 data_len += data_count;
479 if (!rcv_buf)
482 * Check for fast track processing ... just this packet.
484 if (parm_count == parm_tot && data_count == data_tot)
486 #ifdef SMBFS_DEBUG_VERBOSE
487 printk("smb_receive_trans2: fast track, parm=%u %u %u, data=%u %u %u\n",
488 parm_disp, parm_offset, parm_count, data_disp, data_offset, data_count);
489 #endif
490 *parm = base + parm_offset;
491 *data = base + data_offset;
492 goto success;
495 if (parm_tot > TRANS2_MAX_TRANSFER ||
496 data_tot > TRANS2_MAX_TRANSFER)
497 goto out_too_long;
500 * Save the total parameter and data length.
502 total_d = data_tot;
503 total_p = parm_tot;
505 buf_len = total_d + total_p;
506 if (server->packet_size > buf_len)
507 buf_len = server->packet_size;
508 buf_len = smb_round_length(buf_len);
510 rcv_buf = smb_vmalloc(buf_len);
511 if (!rcv_buf)
512 goto out_no_mem;
513 *parm = rcv_buf;
514 *data = rcv_buf + total_p;
516 else if (data_tot > total_d || parm_tot > total_p)
517 goto out_data_grew;
519 if (parm_disp + parm_count > total_p)
520 goto out_bad_parm;
521 if (data_disp + data_count > total_d)
522 goto out_bad_data;
523 memcpy(*parm + parm_disp, base + parm_offset, parm_count);
524 memcpy(*data + data_disp, base + data_offset, data_count);
526 #ifdef SMBFS_PARANOIA
527 printk("smb_receive_trans2: copied, parm=%u of %u, data=%u of %u\n",
528 parm_len, parm_tot, data_len, data_tot);
529 #endif
531 * Check whether we've received all of the data. Note that
532 * we use the packet totals -- total lengths might shrink!
534 if (data_len >= data_tot && parm_len >= parm_tot)
535 break;
539 * Install the new packet. Note that it's possible, though
540 * unlikely, that the new packet could be smaller than the
541 * old one, in which case we just copy the data.
543 inbuf = server->packet;
544 if (buf_len >= server->packet_size)
546 server->packet_size = buf_len;
547 server->packet = rcv_buf;
548 rcv_buf = inbuf;
549 } else
551 #ifdef SMBFS_PARANOIA
552 printk("smb_receive_trans2: copying data, old size=%d, new size=%u\n",
553 server->packet_size, buf_len);
554 #endif
555 memcpy(inbuf, rcv_buf, parm_len + data_len);
558 success:
559 *ldata = data_len;
560 *lparm = parm_len;
561 out:
562 if (rcv_buf)
563 smb_vfree(rcv_buf);
564 return result;
566 out_no_mem:
567 #ifdef SMBFS_PARANOIA
568 printk("smb_receive_trans2: couldn't allocate data area\n");
569 #endif
570 result = -ENOMEM;
571 goto out;
572 out_too_long:
573 printk("smb_receive_trans2: data/param too long, data=%d, parm=%d\n",
574 data_tot, parm_tot);
575 goto out_error;
576 out_data_grew:
577 printk("smb_receive_trans2: data/params grew!\n");
578 goto out_error;
579 out_bad_parm:
580 printk("smb_receive_trans2: invalid parms, disp=%d, cnt=%d, tot=%d\n",
581 parm_disp, parm_count, parm_tot);
582 goto out_error;
583 out_bad_data:
584 printk("smb_receive_trans2: invalid data, disp=%d, cnt=%d, tot=%d\n",
585 data_disp, data_count, data_tot);
586 out_error:
587 result = -EIO;
588 goto out;
592 * Called with the server locked
595 smb_request(struct smb_sb_info *server)
597 unsigned long flags, sigpipe;
598 mm_segment_t fs;
599 sigset_t old_set;
600 int len, result;
601 unsigned char *buffer;
603 result = -EBADF;
604 buffer = server->packet;
605 if (!buffer)
606 goto bad_no_packet;
608 result = -EIO;
609 if (server->state != CONN_VALID)
610 goto bad_no_conn;
612 if ((result = smb_dont_catch_keepalive(server)) != 0)
613 goto bad_conn;
615 len = smb_len(buffer) + 4;
616 pr_debug("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]);
618 spin_lock_irqsave(&current->sigmask_lock, flags);
619 sigpipe = sigismember(&current->signal, SIGPIPE);
620 old_set = current->blocked;
621 siginitsetinv(&current->blocked, sigmask(SIGKILL)|sigmask(SIGSTOP));
622 recalc_sigpending(current);
623 spin_unlock_irqrestore(&current->sigmask_lock, flags);
625 fs = get_fs();
626 set_fs(get_ds());
628 result = smb_send_raw(server_sock(server), (void *) buffer, len);
629 if (result > 0)
631 result = smb_receive(server);
634 /* read/write errors are handled by errno */
635 spin_lock_irqsave(&current->sigmask_lock, flags);
636 if (result == -EPIPE && !sigpipe)
637 sigdelset(&current->signal, SIGPIPE);
638 current->blocked = old_set;
639 recalc_sigpending(current);
640 spin_unlock_irqrestore(&current->sigmask_lock, flags);
642 set_fs(fs);
644 if (result >= 0)
646 int result2 = smb_catch_keepalive(server);
647 if (result2 < 0)
649 printk("smb_request: catch keepalive failed\n");
650 result = result2;
653 if (result < 0)
654 goto bad_conn;
656 * Check for fatal server errors ...
658 if (server->rcls) {
659 int error = smb_errno(server);
660 if (error == EBADSLT) {
661 printk("smb_request: tree ID invalid\n");
662 result = error;
663 goto bad_conn;
667 out:
668 pr_debug("smb_request: result = %d\n", result);
669 return result;
671 bad_conn:
672 #ifdef SMBFS_PARANOIA
673 printk("smb_request: result %d, setting invalid\n", result);
674 #endif
675 server->state = CONN_INVALID;
676 smb_invalidate_inodes(server);
677 goto out;
678 bad_no_packet:
679 printk("smb_request: no packet!\n");
680 goto out;
681 bad_no_conn:
682 printk("smb_request: connection %d not valid!\n", server->state);
683 goto out;
686 #define ROUND_UP(x) (((x)+3) & ~3)
687 static int
688 smb_send_trans2(struct smb_sb_info *server, __u16 trans2_command,
689 int ldata, unsigned char *data,
690 int lparam, unsigned char *param)
692 struct socket *sock = server_sock(server);
693 struct scm_cookie scm;
694 int err;
696 /* I know the following is very ugly, but I want to build the
697 smb packet as efficiently as possible. */
699 const int smb_parameters = 15;
700 const int oparam =
701 ROUND_UP(SMB_HEADER_LEN + 2 * smb_parameters + 2 + 3);
702 const int odata =
703 ROUND_UP(oparam + lparam);
704 const int bcc =
705 odata + ldata - (SMB_HEADER_LEN + 2 * smb_parameters + 2);
706 const int packet_length =
707 SMB_HEADER_LEN + 2 * smb_parameters + bcc + 2;
709 unsigned char padding[4] =
710 {0,};
711 char *p;
713 struct iovec iov[4];
714 struct msghdr msg;
716 /* N.B. This test isn't valid! packet_size may be < max_xmit */
717 if ((bcc + oparam) > server->opt.max_xmit)
719 return -ENOMEM;
721 p = smb_setup_header(server, SMBtrans2, smb_parameters, bcc);
723 WSET(server->packet, smb_tpscnt, lparam);
724 WSET(server->packet, smb_tdscnt, ldata);
725 /* N.B. these values should reflect out current packet size */
726 WSET(server->packet, smb_mprcnt, TRANS2_MAX_TRANSFER);
727 WSET(server->packet, smb_mdrcnt, TRANS2_MAX_TRANSFER);
728 WSET(server->packet, smb_msrcnt, 0);
729 WSET(server->packet, smb_flags, 0);
730 DSET(server->packet, smb_timeout, 0);
731 WSET(server->packet, smb_pscnt, lparam);
732 WSET(server->packet, smb_psoff, oparam - 4);
733 WSET(server->packet, smb_dscnt, ldata);
734 WSET(server->packet, smb_dsoff, odata - 4);
735 WSET(server->packet, smb_suwcnt, 1);
736 WSET(server->packet, smb_setup0, trans2_command);
737 *p++ = 0; /* null smb_name for trans2 */
738 *p++ = 'D'; /* this was added because OS/2 does it */
739 *p++ = ' ';
742 msg.msg_name = NULL;
743 msg.msg_namelen = 0;
744 msg.msg_control = NULL;
745 msg.msg_controllen = 0;
746 msg.msg_iov = iov;
747 msg.msg_iovlen = 4;
748 msg.msg_flags = 0;
750 iov[0].iov_base = (void *) server->packet;
751 iov[0].iov_len = oparam;
752 iov[1].iov_base = (param == NULL) ? padding : param;
753 iov[1].iov_len = lparam;
754 iov[2].iov_base = padding;
755 iov[2].iov_len = odata - oparam - lparam;
756 iov[3].iov_base = (data == NULL) ? padding : data;
757 iov[3].iov_len = ldata;
759 err = scm_send(sock, &msg, &scm);
760 if (err >= 0)
762 err = sock->ops->sendmsg(sock, &msg, packet_length, &scm);
763 scm_destroy(&scm);
765 return err;
769 * This is not really a trans2 request, we assume that you only have
770 * one packet to send.
773 smb_trans2_request(struct smb_sb_info *server, __u16 trans2_command,
774 int ldata, unsigned char *data,
775 int lparam, unsigned char *param,
776 int *lrdata, unsigned char **rdata,
777 int *lrparam, unsigned char **rparam)
779 sigset_t old_set;
780 unsigned long flags, sigpipe;
781 mm_segment_t fs;
782 int result;
784 pr_debug("smb_trans2_request: com=%d, ld=%d, lp=%d\n",
785 trans2_command, ldata, lparam);
788 * These are initialized in smb_request_ok, but not here??
790 server->rcls = 0;
791 server->err = 0;
793 result = -EIO;
794 if (server->state != CONN_VALID)
795 goto out;
797 if ((result = smb_dont_catch_keepalive(server)) != 0)
798 goto bad_conn;
800 spin_lock_irqsave(&current->sigmask_lock, flags);
801 sigpipe = sigismember(&current->signal, SIGPIPE);
802 old_set = current->blocked;
803 siginitsetinv(&current->blocked, sigmask(SIGKILL)|sigmask(SIGSTOP));
804 recalc_sigpending(current);
805 spin_unlock_irqrestore(&current->sigmask_lock, flags);
807 fs = get_fs();
808 set_fs(get_ds());
810 result = smb_send_trans2(server, trans2_command,
811 ldata, data, lparam, param);
812 if (result >= 0)
814 result = smb_receive_trans2(server,
815 lrdata, rdata, lrparam, rparam);
818 /* read/write errors are handled by errno */
819 spin_lock_irqsave(&current->sigmask_lock, flags);
820 if (result == -EPIPE && !sigpipe)
821 sigdelset(&current->signal, SIGPIPE);
822 current->blocked = old_set;
823 recalc_sigpending(current);
824 spin_unlock_irqrestore(&current->sigmask_lock, flags);
826 set_fs(fs);
828 if (result >= 0)
830 int result2 = smb_catch_keepalive(server);
831 if (result2 < 0)
833 result = result2;
836 if (result < 0)
837 goto bad_conn;
839 * Check for fatal server errors ...
841 if (server->rcls) {
842 int error = smb_errno(server);
843 if (error == EBADSLT) {
844 printk("smb_request: tree ID invalid\n");
845 result = error;
846 goto bad_conn;
850 out:
851 return result;
853 bad_conn:
854 #ifdef SMBFS_PARANOIA
855 printk("smb_trans2_request: result=%d, setting invalid\n", result);
856 #endif
857 server->state = CONN_INVALID;
858 smb_invalidate_inodes(server);
859 goto out;