root/net/core/sock_map.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. sock_map_alloc
  2. sock_map_get_from_fd
  3. sock_map_sk_acquire
  4. sock_map_sk_release
  5. sock_map_add_link
  6. sock_map_del_link
  7. sock_map_unref
  8. sock_map_link
  9. sock_map_free
  10. sock_map_release_progs
  11. __sock_map_lookup_elem
  12. sock_map_lookup
  13. __sock_map_delete
  14. sock_map_delete_from_link
  15. sock_map_delete_elem
  16. sock_map_get_next_key
  17. sock_map_update_common
  18. sock_map_op_okay
  19. sock_map_sk_is_suitable
  20. sock_map_update_elem
  21. BPF_CALL_4
  22. BPF_CALL_4
  23. BPF_CALL_4
  24. sock_hash_bucket_hash
  25. sock_hash_select_bucket
  26. sock_hash_lookup_elem_raw
  27. __sock_hash_lookup_elem
  28. sock_hash_free_elem
  29. sock_hash_delete_from_link
  30. sock_hash_delete_elem
  31. sock_hash_alloc_elem
  32. sock_hash_update_common
  33. sock_hash_update_elem
  34. sock_hash_get_next_key
  35. sock_hash_alloc
  36. sock_hash_free
  37. sock_hash_release_progs
  38. BPF_CALL_4
  39. BPF_CALL_4
  40. BPF_CALL_4
  41. sock_map_progs
  42. sock_map_prog_update
  43. sk_psock_unlink

   1 // SPDX-License-Identifier: GPL-2.0
   2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3 
   4 #include <linux/bpf.h>
   5 #include <linux/filter.h>
   6 #include <linux/errno.h>
   7 #include <linux/file.h>
   8 #include <linux/net.h>
   9 #include <linux/workqueue.h>
  10 #include <linux/skmsg.h>
  11 #include <linux/list.h>
  12 #include <linux/jhash.h>
  13 
  14 struct bpf_stab {
  15         struct bpf_map map;
  16         struct sock **sks;
  17         struct sk_psock_progs progs;
  18         raw_spinlock_t lock;
  19 };
  20 
  21 #define SOCK_CREATE_FLAG_MASK                           \
  22         (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
  23 
  24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
  25 {
  26         struct bpf_stab *stab;
  27         u64 cost;
  28         int err;
  29 
  30         if (!capable(CAP_NET_ADMIN))
  31                 return ERR_PTR(-EPERM);
  32         if (attr->max_entries == 0 ||
  33             attr->key_size    != 4 ||
  34             attr->value_size  != 4 ||
  35             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
  36                 return ERR_PTR(-EINVAL);
  37 
  38         stab = kzalloc(sizeof(*stab), GFP_USER);
  39         if (!stab)
  40                 return ERR_PTR(-ENOMEM);
  41 
  42         bpf_map_init_from_attr(&stab->map, attr);
  43         raw_spin_lock_init(&stab->lock);
  44 
  45         /* Make sure page count doesn't overflow. */
  46         cost = (u64) stab->map.max_entries * sizeof(struct sock *);
  47         err = bpf_map_charge_init(&stab->map.memory, cost);
  48         if (err)
  49                 goto free_stab;
  50 
  51         stab->sks = bpf_map_area_alloc(stab->map.max_entries *
  52                                        sizeof(struct sock *),
  53                                        stab->map.numa_node);
  54         if (stab->sks)
  55                 return &stab->map;
  56         err = -ENOMEM;
  57         bpf_map_charge_finish(&stab->map.memory);
  58 free_stab:
  59         kfree(stab);
  60         return ERR_PTR(err);
  61 }
  62 
  63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
  64 {
  65         u32 ufd = attr->target_fd;
  66         struct bpf_map *map;
  67         struct fd f;
  68         int ret;
  69 
  70         f = fdget(ufd);
  71         map = __bpf_map_get(f);
  72         if (IS_ERR(map))
  73                 return PTR_ERR(map);
  74         ret = sock_map_prog_update(map, prog, attr->attach_type);
  75         fdput(f);
  76         return ret;
  77 }
  78 
  79 static void sock_map_sk_acquire(struct sock *sk)
  80         __acquires(&sk->sk_lock.slock)
  81 {
  82         lock_sock(sk);
  83         preempt_disable();
  84         rcu_read_lock();
  85 }
  86 
  87 static void sock_map_sk_release(struct sock *sk)
  88         __releases(&sk->sk_lock.slock)
  89 {
  90         rcu_read_unlock();
  91         preempt_enable();
  92         release_sock(sk);
  93 }
  94 
  95 static void sock_map_add_link(struct sk_psock *psock,
  96                               struct sk_psock_link *link,
  97                               struct bpf_map *map, void *link_raw)
  98 {
  99         link->link_raw = link_raw;
 100         link->map = map;
 101         spin_lock_bh(&psock->link_lock);
 102         list_add_tail(&link->list, &psock->link);
 103         spin_unlock_bh(&psock->link_lock);
 104 }
 105 
 106 static void sock_map_del_link(struct sock *sk,
 107                               struct sk_psock *psock, void *link_raw)
 108 {
 109         struct sk_psock_link *link, *tmp;
 110         bool strp_stop = false;
 111 
 112         spin_lock_bh(&psock->link_lock);
 113         list_for_each_entry_safe(link, tmp, &psock->link, list) {
 114                 if (link->link_raw == link_raw) {
 115                         struct bpf_map *map = link->map;
 116                         struct bpf_stab *stab = container_of(map, struct bpf_stab,
 117                                                              map);
 118                         if (psock->parser.enabled && stab->progs.skb_parser)
 119                                 strp_stop = true;
 120                         list_del(&link->list);
 121                         sk_psock_free_link(link);
 122                 }
 123         }
 124         spin_unlock_bh(&psock->link_lock);
 125         if (strp_stop) {
 126                 write_lock_bh(&sk->sk_callback_lock);
 127                 sk_psock_stop_strp(sk, psock);
 128                 write_unlock_bh(&sk->sk_callback_lock);
 129         }
 130 }
 131 
 132 static void sock_map_unref(struct sock *sk, void *link_raw)
 133 {
 134         struct sk_psock *psock = sk_psock(sk);
 135 
 136         if (likely(psock)) {
 137                 sock_map_del_link(sk, psock, link_raw);
 138                 sk_psock_put(sk, psock);
 139         }
 140 }
 141 
 142 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 143                          struct sock *sk)
 144 {
 145         struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
 146         bool skb_progs, sk_psock_is_new = false;
 147         struct sk_psock *psock;
 148         int ret;
 149 
 150         skb_verdict = READ_ONCE(progs->skb_verdict);
 151         skb_parser = READ_ONCE(progs->skb_parser);
 152         skb_progs = skb_parser && skb_verdict;
 153         if (skb_progs) {
 154                 skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
 155                 if (IS_ERR(skb_verdict))
 156                         return PTR_ERR(skb_verdict);
 157                 skb_parser = bpf_prog_inc_not_zero(skb_parser);
 158                 if (IS_ERR(skb_parser)) {
 159                         bpf_prog_put(skb_verdict);
 160                         return PTR_ERR(skb_parser);
 161                 }
 162         }
 163 
 164         msg_parser = READ_ONCE(progs->msg_parser);
 165         if (msg_parser) {
 166                 msg_parser = bpf_prog_inc_not_zero(msg_parser);
 167                 if (IS_ERR(msg_parser)) {
 168                         ret = PTR_ERR(msg_parser);
 169                         goto out;
 170                 }
 171         }
 172 
 173         psock = sk_psock_get_checked(sk);
 174         if (IS_ERR(psock)) {
 175                 ret = PTR_ERR(psock);
 176                 goto out_progs;
 177         }
 178 
 179         if (psock) {
 180                 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
 181                     (skb_progs  && READ_ONCE(psock->progs.skb_parser))) {
 182                         sk_psock_put(sk, psock);
 183                         ret = -EBUSY;
 184                         goto out_progs;
 185                 }
 186         } else {
 187                 psock = sk_psock_init(sk, map->numa_node);
 188                 if (!psock) {
 189                         ret = -ENOMEM;
 190                         goto out_progs;
 191                 }
 192                 sk_psock_is_new = true;
 193         }
 194 
 195         if (msg_parser)
 196                 psock_set_prog(&psock->progs.msg_parser, msg_parser);
 197         if (sk_psock_is_new) {
 198                 ret = tcp_bpf_init(sk);
 199                 if (ret < 0)
 200                         goto out_drop;
 201         } else {
 202                 tcp_bpf_reinit(sk);
 203         }
 204 
 205         write_lock_bh(&sk->sk_callback_lock);
 206         if (skb_progs && !psock->parser.enabled) {
 207                 ret = sk_psock_init_strp(sk, psock);
 208                 if (ret) {
 209                         write_unlock_bh(&sk->sk_callback_lock);
 210                         goto out_drop;
 211                 }
 212                 psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
 213                 psock_set_prog(&psock->progs.skb_parser, skb_parser);
 214                 sk_psock_start_strp(sk, psock);
 215         }
 216         write_unlock_bh(&sk->sk_callback_lock);
 217         return 0;
 218 out_drop:
 219         sk_psock_put(sk, psock);
 220 out_progs:
 221         if (msg_parser)
 222                 bpf_prog_put(msg_parser);
 223 out:
 224         if (skb_progs) {
 225                 bpf_prog_put(skb_verdict);
 226                 bpf_prog_put(skb_parser);
 227         }
 228         return ret;
 229 }
 230 
 231 static void sock_map_free(struct bpf_map *map)
 232 {
 233         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 234         int i;
 235 
 236         /* After the sync no updates or deletes will be in-flight so it
 237          * is safe to walk map and remove entries without risking a race
 238          * in EEXIST update case.
 239          */
 240         synchronize_rcu();
 241         for (i = 0; i < stab->map.max_entries; i++) {
 242                 struct sock **psk = &stab->sks[i];
 243                 struct sock *sk;
 244 
 245                 sk = xchg(psk, NULL);
 246                 if (sk) {
 247                         lock_sock(sk);
 248                         rcu_read_lock();
 249                         sock_map_unref(sk, psk);
 250                         rcu_read_unlock();
 251                         release_sock(sk);
 252                 }
 253         }
 254 
 255         /* wait for psock readers accessing its map link */
 256         synchronize_rcu();
 257 
 258         bpf_map_area_free(stab->sks);
 259         kfree(stab);
 260 }
 261 
 262 static void sock_map_release_progs(struct bpf_map *map)
 263 {
 264         psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
 265 }
 266 
 267 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
 268 {
 269         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 270 
 271         WARN_ON_ONCE(!rcu_read_lock_held());
 272 
 273         if (unlikely(key >= map->max_entries))
 274                 return NULL;
 275         return READ_ONCE(stab->sks[key]);
 276 }
 277 
 278 static void *sock_map_lookup(struct bpf_map *map, void *key)
 279 {
 280         return ERR_PTR(-EOPNOTSUPP);
 281 }
 282 
 283 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
 284                              struct sock **psk)
 285 {
 286         struct sock *sk;
 287         int err = 0;
 288 
 289         raw_spin_lock_bh(&stab->lock);
 290         sk = *psk;
 291         if (!sk_test || sk_test == sk)
 292                 sk = xchg(psk, NULL);
 293 
 294         if (likely(sk))
 295                 sock_map_unref(sk, psk);
 296         else
 297                 err = -EINVAL;
 298 
 299         raw_spin_unlock_bh(&stab->lock);
 300         return err;
 301 }
 302 
 303 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
 304                                       void *link_raw)
 305 {
 306         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 307 
 308         __sock_map_delete(stab, sk, link_raw);
 309 }
 310 
 311 static int sock_map_delete_elem(struct bpf_map *map, void *key)
 312 {
 313         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 314         u32 i = *(u32 *)key;
 315         struct sock **psk;
 316 
 317         if (unlikely(i >= map->max_entries))
 318                 return -EINVAL;
 319 
 320         psk = &stab->sks[i];
 321         return __sock_map_delete(stab, NULL, psk);
 322 }
 323 
 324 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
 325 {
 326         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 327         u32 i = key ? *(u32 *)key : U32_MAX;
 328         u32 *key_next = next;
 329 
 330         if (i == stab->map.max_entries - 1)
 331                 return -ENOENT;
 332         if (i >= stab->map.max_entries)
 333                 *key_next = 0;
 334         else
 335                 *key_next = i + 1;
 336         return 0;
 337 }
 338 
 339 static int sock_map_update_common(struct bpf_map *map, u32 idx,
 340                                   struct sock *sk, u64 flags)
 341 {
 342         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 343         struct inet_connection_sock *icsk = inet_csk(sk);
 344         struct sk_psock_link *link;
 345         struct sk_psock *psock;
 346         struct sock *osk;
 347         int ret;
 348 
 349         WARN_ON_ONCE(!rcu_read_lock_held());
 350         if (unlikely(flags > BPF_EXIST))
 351                 return -EINVAL;
 352         if (unlikely(idx >= map->max_entries))
 353                 return -E2BIG;
 354         if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
 355                 return -EINVAL;
 356 
 357         link = sk_psock_init_link();
 358         if (!link)
 359                 return -ENOMEM;
 360 
 361         ret = sock_map_link(map, &stab->progs, sk);
 362         if (ret < 0)
 363                 goto out_free;
 364 
 365         psock = sk_psock(sk);
 366         WARN_ON_ONCE(!psock);
 367 
 368         raw_spin_lock_bh(&stab->lock);
 369         osk = stab->sks[idx];
 370         if (osk && flags == BPF_NOEXIST) {
 371                 ret = -EEXIST;
 372                 goto out_unlock;
 373         } else if (!osk && flags == BPF_EXIST) {
 374                 ret = -ENOENT;
 375                 goto out_unlock;
 376         }
 377 
 378         sock_map_add_link(psock, link, map, &stab->sks[idx]);
 379         stab->sks[idx] = sk;
 380         if (osk)
 381                 sock_map_unref(osk, &stab->sks[idx]);
 382         raw_spin_unlock_bh(&stab->lock);
 383         return 0;
 384 out_unlock:
 385         raw_spin_unlock_bh(&stab->lock);
 386         if (psock)
 387                 sk_psock_put(sk, psock);
 388 out_free:
 389         sk_psock_free_link(link);
 390         return ret;
 391 }
 392 
 393 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
 394 {
 395         return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
 396                ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
 397 }
 398 
 399 static bool sock_map_sk_is_suitable(const struct sock *sk)
 400 {
 401         return sk->sk_type == SOCK_STREAM &&
 402                sk->sk_protocol == IPPROTO_TCP;
 403 }
 404 
 405 static int sock_map_update_elem(struct bpf_map *map, void *key,
 406                                 void *value, u64 flags)
 407 {
 408         u32 ufd = *(u32 *)value;
 409         u32 idx = *(u32 *)key;
 410         struct socket *sock;
 411         struct sock *sk;
 412         int ret;
 413 
 414         sock = sockfd_lookup(ufd, &ret);
 415         if (!sock)
 416                 return ret;
 417         sk = sock->sk;
 418         if (!sk) {
 419                 ret = -EINVAL;
 420                 goto out;
 421         }
 422         if (!sock_map_sk_is_suitable(sk)) {
 423                 ret = -EOPNOTSUPP;
 424                 goto out;
 425         }
 426 
 427         sock_map_sk_acquire(sk);
 428         if (sk->sk_state != TCP_ESTABLISHED)
 429                 ret = -EOPNOTSUPP;
 430         else
 431                 ret = sock_map_update_common(map, idx, sk, flags);
 432         sock_map_sk_release(sk);
 433 out:
 434         fput(sock->file);
 435         return ret;
 436 }
 437 
 438 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
 439            struct bpf_map *, map, void *, key, u64, flags)
 440 {
 441         WARN_ON_ONCE(!rcu_read_lock_held());
 442 
 443         if (likely(sock_map_sk_is_suitable(sops->sk) &&
 444                    sock_map_op_okay(sops)))
 445                 return sock_map_update_common(map, *(u32 *)key, sops->sk,
 446                                               flags);
 447         return -EOPNOTSUPP;
 448 }
 449 
 450 const struct bpf_func_proto bpf_sock_map_update_proto = {
 451         .func           = bpf_sock_map_update,
 452         .gpl_only       = false,
 453         .pkt_access     = true,
 454         .ret_type       = RET_INTEGER,
 455         .arg1_type      = ARG_PTR_TO_CTX,
 456         .arg2_type      = ARG_CONST_MAP_PTR,
 457         .arg3_type      = ARG_PTR_TO_MAP_KEY,
 458         .arg4_type      = ARG_ANYTHING,
 459 };
 460 
 461 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
 462            struct bpf_map *, map, u32, key, u64, flags)
 463 {
 464         struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
 465 
 466         if (unlikely(flags & ~(BPF_F_INGRESS)))
 467                 return SK_DROP;
 468         tcb->bpf.flags = flags;
 469         tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
 470         if (!tcb->bpf.sk_redir)
 471                 return SK_DROP;
 472         return SK_PASS;
 473 }
 474 
 475 const struct bpf_func_proto bpf_sk_redirect_map_proto = {
 476         .func           = bpf_sk_redirect_map,
 477         .gpl_only       = false,
 478         .ret_type       = RET_INTEGER,
 479         .arg1_type      = ARG_PTR_TO_CTX,
 480         .arg2_type      = ARG_CONST_MAP_PTR,
 481         .arg3_type      = ARG_ANYTHING,
 482         .arg4_type      = ARG_ANYTHING,
 483 };
 484 
 485 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
 486            struct bpf_map *, map, u32, key, u64, flags)
 487 {
 488         if (unlikely(flags & ~(BPF_F_INGRESS)))
 489                 return SK_DROP;
 490         msg->flags = flags;
 491         msg->sk_redir = __sock_map_lookup_elem(map, key);
 492         if (!msg->sk_redir)
 493                 return SK_DROP;
 494         return SK_PASS;
 495 }
 496 
 497 const struct bpf_func_proto bpf_msg_redirect_map_proto = {
 498         .func           = bpf_msg_redirect_map,
 499         .gpl_only       = false,
 500         .ret_type       = RET_INTEGER,
 501         .arg1_type      = ARG_PTR_TO_CTX,
 502         .arg2_type      = ARG_CONST_MAP_PTR,
 503         .arg3_type      = ARG_ANYTHING,
 504         .arg4_type      = ARG_ANYTHING,
 505 };
 506 
 507 const struct bpf_map_ops sock_map_ops = {
 508         .map_alloc              = sock_map_alloc,
 509         .map_free               = sock_map_free,
 510         .map_get_next_key       = sock_map_get_next_key,
 511         .map_update_elem        = sock_map_update_elem,
 512         .map_delete_elem        = sock_map_delete_elem,
 513         .map_lookup_elem        = sock_map_lookup,
 514         .map_release_uref       = sock_map_release_progs,
 515         .map_check_btf          = map_check_no_btf,
 516 };
 517 
 518 struct bpf_htab_elem {
 519         struct rcu_head rcu;
 520         u32 hash;
 521         struct sock *sk;
 522         struct hlist_node node;
 523         u8 key[0];
 524 };
 525 
 526 struct bpf_htab_bucket {
 527         struct hlist_head head;
 528         raw_spinlock_t lock;
 529 };
 530 
 531 struct bpf_htab {
 532         struct bpf_map map;
 533         struct bpf_htab_bucket *buckets;
 534         u32 buckets_num;
 535         u32 elem_size;
 536         struct sk_psock_progs progs;
 537         atomic_t count;
 538 };
 539 
 540 static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
 541 {
 542         return jhash(key, len, 0);
 543 }
 544 
 545 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
 546                                                        u32 hash)
 547 {
 548         return &htab->buckets[hash & (htab->buckets_num - 1)];
 549 }
 550 
 551 static struct bpf_htab_elem *
 552 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
 553                           u32 key_size)
 554 {
 555         struct bpf_htab_elem *elem;
 556 
 557         hlist_for_each_entry_rcu(elem, head, node) {
 558                 if (elem->hash == hash &&
 559                     !memcmp(&elem->key, key, key_size))
 560                         return elem;
 561         }
 562 
 563         return NULL;
 564 }
 565 
 566 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
 567 {
 568         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 569         u32 key_size = map->key_size, hash;
 570         struct bpf_htab_bucket *bucket;
 571         struct bpf_htab_elem *elem;
 572 
 573         WARN_ON_ONCE(!rcu_read_lock_held());
 574 
 575         hash = sock_hash_bucket_hash(key, key_size);
 576         bucket = sock_hash_select_bucket(htab, hash);
 577         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 578 
 579         return elem ? elem->sk : NULL;
 580 }
 581 
 582 static void sock_hash_free_elem(struct bpf_htab *htab,
 583                                 struct bpf_htab_elem *elem)
 584 {
 585         atomic_dec(&htab->count);
 586         kfree_rcu(elem, rcu);
 587 }
 588 
 589 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
 590                                        void *link_raw)
 591 {
 592         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 593         struct bpf_htab_elem *elem_probe, *elem = link_raw;
 594         struct bpf_htab_bucket *bucket;
 595 
 596         WARN_ON_ONCE(!rcu_read_lock_held());
 597         bucket = sock_hash_select_bucket(htab, elem->hash);
 598 
 599         /* elem may be deleted in parallel from the map, but access here
 600          * is okay since it's going away only after RCU grace period.
 601          * However, we need to check whether it's still present.
 602          */
 603         raw_spin_lock_bh(&bucket->lock);
 604         elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
 605                                                elem->key, map->key_size);
 606         if (elem_probe && elem_probe == elem) {
 607                 hlist_del_rcu(&elem->node);
 608                 sock_map_unref(elem->sk, elem);
 609                 sock_hash_free_elem(htab, elem);
 610         }
 611         raw_spin_unlock_bh(&bucket->lock);
 612 }
 613 
 614 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
 615 {
 616         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 617         u32 hash, key_size = map->key_size;
 618         struct bpf_htab_bucket *bucket;
 619         struct bpf_htab_elem *elem;
 620         int ret = -ENOENT;
 621 
 622         hash = sock_hash_bucket_hash(key, key_size);
 623         bucket = sock_hash_select_bucket(htab, hash);
 624 
 625         raw_spin_lock_bh(&bucket->lock);
 626         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 627         if (elem) {
 628                 hlist_del_rcu(&elem->node);
 629                 sock_map_unref(elem->sk, elem);
 630                 sock_hash_free_elem(htab, elem);
 631                 ret = 0;
 632         }
 633         raw_spin_unlock_bh(&bucket->lock);
 634         return ret;
 635 }
 636 
 637 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
 638                                                   void *key, u32 key_size,
 639                                                   u32 hash, struct sock *sk,
 640                                                   struct bpf_htab_elem *old)
 641 {
 642         struct bpf_htab_elem *new;
 643 
 644         if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
 645                 if (!old) {
 646                         atomic_dec(&htab->count);
 647                         return ERR_PTR(-E2BIG);
 648                 }
 649         }
 650 
 651         new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
 652                            htab->map.numa_node);
 653         if (!new) {
 654                 atomic_dec(&htab->count);
 655                 return ERR_PTR(-ENOMEM);
 656         }
 657         memcpy(new->key, key, key_size);
 658         new->sk = sk;
 659         new->hash = hash;
 660         return new;
 661 }
 662 
 663 static int sock_hash_update_common(struct bpf_map *map, void *key,
 664                                    struct sock *sk, u64 flags)
 665 {
 666         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 667         struct inet_connection_sock *icsk = inet_csk(sk);
 668         u32 key_size = map->key_size, hash;
 669         struct bpf_htab_elem *elem, *elem_new;
 670         struct bpf_htab_bucket *bucket;
 671         struct sk_psock_link *link;
 672         struct sk_psock *psock;
 673         int ret;
 674 
 675         WARN_ON_ONCE(!rcu_read_lock_held());
 676         if (unlikely(flags > BPF_EXIST))
 677                 return -EINVAL;
 678         if (unlikely(icsk->icsk_ulp_data))
 679                 return -EINVAL;
 680 
 681         link = sk_psock_init_link();
 682         if (!link)
 683                 return -ENOMEM;
 684 
 685         ret = sock_map_link(map, &htab->progs, sk);
 686         if (ret < 0)
 687                 goto out_free;
 688 
 689         psock = sk_psock(sk);
 690         WARN_ON_ONCE(!psock);
 691 
 692         hash = sock_hash_bucket_hash(key, key_size);
 693         bucket = sock_hash_select_bucket(htab, hash);
 694 
 695         raw_spin_lock_bh(&bucket->lock);
 696         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 697         if (elem && flags == BPF_NOEXIST) {
 698                 ret = -EEXIST;
 699                 goto out_unlock;
 700         } else if (!elem && flags == BPF_EXIST) {
 701                 ret = -ENOENT;
 702                 goto out_unlock;
 703         }
 704 
 705         elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
 706         if (IS_ERR(elem_new)) {
 707                 ret = PTR_ERR(elem_new);
 708                 goto out_unlock;
 709         }
 710 
 711         sock_map_add_link(psock, link, map, elem_new);
 712         /* Add new element to the head of the list, so that
 713          * concurrent search will find it before old elem.
 714          */
 715         hlist_add_head_rcu(&elem_new->node, &bucket->head);
 716         if (elem) {
 717                 hlist_del_rcu(&elem->node);
 718                 sock_map_unref(elem->sk, elem);
 719                 sock_hash_free_elem(htab, elem);
 720         }
 721         raw_spin_unlock_bh(&bucket->lock);
 722         return 0;
 723 out_unlock:
 724         raw_spin_unlock_bh(&bucket->lock);
 725         sk_psock_put(sk, psock);
 726 out_free:
 727         sk_psock_free_link(link);
 728         return ret;
 729 }
 730 
 731 static int sock_hash_update_elem(struct bpf_map *map, void *key,
 732                                  void *value, u64 flags)
 733 {
 734         u32 ufd = *(u32 *)value;
 735         struct socket *sock;
 736         struct sock *sk;
 737         int ret;
 738 
 739         sock = sockfd_lookup(ufd, &ret);
 740         if (!sock)
 741                 return ret;
 742         sk = sock->sk;
 743         if (!sk) {
 744                 ret = -EINVAL;
 745                 goto out;
 746         }
 747         if (!sock_map_sk_is_suitable(sk)) {
 748                 ret = -EOPNOTSUPP;
 749                 goto out;
 750         }
 751 
 752         sock_map_sk_acquire(sk);
 753         if (sk->sk_state != TCP_ESTABLISHED)
 754                 ret = -EOPNOTSUPP;
 755         else
 756                 ret = sock_hash_update_common(map, key, sk, flags);
 757         sock_map_sk_release(sk);
 758 out:
 759         fput(sock->file);
 760         return ret;
 761 }
 762 
 763 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
 764                                   void *key_next)
 765 {
 766         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 767         struct bpf_htab_elem *elem, *elem_next;
 768         u32 hash, key_size = map->key_size;
 769         struct hlist_head *head;
 770         int i = 0;
 771 
 772         if (!key)
 773                 goto find_first_elem;
 774         hash = sock_hash_bucket_hash(key, key_size);
 775         head = &sock_hash_select_bucket(htab, hash)->head;
 776         elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
 777         if (!elem)
 778                 goto find_first_elem;
 779 
 780         elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
 781                                      struct bpf_htab_elem, node);
 782         if (elem_next) {
 783                 memcpy(key_next, elem_next->key, key_size);
 784                 return 0;
 785         }
 786 
 787         i = hash & (htab->buckets_num - 1);
 788         i++;
 789 find_first_elem:
 790         for (; i < htab->buckets_num; i++) {
 791                 head = &sock_hash_select_bucket(htab, i)->head;
 792                 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
 793                                              struct bpf_htab_elem, node);
 794                 if (elem_next) {
 795                         memcpy(key_next, elem_next->key, key_size);
 796                         return 0;
 797                 }
 798         }
 799 
 800         return -ENOENT;
 801 }
 802 
 803 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
 804 {
 805         struct bpf_htab *htab;
 806         int i, err;
 807         u64 cost;
 808 
 809         if (!capable(CAP_NET_ADMIN))
 810                 return ERR_PTR(-EPERM);
 811         if (attr->max_entries == 0 ||
 812             attr->key_size    == 0 ||
 813             attr->value_size  != 4 ||
 814             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
 815                 return ERR_PTR(-EINVAL);
 816         if (attr->key_size > MAX_BPF_STACK)
 817                 return ERR_PTR(-E2BIG);
 818 
 819         htab = kzalloc(sizeof(*htab), GFP_USER);
 820         if (!htab)
 821                 return ERR_PTR(-ENOMEM);
 822 
 823         bpf_map_init_from_attr(&htab->map, attr);
 824 
 825         htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
 826         htab->elem_size = sizeof(struct bpf_htab_elem) +
 827                           round_up(htab->map.key_size, 8);
 828         if (htab->buckets_num == 0 ||
 829             htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
 830                 err = -EINVAL;
 831                 goto free_htab;
 832         }
 833 
 834         cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
 835                (u64) htab->elem_size * htab->map.max_entries;
 836         if (cost >= U32_MAX - PAGE_SIZE) {
 837                 err = -EINVAL;
 838                 goto free_htab;
 839         }
 840 
 841         htab->buckets = bpf_map_area_alloc(htab->buckets_num *
 842                                            sizeof(struct bpf_htab_bucket),
 843                                            htab->map.numa_node);
 844         if (!htab->buckets) {
 845                 err = -ENOMEM;
 846                 goto free_htab;
 847         }
 848 
 849         for (i = 0; i < htab->buckets_num; i++) {
 850                 INIT_HLIST_HEAD(&htab->buckets[i].head);
 851                 raw_spin_lock_init(&htab->buckets[i].lock);
 852         }
 853 
 854         return &htab->map;
 855 free_htab:
 856         kfree(htab);
 857         return ERR_PTR(err);
 858 }
 859 
 860 static void sock_hash_free(struct bpf_map *map)
 861 {
 862         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 863         struct bpf_htab_bucket *bucket;
 864         struct bpf_htab_elem *elem;
 865         struct hlist_node *node;
 866         int i;
 867 
 868         /* After the sync no updates or deletes will be in-flight so it
 869          * is safe to walk map and remove entries without risking a race
 870          * in EEXIST update case.
 871          */
 872         synchronize_rcu();
 873         for (i = 0; i < htab->buckets_num; i++) {
 874                 bucket = sock_hash_select_bucket(htab, i);
 875                 hlist_for_each_entry_safe(elem, node, &bucket->head, node) {
 876                         hlist_del_rcu(&elem->node);
 877                         lock_sock(elem->sk);
 878                         rcu_read_lock();
 879                         sock_map_unref(elem->sk, elem);
 880                         rcu_read_unlock();
 881                         release_sock(elem->sk);
 882                 }
 883         }
 884 
 885         /* wait for psock readers accessing its map link */
 886         synchronize_rcu();
 887 
 888         /* wait for psock readers accessing its map link */
 889         synchronize_rcu();
 890 
 891         bpf_map_area_free(htab->buckets);
 892         kfree(htab);
 893 }
 894 
 895 static void sock_hash_release_progs(struct bpf_map *map)
 896 {
 897         psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
 898 }
 899 
 900 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
 901            struct bpf_map *, map, void *, key, u64, flags)
 902 {
 903         WARN_ON_ONCE(!rcu_read_lock_held());
 904 
 905         if (likely(sock_map_sk_is_suitable(sops->sk) &&
 906                    sock_map_op_okay(sops)))
 907                 return sock_hash_update_common(map, key, sops->sk, flags);
 908         return -EOPNOTSUPP;
 909 }
 910 
 911 const struct bpf_func_proto bpf_sock_hash_update_proto = {
 912         .func           = bpf_sock_hash_update,
 913         .gpl_only       = false,
 914         .pkt_access     = true,
 915         .ret_type       = RET_INTEGER,
 916         .arg1_type      = ARG_PTR_TO_CTX,
 917         .arg2_type      = ARG_CONST_MAP_PTR,
 918         .arg3_type      = ARG_PTR_TO_MAP_KEY,
 919         .arg4_type      = ARG_ANYTHING,
 920 };
 921 
 922 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
 923            struct bpf_map *, map, void *, key, u64, flags)
 924 {
 925         struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
 926 
 927         if (unlikely(flags & ~(BPF_F_INGRESS)))
 928                 return SK_DROP;
 929         tcb->bpf.flags = flags;
 930         tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
 931         if (!tcb->bpf.sk_redir)
 932                 return SK_DROP;
 933         return SK_PASS;
 934 }
 935 
 936 const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
 937         .func           = bpf_sk_redirect_hash,
 938         .gpl_only       = false,
 939         .ret_type       = RET_INTEGER,
 940         .arg1_type      = ARG_PTR_TO_CTX,
 941         .arg2_type      = ARG_CONST_MAP_PTR,
 942         .arg3_type      = ARG_PTR_TO_MAP_KEY,
 943         .arg4_type      = ARG_ANYTHING,
 944 };
 945 
 946 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
 947            struct bpf_map *, map, void *, key, u64, flags)
 948 {
 949         if (unlikely(flags & ~(BPF_F_INGRESS)))
 950                 return SK_DROP;
 951         msg->flags = flags;
 952         msg->sk_redir = __sock_hash_lookup_elem(map, key);
 953         if (!msg->sk_redir)
 954                 return SK_DROP;
 955         return SK_PASS;
 956 }
 957 
 958 const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
 959         .func           = bpf_msg_redirect_hash,
 960         .gpl_only       = false,
 961         .ret_type       = RET_INTEGER,
 962         .arg1_type      = ARG_PTR_TO_CTX,
 963         .arg2_type      = ARG_CONST_MAP_PTR,
 964         .arg3_type      = ARG_PTR_TO_MAP_KEY,
 965         .arg4_type      = ARG_ANYTHING,
 966 };
 967 
 968 const struct bpf_map_ops sock_hash_ops = {
 969         .map_alloc              = sock_hash_alloc,
 970         .map_free               = sock_hash_free,
 971         .map_get_next_key       = sock_hash_get_next_key,
 972         .map_update_elem        = sock_hash_update_elem,
 973         .map_delete_elem        = sock_hash_delete_elem,
 974         .map_lookup_elem        = sock_map_lookup,
 975         .map_release_uref       = sock_hash_release_progs,
 976         .map_check_btf          = map_check_no_btf,
 977 };
 978 
 979 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
 980 {
 981         switch (map->map_type) {
 982         case BPF_MAP_TYPE_SOCKMAP:
 983                 return &container_of(map, struct bpf_stab, map)->progs;
 984         case BPF_MAP_TYPE_SOCKHASH:
 985                 return &container_of(map, struct bpf_htab, map)->progs;
 986         default:
 987                 break;
 988         }
 989 
 990         return NULL;
 991 }
 992 
 993 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
 994                          u32 which)
 995 {
 996         struct sk_psock_progs *progs = sock_map_progs(map);
 997 
 998         if (!progs)
 999                 return -EOPNOTSUPP;
1000 
1001         switch (which) {
1002         case BPF_SK_MSG_VERDICT:
1003                 psock_set_prog(&progs->msg_parser, prog);
1004                 break;
1005         case BPF_SK_SKB_STREAM_PARSER:
1006                 psock_set_prog(&progs->skb_parser, prog);
1007                 break;
1008         case BPF_SK_SKB_STREAM_VERDICT:
1009                 psock_set_prog(&progs->skb_verdict, prog);
1010                 break;
1011         default:
1012                 return -EOPNOTSUPP;
1013         }
1014 
1015         return 0;
1016 }
1017 
1018 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
1019 {
1020         switch (link->map->map_type) {
1021         case BPF_MAP_TYPE_SOCKMAP:
1022                 return sock_map_delete_from_link(link->map, sk,
1023                                                  link->link_raw);
1024         case BPF_MAP_TYPE_SOCKHASH:
1025                 return sock_hash_delete_from_link(link->map, sk,
1026                                                   link->link_raw);
1027         default:
1028                 break;
1029         }
1030 }

/* [<][>][^][v][top][bottom][index][help] */