1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
5 #include <linux/filter.h>
6 #include <linux/errno.h>
7 #include <linux/file.h>
9 #include <linux/workqueue.h>
10 #include <linux/skmsg.h>
11 #include <linux/list.h>
12 #include <linux/jhash.h>
17 struct sk_psock_progs progs
;
21 #define SOCK_CREATE_FLAG_MASK \
22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
24 static struct bpf_map
*sock_map_alloc(union bpf_attr
*attr
)
26 struct bpf_stab
*stab
;
30 if (!capable(CAP_NET_ADMIN
))
31 return ERR_PTR(-EPERM
);
32 if (attr
->max_entries
== 0 ||
33 attr
->key_size
!= 4 ||
34 attr
->value_size
!= 4 ||
35 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
36 return ERR_PTR(-EINVAL
);
38 stab
= kzalloc(sizeof(*stab
), GFP_USER
);
40 return ERR_PTR(-ENOMEM
);
42 bpf_map_init_from_attr(&stab
->map
, attr
);
43 raw_spin_lock_init(&stab
->lock
);
45 /* Make sure page count doesn't overflow. */
46 cost
= (u64
) stab
->map
.max_entries
* sizeof(struct sock
*);
47 err
= bpf_map_charge_init(&stab
->map
.memory
, cost
);
51 stab
->sks
= bpf_map_area_alloc(stab
->map
.max_entries
*
52 sizeof(struct sock
*),
57 bpf_map_charge_finish(&stab
->map
.memory
);
63 int sock_map_get_from_fd(const union bpf_attr
*attr
, struct bpf_prog
*prog
)
65 u32 ufd
= attr
->target_fd
;
71 map
= __bpf_map_get(f
);
74 ret
= sock_map_prog_update(map
, prog
, attr
->attach_type
);
79 static void sock_map_sk_acquire(struct sock
*sk
)
80 __acquires(&sk
->sk_lock
.slock
)
87 static void sock_map_sk_release(struct sock
*sk
)
88 __releases(&sk
->sk_lock
.slock
)
95 static void sock_map_add_link(struct sk_psock
*psock
,
96 struct sk_psock_link
*link
,
97 struct bpf_map
*map
, void *link_raw
)
99 link
->link_raw
= link_raw
;
101 spin_lock_bh(&psock
->link_lock
);
102 list_add_tail(&link
->list
, &psock
->link
);
103 spin_unlock_bh(&psock
->link_lock
);
106 static void sock_map_del_link(struct sock
*sk
,
107 struct sk_psock
*psock
, void *link_raw
)
109 struct sk_psock_link
*link
, *tmp
;
110 bool strp_stop
= false;
112 spin_lock_bh(&psock
->link_lock
);
113 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
114 if (link
->link_raw
== link_raw
) {
115 struct bpf_map
*map
= link
->map
;
116 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
,
118 if (psock
->parser
.enabled
&& stab
->progs
.skb_parser
)
120 list_del(&link
->list
);
121 sk_psock_free_link(link
);
124 spin_unlock_bh(&psock
->link_lock
);
126 write_lock_bh(&sk
->sk_callback_lock
);
127 sk_psock_stop_strp(sk
, psock
);
128 write_unlock_bh(&sk
->sk_callback_lock
);
132 static void sock_map_unref(struct sock
*sk
, void *link_raw
)
134 struct sk_psock
*psock
= sk_psock(sk
);
137 sock_map_del_link(sk
, psock
, link_raw
);
138 sk_psock_put(sk
, psock
);
142 static int sock_map_link(struct bpf_map
*map
, struct sk_psock_progs
*progs
,
145 struct bpf_prog
*msg_parser
, *skb_parser
, *skb_verdict
;
146 bool skb_progs
, sk_psock_is_new
= false;
147 struct sk_psock
*psock
;
150 skb_verdict
= READ_ONCE(progs
->skb_verdict
);
151 skb_parser
= READ_ONCE(progs
->skb_parser
);
152 skb_progs
= skb_parser
&& skb_verdict
;
154 skb_verdict
= bpf_prog_inc_not_zero(skb_verdict
);
155 if (IS_ERR(skb_verdict
))
156 return PTR_ERR(skb_verdict
);
157 skb_parser
= bpf_prog_inc_not_zero(skb_parser
);
158 if (IS_ERR(skb_parser
)) {
159 bpf_prog_put(skb_verdict
);
160 return PTR_ERR(skb_parser
);
164 msg_parser
= READ_ONCE(progs
->msg_parser
);
166 msg_parser
= bpf_prog_inc_not_zero(msg_parser
);
167 if (IS_ERR(msg_parser
)) {
168 ret
= PTR_ERR(msg_parser
);
173 psock
= sk_psock_get_checked(sk
);
175 ret
= PTR_ERR(psock
);
180 if ((msg_parser
&& READ_ONCE(psock
->progs
.msg_parser
)) ||
181 (skb_progs
&& READ_ONCE(psock
->progs
.skb_parser
))) {
182 sk_psock_put(sk
, psock
);
187 psock
= sk_psock_init(sk
, map
->numa_node
);
192 sk_psock_is_new
= true;
196 psock_set_prog(&psock
->progs
.msg_parser
, msg_parser
);
197 if (sk_psock_is_new
) {
198 ret
= tcp_bpf_init(sk
);
205 write_lock_bh(&sk
->sk_callback_lock
);
206 if (skb_progs
&& !psock
->parser
.enabled
) {
207 ret
= sk_psock_init_strp(sk
, psock
);
209 write_unlock_bh(&sk
->sk_callback_lock
);
212 psock_set_prog(&psock
->progs
.skb_verdict
, skb_verdict
);
213 psock_set_prog(&psock
->progs
.skb_parser
, skb_parser
);
214 sk_psock_start_strp(sk
, psock
);
216 write_unlock_bh(&sk
->sk_callback_lock
);
219 sk_psock_put(sk
, psock
);
222 bpf_prog_put(msg_parser
);
225 bpf_prog_put(skb_verdict
);
226 bpf_prog_put(skb_parser
);
231 static void sock_map_free(struct bpf_map
*map
)
233 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
238 raw_spin_lock_bh(&stab
->lock
);
239 for (i
= 0; i
< stab
->map
.max_entries
; i
++) {
240 struct sock
**psk
= &stab
->sks
[i
];
243 sk
= xchg(psk
, NULL
);
245 sock_map_unref(sk
, psk
);
247 raw_spin_unlock_bh(&stab
->lock
);
252 bpf_map_area_free(stab
->sks
);
256 static void sock_map_release_progs(struct bpf_map
*map
)
258 psock_progs_drop(&container_of(map
, struct bpf_stab
, map
)->progs
);
261 static struct sock
*__sock_map_lookup_elem(struct bpf_map
*map
, u32 key
)
263 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
265 WARN_ON_ONCE(!rcu_read_lock_held());
267 if (unlikely(key
>= map
->max_entries
))
269 return READ_ONCE(stab
->sks
[key
]);
272 static void *sock_map_lookup(struct bpf_map
*map
, void *key
)
274 return ERR_PTR(-EOPNOTSUPP
);
277 static int __sock_map_delete(struct bpf_stab
*stab
, struct sock
*sk_test
,
283 raw_spin_lock_bh(&stab
->lock
);
285 if (!sk_test
|| sk_test
== sk
)
286 sk
= xchg(psk
, NULL
);
289 sock_map_unref(sk
, psk
);
293 raw_spin_unlock_bh(&stab
->lock
);
297 static void sock_map_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
300 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
302 __sock_map_delete(stab
, sk
, link_raw
);
305 static int sock_map_delete_elem(struct bpf_map
*map
, void *key
)
307 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
311 if (unlikely(i
>= map
->max_entries
))
315 return __sock_map_delete(stab
, NULL
, psk
);
318 static int sock_map_get_next_key(struct bpf_map
*map
, void *key
, void *next
)
320 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
321 u32 i
= key
? *(u32
*)key
: U32_MAX
;
322 u32
*key_next
= next
;
324 if (i
== stab
->map
.max_entries
- 1)
326 if (i
>= stab
->map
.max_entries
)
333 static int sock_map_update_common(struct bpf_map
*map
, u32 idx
,
334 struct sock
*sk
, u64 flags
)
336 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
337 struct inet_connection_sock
*icsk
= inet_csk(sk
);
338 struct sk_psock_link
*link
;
339 struct sk_psock
*psock
;
343 WARN_ON_ONCE(!rcu_read_lock_held());
344 if (unlikely(flags
> BPF_EXIST
))
346 if (unlikely(idx
>= map
->max_entries
))
348 if (unlikely(rcu_access_pointer(icsk
->icsk_ulp_data
)))
351 link
= sk_psock_init_link();
355 ret
= sock_map_link(map
, &stab
->progs
, sk
);
359 psock
= sk_psock(sk
);
360 WARN_ON_ONCE(!psock
);
362 raw_spin_lock_bh(&stab
->lock
);
363 osk
= stab
->sks
[idx
];
364 if (osk
&& flags
== BPF_NOEXIST
) {
367 } else if (!osk
&& flags
== BPF_EXIST
) {
372 sock_map_add_link(psock
, link
, map
, &stab
->sks
[idx
]);
375 sock_map_unref(osk
, &stab
->sks
[idx
]);
376 raw_spin_unlock_bh(&stab
->lock
);
379 raw_spin_unlock_bh(&stab
->lock
);
381 sk_psock_put(sk
, psock
);
383 sk_psock_free_link(link
);
387 static bool sock_map_op_okay(const struct bpf_sock_ops_kern
*ops
)
389 return ops
->op
== BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB
||
390 ops
->op
== BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB
;
393 static bool sock_map_sk_is_suitable(const struct sock
*sk
)
395 return sk
->sk_type
== SOCK_STREAM
&&
396 sk
->sk_protocol
== IPPROTO_TCP
;
399 static int sock_map_update_elem(struct bpf_map
*map
, void *key
,
400 void *value
, u64 flags
)
402 u32 ufd
= *(u32
*)value
;
403 u32 idx
= *(u32
*)key
;
408 sock
= sockfd_lookup(ufd
, &ret
);
416 if (!sock_map_sk_is_suitable(sk
) ||
417 sk
->sk_state
!= TCP_ESTABLISHED
) {
422 sock_map_sk_acquire(sk
);
423 ret
= sock_map_update_common(map
, idx
, sk
, flags
);
424 sock_map_sk_release(sk
);
430 BPF_CALL_4(bpf_sock_map_update
, struct bpf_sock_ops_kern
*, sops
,
431 struct bpf_map
*, map
, void *, key
, u64
, flags
)
433 WARN_ON_ONCE(!rcu_read_lock_held());
435 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
436 sock_map_op_okay(sops
)))
437 return sock_map_update_common(map
, *(u32
*)key
, sops
->sk
,
442 const struct bpf_func_proto bpf_sock_map_update_proto
= {
443 .func
= bpf_sock_map_update
,
446 .ret_type
= RET_INTEGER
,
447 .arg1_type
= ARG_PTR_TO_CTX
,
448 .arg2_type
= ARG_CONST_MAP_PTR
,
449 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
450 .arg4_type
= ARG_ANYTHING
,
453 BPF_CALL_4(bpf_sk_redirect_map
, struct sk_buff
*, skb
,
454 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
456 struct tcp_skb_cb
*tcb
= TCP_SKB_CB(skb
);
458 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
460 tcb
->bpf
.flags
= flags
;
461 tcb
->bpf
.sk_redir
= __sock_map_lookup_elem(map
, key
);
462 if (!tcb
->bpf
.sk_redir
)
467 const struct bpf_func_proto bpf_sk_redirect_map_proto
= {
468 .func
= bpf_sk_redirect_map
,
470 .ret_type
= RET_INTEGER
,
471 .arg1_type
= ARG_PTR_TO_CTX
,
472 .arg2_type
= ARG_CONST_MAP_PTR
,
473 .arg3_type
= ARG_ANYTHING
,
474 .arg4_type
= ARG_ANYTHING
,
477 BPF_CALL_4(bpf_msg_redirect_map
, struct sk_msg
*, msg
,
478 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
480 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
483 msg
->sk_redir
= __sock_map_lookup_elem(map
, key
);
489 const struct bpf_func_proto bpf_msg_redirect_map_proto
= {
490 .func
= bpf_msg_redirect_map
,
492 .ret_type
= RET_INTEGER
,
493 .arg1_type
= ARG_PTR_TO_CTX
,
494 .arg2_type
= ARG_CONST_MAP_PTR
,
495 .arg3_type
= ARG_ANYTHING
,
496 .arg4_type
= ARG_ANYTHING
,
499 const struct bpf_map_ops sock_map_ops
= {
500 .map_alloc
= sock_map_alloc
,
501 .map_free
= sock_map_free
,
502 .map_get_next_key
= sock_map_get_next_key
,
503 .map_update_elem
= sock_map_update_elem
,
504 .map_delete_elem
= sock_map_delete_elem
,
505 .map_lookup_elem
= sock_map_lookup
,
506 .map_release_uref
= sock_map_release_progs
,
507 .map_check_btf
= map_check_no_btf
,
510 struct bpf_htab_elem
{
514 struct hlist_node node
;
518 struct bpf_htab_bucket
{
519 struct hlist_head head
;
525 struct bpf_htab_bucket
*buckets
;
528 struct sk_psock_progs progs
;
532 static inline u32
sock_hash_bucket_hash(const void *key
, u32 len
)
534 return jhash(key
, len
, 0);
537 static struct bpf_htab_bucket
*sock_hash_select_bucket(struct bpf_htab
*htab
,
540 return &htab
->buckets
[hash
& (htab
->buckets_num
- 1)];
543 static struct bpf_htab_elem
*
544 sock_hash_lookup_elem_raw(struct hlist_head
*head
, u32 hash
, void *key
,
547 struct bpf_htab_elem
*elem
;
549 hlist_for_each_entry_rcu(elem
, head
, node
) {
550 if (elem
->hash
== hash
&&
551 !memcmp(&elem
->key
, key
, key_size
))
558 static struct sock
*__sock_hash_lookup_elem(struct bpf_map
*map
, void *key
)
560 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
561 u32 key_size
= map
->key_size
, hash
;
562 struct bpf_htab_bucket
*bucket
;
563 struct bpf_htab_elem
*elem
;
565 WARN_ON_ONCE(!rcu_read_lock_held());
567 hash
= sock_hash_bucket_hash(key
, key_size
);
568 bucket
= sock_hash_select_bucket(htab
, hash
);
569 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
571 return elem
? elem
->sk
: NULL
;
574 static void sock_hash_free_elem(struct bpf_htab
*htab
,
575 struct bpf_htab_elem
*elem
)
577 atomic_dec(&htab
->count
);
578 kfree_rcu(elem
, rcu
);
581 static void sock_hash_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
584 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
585 struct bpf_htab_elem
*elem_probe
, *elem
= link_raw
;
586 struct bpf_htab_bucket
*bucket
;
588 WARN_ON_ONCE(!rcu_read_lock_held());
589 bucket
= sock_hash_select_bucket(htab
, elem
->hash
);
591 /* elem may be deleted in parallel from the map, but access here
592 * is okay since it's going away only after RCU grace period.
593 * However, we need to check whether it's still present.
595 raw_spin_lock_bh(&bucket
->lock
);
596 elem_probe
= sock_hash_lookup_elem_raw(&bucket
->head
, elem
->hash
,
597 elem
->key
, map
->key_size
);
598 if (elem_probe
&& elem_probe
== elem
) {
599 hlist_del_rcu(&elem
->node
);
600 sock_map_unref(elem
->sk
, elem
);
601 sock_hash_free_elem(htab
, elem
);
603 raw_spin_unlock_bh(&bucket
->lock
);
606 static int sock_hash_delete_elem(struct bpf_map
*map
, void *key
)
608 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
609 u32 hash
, key_size
= map
->key_size
;
610 struct bpf_htab_bucket
*bucket
;
611 struct bpf_htab_elem
*elem
;
614 hash
= sock_hash_bucket_hash(key
, key_size
);
615 bucket
= sock_hash_select_bucket(htab
, hash
);
617 raw_spin_lock_bh(&bucket
->lock
);
618 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
620 hlist_del_rcu(&elem
->node
);
621 sock_map_unref(elem
->sk
, elem
);
622 sock_hash_free_elem(htab
, elem
);
625 raw_spin_unlock_bh(&bucket
->lock
);
629 static struct bpf_htab_elem
*sock_hash_alloc_elem(struct bpf_htab
*htab
,
630 void *key
, u32 key_size
,
631 u32 hash
, struct sock
*sk
,
632 struct bpf_htab_elem
*old
)
634 struct bpf_htab_elem
*new;
636 if (atomic_inc_return(&htab
->count
) > htab
->map
.max_entries
) {
638 atomic_dec(&htab
->count
);
639 return ERR_PTR(-E2BIG
);
643 new = kmalloc_node(htab
->elem_size
, GFP_ATOMIC
| __GFP_NOWARN
,
644 htab
->map
.numa_node
);
646 atomic_dec(&htab
->count
);
647 return ERR_PTR(-ENOMEM
);
649 memcpy(new->key
, key
, key_size
);
655 static int sock_hash_update_common(struct bpf_map
*map
, void *key
,
656 struct sock
*sk
, u64 flags
)
658 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
659 struct inet_connection_sock
*icsk
= inet_csk(sk
);
660 u32 key_size
= map
->key_size
, hash
;
661 struct bpf_htab_elem
*elem
, *elem_new
;
662 struct bpf_htab_bucket
*bucket
;
663 struct sk_psock_link
*link
;
664 struct sk_psock
*psock
;
667 WARN_ON_ONCE(!rcu_read_lock_held());
668 if (unlikely(flags
> BPF_EXIST
))
670 if (unlikely(icsk
->icsk_ulp_data
))
673 link
= sk_psock_init_link();
677 ret
= sock_map_link(map
, &htab
->progs
, sk
);
681 psock
= sk_psock(sk
);
682 WARN_ON_ONCE(!psock
);
684 hash
= sock_hash_bucket_hash(key
, key_size
);
685 bucket
= sock_hash_select_bucket(htab
, hash
);
687 raw_spin_lock_bh(&bucket
->lock
);
688 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
689 if (elem
&& flags
== BPF_NOEXIST
) {
692 } else if (!elem
&& flags
== BPF_EXIST
) {
697 elem_new
= sock_hash_alloc_elem(htab
, key
, key_size
, hash
, sk
, elem
);
698 if (IS_ERR(elem_new
)) {
699 ret
= PTR_ERR(elem_new
);
703 sock_map_add_link(psock
, link
, map
, elem_new
);
704 /* Add new element to the head of the list, so that
705 * concurrent search will find it before old elem.
707 hlist_add_head_rcu(&elem_new
->node
, &bucket
->head
);
709 hlist_del_rcu(&elem
->node
);
710 sock_map_unref(elem
->sk
, elem
);
711 sock_hash_free_elem(htab
, elem
);
713 raw_spin_unlock_bh(&bucket
->lock
);
716 raw_spin_unlock_bh(&bucket
->lock
);
717 sk_psock_put(sk
, psock
);
719 sk_psock_free_link(link
);
723 static int sock_hash_update_elem(struct bpf_map
*map
, void *key
,
724 void *value
, u64 flags
)
726 u32 ufd
= *(u32
*)value
;
731 sock
= sockfd_lookup(ufd
, &ret
);
739 if (!sock_map_sk_is_suitable(sk
) ||
740 sk
->sk_state
!= TCP_ESTABLISHED
) {
745 sock_map_sk_acquire(sk
);
746 ret
= sock_hash_update_common(map
, key
, sk
, flags
);
747 sock_map_sk_release(sk
);
753 static int sock_hash_get_next_key(struct bpf_map
*map
, void *key
,
756 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
757 struct bpf_htab_elem
*elem
, *elem_next
;
758 u32 hash
, key_size
= map
->key_size
;
759 struct hlist_head
*head
;
763 goto find_first_elem
;
764 hash
= sock_hash_bucket_hash(key
, key_size
);
765 head
= &sock_hash_select_bucket(htab
, hash
)->head
;
766 elem
= sock_hash_lookup_elem_raw(head
, hash
, key
, key_size
);
768 goto find_first_elem
;
770 elem_next
= hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem
->node
)),
771 struct bpf_htab_elem
, node
);
773 memcpy(key_next
, elem_next
->key
, key_size
);
777 i
= hash
& (htab
->buckets_num
- 1);
780 for (; i
< htab
->buckets_num
; i
++) {
781 head
= &sock_hash_select_bucket(htab
, i
)->head
;
782 elem_next
= hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head
)),
783 struct bpf_htab_elem
, node
);
785 memcpy(key_next
, elem_next
->key
, key_size
);
793 static struct bpf_map
*sock_hash_alloc(union bpf_attr
*attr
)
795 struct bpf_htab
*htab
;
799 if (!capable(CAP_NET_ADMIN
))
800 return ERR_PTR(-EPERM
);
801 if (attr
->max_entries
== 0 ||
802 attr
->key_size
== 0 ||
803 attr
->value_size
!= 4 ||
804 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
805 return ERR_PTR(-EINVAL
);
806 if (attr
->key_size
> MAX_BPF_STACK
)
807 return ERR_PTR(-E2BIG
);
809 htab
= kzalloc(sizeof(*htab
), GFP_USER
);
811 return ERR_PTR(-ENOMEM
);
813 bpf_map_init_from_attr(&htab
->map
, attr
);
815 htab
->buckets_num
= roundup_pow_of_two(htab
->map
.max_entries
);
816 htab
->elem_size
= sizeof(struct bpf_htab_elem
) +
817 round_up(htab
->map
.key_size
, 8);
818 if (htab
->buckets_num
== 0 ||
819 htab
->buckets_num
> U32_MAX
/ sizeof(struct bpf_htab_bucket
)) {
824 cost
= (u64
) htab
->buckets_num
* sizeof(struct bpf_htab_bucket
) +
825 (u64
) htab
->elem_size
* htab
->map
.max_entries
;
826 if (cost
>= U32_MAX
- PAGE_SIZE
) {
831 htab
->buckets
= bpf_map_area_alloc(htab
->buckets_num
*
832 sizeof(struct bpf_htab_bucket
),
833 htab
->map
.numa_node
);
834 if (!htab
->buckets
) {
839 for (i
= 0; i
< htab
->buckets_num
; i
++) {
840 INIT_HLIST_HEAD(&htab
->buckets
[i
].head
);
841 raw_spin_lock_init(&htab
->buckets
[i
].lock
);
850 static void sock_hash_free(struct bpf_map
*map
)
852 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
853 struct bpf_htab_bucket
*bucket
;
854 struct bpf_htab_elem
*elem
;
855 struct hlist_node
*node
;
860 for (i
= 0; i
< htab
->buckets_num
; i
++) {
861 bucket
= sock_hash_select_bucket(htab
, i
);
862 raw_spin_lock_bh(&bucket
->lock
);
863 hlist_for_each_entry_safe(elem
, node
, &bucket
->head
, node
) {
864 hlist_del_rcu(&elem
->node
);
865 sock_map_unref(elem
->sk
, elem
);
867 raw_spin_unlock_bh(&bucket
->lock
);
871 bpf_map_area_free(htab
->buckets
);
875 static void sock_hash_release_progs(struct bpf_map
*map
)
877 psock_progs_drop(&container_of(map
, struct bpf_htab
, map
)->progs
);
880 BPF_CALL_4(bpf_sock_hash_update
, struct bpf_sock_ops_kern
*, sops
,
881 struct bpf_map
*, map
, void *, key
, u64
, flags
)
883 WARN_ON_ONCE(!rcu_read_lock_held());
885 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
886 sock_map_op_okay(sops
)))
887 return sock_hash_update_common(map
, key
, sops
->sk
, flags
);
891 const struct bpf_func_proto bpf_sock_hash_update_proto
= {
892 .func
= bpf_sock_hash_update
,
895 .ret_type
= RET_INTEGER
,
896 .arg1_type
= ARG_PTR_TO_CTX
,
897 .arg2_type
= ARG_CONST_MAP_PTR
,
898 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
899 .arg4_type
= ARG_ANYTHING
,
902 BPF_CALL_4(bpf_sk_redirect_hash
, struct sk_buff
*, skb
,
903 struct bpf_map
*, map
, void *, key
, u64
, flags
)
905 struct tcp_skb_cb
*tcb
= TCP_SKB_CB(skb
);
907 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
909 tcb
->bpf
.flags
= flags
;
910 tcb
->bpf
.sk_redir
= __sock_hash_lookup_elem(map
, key
);
911 if (!tcb
->bpf
.sk_redir
)
916 const struct bpf_func_proto bpf_sk_redirect_hash_proto
= {
917 .func
= bpf_sk_redirect_hash
,
919 .ret_type
= RET_INTEGER
,
920 .arg1_type
= ARG_PTR_TO_CTX
,
921 .arg2_type
= ARG_CONST_MAP_PTR
,
922 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
923 .arg4_type
= ARG_ANYTHING
,
926 BPF_CALL_4(bpf_msg_redirect_hash
, struct sk_msg
*, msg
,
927 struct bpf_map
*, map
, void *, key
, u64
, flags
)
929 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
932 msg
->sk_redir
= __sock_hash_lookup_elem(map
, key
);
938 const struct bpf_func_proto bpf_msg_redirect_hash_proto
= {
939 .func
= bpf_msg_redirect_hash
,
941 .ret_type
= RET_INTEGER
,
942 .arg1_type
= ARG_PTR_TO_CTX
,
943 .arg2_type
= ARG_CONST_MAP_PTR
,
944 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
945 .arg4_type
= ARG_ANYTHING
,
948 const struct bpf_map_ops sock_hash_ops
= {
949 .map_alloc
= sock_hash_alloc
,
950 .map_free
= sock_hash_free
,
951 .map_get_next_key
= sock_hash_get_next_key
,
952 .map_update_elem
= sock_hash_update_elem
,
953 .map_delete_elem
= sock_hash_delete_elem
,
954 .map_lookup_elem
= sock_map_lookup
,
955 .map_release_uref
= sock_hash_release_progs
,
956 .map_check_btf
= map_check_no_btf
,
959 static struct sk_psock_progs
*sock_map_progs(struct bpf_map
*map
)
961 switch (map
->map_type
) {
962 case BPF_MAP_TYPE_SOCKMAP
:
963 return &container_of(map
, struct bpf_stab
, map
)->progs
;
964 case BPF_MAP_TYPE_SOCKHASH
:
965 return &container_of(map
, struct bpf_htab
, map
)->progs
;
973 int sock_map_prog_update(struct bpf_map
*map
, struct bpf_prog
*prog
,
976 struct sk_psock_progs
*progs
= sock_map_progs(map
);
982 case BPF_SK_MSG_VERDICT
:
983 psock_set_prog(&progs
->msg_parser
, prog
);
985 case BPF_SK_SKB_STREAM_PARSER
:
986 psock_set_prog(&progs
->skb_parser
, prog
);
988 case BPF_SK_SKB_STREAM_VERDICT
:
989 psock_set_prog(&progs
->skb_verdict
, prog
);
998 void sk_psock_unlink(struct sock
*sk
, struct sk_psock_link
*link
)
1000 switch (link
->map
->map_type
) {
1001 case BPF_MAP_TYPE_SOCKMAP
:
1002 return sock_map_delete_from_link(link
->map
, sk
,
1004 case BPF_MAP_TYPE_SOCKHASH
:
1005 return sock_hash_delete_from_link(link
->map
, sk
,