2 * vhost transport for vsock
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
8 * This work is licensed under the terms of the GNU GPL, version 2.
10 #include <linux/miscdevice.h>
11 #include <linux/module.h>
12 #include <linux/mutex.h>
14 #include <linux/virtio_vsock.h>
15 #include <linux/vhost.h>
17 #include <net/af_vsock.h>
21 #define VHOST_VSOCK_DEFAULT_HOST_CID 2
23 static int vhost_transport_socket_init(struct vsock_sock
*vsk
,
24 struct vsock_sock
*psk
);
27 VHOST_VSOCK_FEATURES
= VHOST_FEATURES
,
30 /* Used to track all the vhost_vsock instances on the system. */
31 static LIST_HEAD(vhost_vsock_list
);
32 static DEFINE_MUTEX(vhost_vsock_mutex
);
34 struct vhost_vsock_virtqueue
{
35 struct vhost_virtqueue vq
;
41 /* Vhost vsock virtqueue*/
42 struct vhost_vsock_virtqueue vqs
[VSOCK_VQ_MAX
];
43 /* Link to global vhost_vsock_list*/
44 struct list_head list
;
45 /* Head for pkt from host to guest */
46 struct list_head send_pkt_list
;
47 /* Work item to send pkt */
48 struct vhost_work send_pkt_work
;
49 /* Wait queue for send pkt */
50 wait_queue_head_t queue_wait
;
51 /* Used for global tx buf limitation */
53 /* Guest contex id this vhost_vsock instance handles */
57 static u32
vhost_transport_get_local_cid(void)
59 return VHOST_VSOCK_DEFAULT_HOST_CID
;
62 static struct vhost_vsock
*vhost_vsock_get(u32 guest_cid
)
64 struct vhost_vsock
*vsock
;
66 mutex_lock(&vhost_vsock_mutex
);
67 list_for_each_entry(vsock
, &vhost_vsock_list
, list
) {
68 if (vsock
->guest_cid
== guest_cid
) {
69 mutex_unlock(&vhost_vsock_mutex
);
73 mutex_unlock(&vhost_vsock_mutex
);
79 vhost_transport_do_send_pkt(struct vhost_vsock
*vsock
,
80 struct vhost_virtqueue
*vq
)
84 mutex_lock(&vq
->mutex
);
85 vhost_disable_notify(&vsock
->dev
, vq
);
87 struct virtio_vsock_pkt
*pkt
;
88 struct iov_iter iov_iter
;
95 if (list_empty(&vsock
->send_pkt_list
)) {
96 vhost_enable_notify(&vsock
->dev
, vq
);
100 head
= vhost_get_vq_desc(vq
, vq
->iov
, ARRAY_SIZE(vq
->iov
),
101 &out
, &in
, NULL
, NULL
);
102 pr_debug("%s: head = %d\n", __func__
, head
);
106 if (head
== vq
->num
) {
107 if (unlikely(vhost_enable_notify(&vsock
->dev
, vq
))) {
108 vhost_disable_notify(&vsock
->dev
, vq
);
114 pkt
= list_first_entry(&vsock
->send_pkt_list
,
115 struct virtio_vsock_pkt
, list
);
116 list_del_init(&pkt
->list
);
119 virtio_transport_free_pkt(pkt
);
120 vq_err(vq
, "Expected 0 output buffers, got %u\n", out
);
124 len
= iov_length(&vq
->iov
[out
], in
);
125 iov_iter_init(&iov_iter
, READ
, &vq
->iov
[out
], in
, len
);
127 nbytes
= copy_to_iter(&pkt
->hdr
, sizeof(pkt
->hdr
), &iov_iter
);
128 if (nbytes
!= sizeof(pkt
->hdr
)) {
129 virtio_transport_free_pkt(pkt
);
130 vq_err(vq
, "Faulted on copying pkt hdr\n");
134 nbytes
= copy_to_iter(pkt
->buf
, pkt
->len
, &iov_iter
);
135 if (nbytes
!= pkt
->len
) {
136 virtio_transport_free_pkt(pkt
);
137 vq_err(vq
, "Faulted on copying pkt buf\n");
141 vhost_add_used(vq
, head
, pkt
->len
); /* TODO should this be sizeof(pkt->hdr) + pkt->len? */
144 virtio_transport_dec_tx_pkt(pkt
);
145 vsock
->total_tx_buf
-= pkt
->len
;
147 sk
= sk_vsock(pkt
->trans
->vsk
);
148 /* Release refcnt taken in vhost_transport_send_pkt */
151 virtio_transport_free_pkt(pkt
);
154 vhost_signal(&vsock
->dev
, vq
);
155 mutex_unlock(&vq
->mutex
);
158 wake_up(&vsock
->queue_wait
);
161 static void vhost_transport_send_pkt_work(struct vhost_work
*work
)
163 struct vhost_virtqueue
*vq
;
164 struct vhost_vsock
*vsock
;
166 vsock
= container_of(work
, struct vhost_vsock
, send_pkt_work
);
167 vq
= &vsock
->vqs
[VSOCK_VQ_RX
].vq
;
169 vhost_transport_do_send_pkt(vsock
, vq
);
173 vhost_transport_send_pkt(struct vsock_sock
*vsk
,
174 struct virtio_vsock_pkt_info
*info
)
176 u32 src_cid
, src_port
, dst_cid
, dst_port
;
177 struct virtio_transport
*trans
;
178 struct virtio_vsock_pkt
*pkt
;
179 struct vhost_virtqueue
*vq
;
180 struct vhost_vsock
*vsock
;
181 u32 pkt_len
= info
->pkt_len
;
184 src_cid
= vhost_transport_get_local_cid();
185 src_port
= vsk
->local_addr
.svm_port
;
186 if (!info
->remote_cid
) {
187 dst_cid
= vsk
->remote_addr
.svm_cid
;
188 dst_port
= vsk
->remote_addr
.svm_port
;
190 dst_cid
= info
->remote_cid
;
191 dst_port
= info
->remote_port
;
194 /* Find the vhost_vsock according to guest context id */
195 vsock
= vhost_vsock_get(dst_cid
);
200 vq
= &vsock
->vqs
[VSOCK_VQ_RX
].vq
;
202 /* we can send less than pkt_len bytes */
203 if (pkt_len
> VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE
)
204 pkt_len
= VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE
;
206 /* virtio_transport_get_credit might return less than pkt_len credit */
207 pkt_len
= virtio_transport_get_credit(trans
, pkt_len
);
209 /* Do not send zero length OP_RW pkt*/
210 if (pkt_len
== 0 && info
->op
== VIRTIO_VSOCK_OP_RW
)
213 /* Respect global tx buf limitation */
214 mutex_lock(&vq
->mutex
);
215 while (pkt_len
+ vsock
->total_tx_buf
> VIRTIO_VSOCK_MAX_TX_BUF_SIZE
) {
216 prepare_to_wait_exclusive(&vsock
->queue_wait
, &wait
,
217 TASK_UNINTERRUPTIBLE
);
218 mutex_unlock(&vq
->mutex
);
220 mutex_lock(&vq
->mutex
);
221 finish_wait(&vsock
->queue_wait
, &wait
);
223 vsock
->total_tx_buf
+= pkt_len
;
224 mutex_unlock(&vq
->mutex
);
226 pkt
= virtio_transport_alloc_pkt(vsk
, info
, pkt_len
,
230 mutex_lock(&vq
->mutex
);
231 vsock
->total_tx_buf
-= pkt_len
;
232 mutex_unlock(&vq
->mutex
);
233 virtio_transport_put_credit(trans
, pkt_len
);
237 pr_debug("%s:info->pkt_len= %d\n", __func__
, pkt_len
);
238 /* Released in vhost_transport_do_send_pkt */
239 sock_hold(&trans
->vsk
->sk
);
240 virtio_transport_inc_tx_pkt(pkt
);
242 /* Queue it up in vhost work */
243 mutex_lock(&vq
->mutex
);
244 list_add_tail(&pkt
->list
, &vsock
->send_pkt_list
);
245 vhost_work_queue(&vsock
->dev
, &vsock
->send_pkt_work
);
246 mutex_unlock(&vq
->mutex
);
251 static struct virtio_transport_pkt_ops vhost_ops
= {
252 .send_pkt
= vhost_transport_send_pkt
,
255 static struct virtio_vsock_pkt
*
256 vhost_vsock_alloc_pkt(struct vhost_virtqueue
*vq
,
257 unsigned int out
, unsigned int in
)
259 struct virtio_vsock_pkt
*pkt
;
260 struct iov_iter iov_iter
;
265 vq_err(vq
, "Expected 0 input buffers, got %u\n", in
);
269 pkt
= kzalloc(sizeof(*pkt
), GFP_KERNEL
);
273 len
= iov_length(vq
->iov
, out
);
274 iov_iter_init(&iov_iter
, WRITE
, vq
->iov
, out
, len
);
276 nbytes
= copy_from_iter(&pkt
->hdr
, sizeof(pkt
->hdr
), &iov_iter
);
277 if (nbytes
!= sizeof(pkt
->hdr
)) {
278 vq_err(vq
, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
279 sizeof(pkt
->hdr
), nbytes
);
284 if (le16_to_cpu(pkt
->hdr
.type
) == VIRTIO_VSOCK_TYPE_DGRAM
)
285 pkt
->len
= le32_to_cpu(pkt
->hdr
.len
) & 0XFFFF;
286 else if (le16_to_cpu(pkt
->hdr
.type
) == VIRTIO_VSOCK_TYPE_STREAM
)
287 pkt
->len
= le32_to_cpu(pkt
->hdr
.len
);
293 /* The pkt is too big */
294 if (pkt
->len
> VIRTIO_VSOCK_MAX_PKT_BUF_SIZE
) {
299 pkt
->buf
= kmalloc(pkt
->len
, GFP_KERNEL
);
305 nbytes
= copy_from_iter(pkt
->buf
, pkt
->len
, &iov_iter
);
306 if (nbytes
!= pkt
->len
) {
307 vq_err(vq
, "Expected %u byte payload, got %zu bytes\n",
309 virtio_transport_free_pkt(pkt
);
316 static void vhost_vsock_handle_ctl_kick(struct vhost_work
*work
)
318 struct vhost_virtqueue
*vq
= container_of(work
, struct vhost_virtqueue
,
320 struct vhost_vsock
*vsock
= container_of(vq
->dev
, struct vhost_vsock
,
323 pr_debug("%s vq=%p, vsock=%p\n", __func__
, vq
, vsock
);
326 static void vhost_vsock_handle_tx_kick(struct vhost_work
*work
)
328 struct vhost_virtqueue
*vq
= container_of(work
, struct vhost_virtqueue
,
330 struct vhost_vsock
*vsock
= container_of(vq
->dev
, struct vhost_vsock
,
332 struct virtio_vsock_pkt
*pkt
;
334 unsigned int out
, in
;
338 mutex_lock(&vq
->mutex
);
339 vhost_disable_notify(&vsock
->dev
, vq
);
341 head
= vhost_get_vq_desc(vq
, vq
->iov
, ARRAY_SIZE(vq
->iov
),
342 &out
, &in
, NULL
, NULL
);
346 if (head
== vq
->num
) {
347 if (unlikely(vhost_enable_notify(&vsock
->dev
, vq
))) {
348 vhost_disable_notify(&vsock
->dev
, vq
);
354 pkt
= vhost_vsock_alloc_pkt(vq
, out
, in
);
356 vq_err(vq
, "Faulted on pkt\n");
362 /* Only accept correctly addressed packets */
363 if (le32_to_cpu(pkt
->hdr
.src_cid
) == vsock
->guest_cid
&&
364 le32_to_cpu(pkt
->hdr
.dst_cid
) == vhost_transport_get_local_cid())
365 virtio_transport_recv_pkt(pkt
);
367 virtio_transport_free_pkt(pkt
);
369 vhost_add_used(vq
, head
, len
);
373 vhost_signal(&vsock
->dev
, vq
);
374 mutex_unlock(&vq
->mutex
);
377 static void vhost_vsock_handle_rx_kick(struct vhost_work
*work
)
379 struct vhost_virtqueue
*vq
= container_of(work
, struct vhost_virtqueue
,
381 struct vhost_vsock
*vsock
= container_of(vq
->dev
, struct vhost_vsock
,
384 vhost_transport_do_send_pkt(vsock
, vq
);
387 static int vhost_vsock_dev_open(struct inode
*inode
, struct file
*file
)
389 struct vhost_virtqueue
**vqs
;
390 struct vhost_vsock
*vsock
;
393 vsock
= kzalloc(sizeof(*vsock
), GFP_KERNEL
);
397 pr_debug("%s:vsock=%p\n", __func__
, vsock
);
399 vqs
= kmalloc(VSOCK_VQ_MAX
* sizeof(*vqs
), GFP_KERNEL
);
405 vqs
[VSOCK_VQ_CTRL
] = &vsock
->vqs
[VSOCK_VQ_CTRL
].vq
;
406 vqs
[VSOCK_VQ_TX
] = &vsock
->vqs
[VSOCK_VQ_TX
].vq
;
407 vqs
[VSOCK_VQ_RX
] = &vsock
->vqs
[VSOCK_VQ_RX
].vq
;
408 vsock
->vqs
[VSOCK_VQ_CTRL
].vq
.handle_kick
= vhost_vsock_handle_ctl_kick
;
409 vsock
->vqs
[VSOCK_VQ_TX
].vq
.handle_kick
= vhost_vsock_handle_tx_kick
;
410 vsock
->vqs
[VSOCK_VQ_RX
].vq
.handle_kick
= vhost_vsock_handle_rx_kick
;
412 vhost_dev_init(&vsock
->dev
, vqs
, VSOCK_VQ_MAX
);
414 file
->private_data
= vsock
;
415 init_waitqueue_head(&vsock
->queue_wait
);
416 INIT_LIST_HEAD(&vsock
->send_pkt_list
);
417 vhost_work_init(&vsock
->send_pkt_work
, vhost_transport_send_pkt_work
);
419 mutex_lock(&vhost_vsock_mutex
);
420 list_add_tail(&vsock
->list
, &vhost_vsock_list
);
421 mutex_unlock(&vhost_vsock_mutex
);
429 static void vhost_vsock_flush(struct vhost_vsock
*vsock
)
433 for (i
= 0; i
< VSOCK_VQ_MAX
; i
++)
434 vhost_poll_flush(&vsock
->vqs
[i
].vq
.poll
);
435 vhost_work_flush(&vsock
->dev
, &vsock
->send_pkt_work
);
438 static int vhost_vsock_dev_release(struct inode
*inode
, struct file
*file
)
440 struct vhost_vsock
*vsock
= file
->private_data
;
442 mutex_lock(&vhost_vsock_mutex
);
443 list_del(&vsock
->list
);
444 mutex_unlock(&vhost_vsock_mutex
);
446 vhost_dev_stop(&vsock
->dev
);
447 vhost_vsock_flush(vsock
);
448 vhost_dev_cleanup(&vsock
->dev
, false);
449 kfree(vsock
->dev
.vqs
);
454 static int vhost_vsock_set_cid(struct vhost_vsock
*vsock
, u32 guest_cid
)
456 struct vhost_vsock
*other
;
458 /* Refuse reserved CIDs */
459 if (guest_cid
<= VMADDR_CID_HOST
) {
463 /* Refuse if CID is already in use */
464 other
= vhost_vsock_get(guest_cid
);
465 if (other
&& other
!= vsock
) {
469 mutex_lock(&vhost_vsock_mutex
);
470 vsock
->guest_cid
= guest_cid
;
471 pr_debug("%s:guest_cid=%d\n", __func__
, guest_cid
);
472 mutex_unlock(&vhost_vsock_mutex
);
477 static int vhost_vsock_set_features(struct vhost_vsock
*vsock
, u64 features
)
479 struct vhost_virtqueue
*vq
;
482 if (features
& ~VHOST_VSOCK_FEATURES
)
485 mutex_lock(&vsock
->dev
.mutex
);
486 if ((features
& (1 << VHOST_F_LOG_ALL
)) &&
487 !vhost_log_access_ok(&vsock
->dev
)) {
488 mutex_unlock(&vsock
->dev
.mutex
);
492 for (i
= 0; i
< VSOCK_VQ_MAX
; i
++) {
493 vq
= &vsock
->vqs
[i
].vq
;
494 mutex_lock(&vq
->mutex
);
495 vq
->acked_features
= features
;
496 mutex_unlock(&vq
->mutex
);
498 mutex_unlock(&vsock
->dev
.mutex
);
502 static long vhost_vsock_dev_ioctl(struct file
*f
, unsigned int ioctl
,
505 struct vhost_vsock
*vsock
= f
->private_data
;
506 void __user
*argp
= (void __user
*)arg
;
507 u64 __user
*featurep
= argp
;
508 u32 __user
*cidp
= argp
;
514 case VHOST_VSOCK_SET_GUEST_CID
:
515 if (get_user(guest_cid
, cidp
))
517 return vhost_vsock_set_cid(vsock
, guest_cid
);
518 case VHOST_GET_FEATURES
:
519 features
= VHOST_VSOCK_FEATURES
;
520 if (copy_to_user(featurep
, &features
, sizeof(features
)))
523 case VHOST_SET_FEATURES
:
524 if (copy_from_user(&features
, featurep
, sizeof(features
)))
526 return vhost_vsock_set_features(vsock
, features
);
528 mutex_lock(&vsock
->dev
.mutex
);
529 r
= vhost_dev_ioctl(&vsock
->dev
, ioctl
, argp
);
530 if (r
== -ENOIOCTLCMD
)
531 r
= vhost_vring_ioctl(&vsock
->dev
, ioctl
, argp
);
533 vhost_vsock_flush(vsock
);
534 mutex_unlock(&vsock
->dev
.mutex
);
539 static const struct file_operations vhost_vsock_fops
= {
540 .owner
= THIS_MODULE
,
541 .open
= vhost_vsock_dev_open
,
542 .release
= vhost_vsock_dev_release
,
543 .llseek
= noop_llseek
,
544 .unlocked_ioctl
= vhost_vsock_dev_ioctl
,
547 static struct miscdevice vhost_vsock_misc
= {
548 .minor
= MISC_DYNAMIC_MINOR
,
549 .name
= "vhost-vsock",
550 .fops
= &vhost_vsock_fops
,
554 vhost_transport_socket_init(struct vsock_sock
*vsk
, struct vsock_sock
*psk
)
556 struct virtio_transport
*trans
;
559 ret
= virtio_transport_do_socket_init(vsk
, psk
);
564 trans
->ops
= &vhost_ops
;
569 static struct vsock_transport vhost_transport
= {
570 .get_local_cid
= vhost_transport_get_local_cid
,
572 .init
= vhost_transport_socket_init
,
573 .destruct
= virtio_transport_destruct
,
574 .release
= virtio_transport_release
,
575 .connect
= virtio_transport_connect
,
576 .shutdown
= virtio_transport_shutdown
,
578 .dgram_enqueue
= virtio_transport_dgram_enqueue
,
579 .dgram_dequeue
= virtio_transport_dgram_dequeue
,
580 .dgram_bind
= virtio_transport_dgram_bind
,
581 .dgram_allow
= virtio_transport_dgram_allow
,
583 .stream_enqueue
= virtio_transport_stream_enqueue
,
584 .stream_dequeue
= virtio_transport_stream_dequeue
,
585 .stream_has_data
= virtio_transport_stream_has_data
,
586 .stream_has_space
= virtio_transport_stream_has_space
,
587 .stream_rcvhiwat
= virtio_transport_stream_rcvhiwat
,
588 .stream_is_active
= virtio_transport_stream_is_active
,
589 .stream_allow
= virtio_transport_stream_allow
,
591 .notify_poll_in
= virtio_transport_notify_poll_in
,
592 .notify_poll_out
= virtio_transport_notify_poll_out
,
593 .notify_recv_init
= virtio_transport_notify_recv_init
,
594 .notify_recv_pre_block
= virtio_transport_notify_recv_pre_block
,
595 .notify_recv_pre_dequeue
= virtio_transport_notify_recv_pre_dequeue
,
596 .notify_recv_post_dequeue
= virtio_transport_notify_recv_post_dequeue
,
597 .notify_send_init
= virtio_transport_notify_send_init
,
598 .notify_send_pre_block
= virtio_transport_notify_send_pre_block
,
599 .notify_send_pre_enqueue
= virtio_transport_notify_send_pre_enqueue
,
600 .notify_send_post_enqueue
= virtio_transport_notify_send_post_enqueue
,
602 .set_buffer_size
= virtio_transport_set_buffer_size
,
603 .set_min_buffer_size
= virtio_transport_set_min_buffer_size
,
604 .set_max_buffer_size
= virtio_transport_set_max_buffer_size
,
605 .get_buffer_size
= virtio_transport_get_buffer_size
,
606 .get_min_buffer_size
= virtio_transport_get_min_buffer_size
,
607 .get_max_buffer_size
= virtio_transport_get_max_buffer_size
,
610 static int __init
vhost_vsock_init(void)
614 ret
= vsock_core_init(&vhost_transport
);
617 return misc_register(&vhost_vsock_misc
);
620 static void __exit
vhost_vsock_exit(void)
622 misc_deregister(&vhost_vsock_misc
);
626 module_init(vhost_vsock_init
);
627 module_exit(vhost_vsock_exit
);
628 MODULE_LICENSE("GPL v2");
629 MODULE_AUTHOR("Asias He");
630 MODULE_DESCRIPTION("vhost transport for vsock ");