root/tools/testing/selftests/bpf/prog_tests/sockopt_inherit.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. connect_to_server
  2. verify_sockopt
  3. server_thread
  4. start_server
  5. prog_attach
  6. run_test
  7. test_sockopt_inherit

   1 // SPDX-License-Identifier: GPL-2.0
   2 #include <test_progs.h>
   3 #include "cgroup_helpers.h"
   4 
   5 #define SOL_CUSTOM                      0xdeadbeef
   6 #define CUSTOM_INHERIT1                 0
   7 #define CUSTOM_INHERIT2                 1
   8 #define CUSTOM_LISTENER                 2
   9 
  10 static int connect_to_server(int server_fd)
  11 {
  12         struct sockaddr_storage addr;
  13         socklen_t len = sizeof(addr);
  14         int fd;
  15 
  16         fd = socket(AF_INET, SOCK_STREAM, 0);
  17         if (fd < 0) {
  18                 log_err("Failed to create client socket");
  19                 return -1;
  20         }
  21 
  22         if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
  23                 log_err("Failed to get server addr");
  24                 goto out;
  25         }
  26 
  27         if (connect(fd, (const struct sockaddr *)&addr, len) < 0) {
  28                 log_err("Fail to connect to server");
  29                 goto out;
  30         }
  31 
  32         return fd;
  33 
  34 out:
  35         close(fd);
  36         return -1;
  37 }
  38 
  39 static int verify_sockopt(int fd, int optname, const char *msg, char expected)
  40 {
  41         socklen_t optlen = 1;
  42         char buf = 0;
  43         int err;
  44 
  45         err = getsockopt(fd, SOL_CUSTOM, optname, &buf, &optlen);
  46         if (err) {
  47                 log_err("%s: failed to call getsockopt", msg);
  48                 return 1;
  49         }
  50 
  51         printf("%s %d: got=0x%x ? expected=0x%x\n", msg, optname, buf, expected);
  52 
  53         if (buf != expected) {
  54                 log_err("%s: unexpected getsockopt value %d != %d", msg,
  55                         buf, expected);
  56                 return 1;
  57         }
  58 
  59         return 0;
  60 }
  61 
  62 static pthread_mutex_t server_started_mtx = PTHREAD_MUTEX_INITIALIZER;
  63 static pthread_cond_t server_started = PTHREAD_COND_INITIALIZER;
  64 
  65 static void *server_thread(void *arg)
  66 {
  67         struct sockaddr_storage addr;
  68         socklen_t len = sizeof(addr);
  69         int fd = *(int *)arg;
  70         int client_fd;
  71         int err = 0;
  72 
  73         err = listen(fd, 1);
  74 
  75         pthread_mutex_lock(&server_started_mtx);
  76         pthread_cond_signal(&server_started);
  77         pthread_mutex_unlock(&server_started_mtx);
  78 
  79         if (CHECK_FAIL(err < 0)) {
  80                 perror("Failed to listed on socket");
  81                 return NULL;
  82         }
  83 
  84         err += verify_sockopt(fd, CUSTOM_INHERIT1, "listen", 1);
  85         err += verify_sockopt(fd, CUSTOM_INHERIT2, "listen", 1);
  86         err += verify_sockopt(fd, CUSTOM_LISTENER, "listen", 1);
  87 
  88         client_fd = accept(fd, (struct sockaddr *)&addr, &len);
  89         if (CHECK_FAIL(client_fd < 0)) {
  90                 perror("Failed to accept client");
  91                 return NULL;
  92         }
  93 
  94         err += verify_sockopt(client_fd, CUSTOM_INHERIT1, "accept", 1);
  95         err += verify_sockopt(client_fd, CUSTOM_INHERIT2, "accept", 1);
  96         err += verify_sockopt(client_fd, CUSTOM_LISTENER, "accept", 0);
  97 
  98         close(client_fd);
  99 
 100         return (void *)(long)err;
 101 }
 102 
 103 static int start_server(void)
 104 {
 105         struct sockaddr_in addr = {
 106                 .sin_family = AF_INET,
 107                 .sin_addr.s_addr = htonl(INADDR_LOOPBACK),
 108         };
 109         char buf;
 110         int err;
 111         int fd;
 112         int i;
 113 
 114         fd = socket(AF_INET, SOCK_STREAM, 0);
 115         if (fd < 0) {
 116                 log_err("Failed to create server socket");
 117                 return -1;
 118         }
 119 
 120         for (i = CUSTOM_INHERIT1; i <= CUSTOM_LISTENER; i++) {
 121                 buf = 0x01;
 122                 err = setsockopt(fd, SOL_CUSTOM, i, &buf, 1);
 123                 if (err) {
 124                         log_err("Failed to call setsockopt(%d)", i);
 125                         close(fd);
 126                         return -1;
 127                 }
 128         }
 129 
 130         if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) {
 131                 log_err("Failed to bind socket");
 132                 close(fd);
 133                 return -1;
 134         }
 135 
 136         return fd;
 137 }
 138 
 139 static int prog_attach(struct bpf_object *obj, int cgroup_fd, const char *title)
 140 {
 141         enum bpf_attach_type attach_type;
 142         enum bpf_prog_type prog_type;
 143         struct bpf_program *prog;
 144         int err;
 145 
 146         err = libbpf_prog_type_by_name(title, &prog_type, &attach_type);
 147         if (err) {
 148                 log_err("Failed to deduct types for %s BPF program", title);
 149                 return -1;
 150         }
 151 
 152         prog = bpf_object__find_program_by_title(obj, title);
 153         if (!prog) {
 154                 log_err("Failed to find %s BPF program", title);
 155                 return -1;
 156         }
 157 
 158         err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd,
 159                               attach_type, 0);
 160         if (err) {
 161                 log_err("Failed to attach %s BPF program", title);
 162                 return -1;
 163         }
 164 
 165         return 0;
 166 }
 167 
 168 static void run_test(int cgroup_fd)
 169 {
 170         struct bpf_prog_load_attr attr = {
 171                 .file = "./sockopt_inherit.o",
 172         };
 173         int server_fd = -1, client_fd;
 174         struct bpf_object *obj;
 175         void *server_err;
 176         pthread_t tid;
 177         int ignored;
 178         int err;
 179 
 180         err = bpf_prog_load_xattr(&attr, &obj, &ignored);
 181         if (CHECK_FAIL(err))
 182                 return;
 183 
 184         err = prog_attach(obj, cgroup_fd, "cgroup/getsockopt");
 185         if (CHECK_FAIL(err))
 186                 goto close_bpf_object;
 187 
 188         err = prog_attach(obj, cgroup_fd, "cgroup/setsockopt");
 189         if (CHECK_FAIL(err))
 190                 goto close_bpf_object;
 191 
 192         server_fd = start_server();
 193         if (CHECK_FAIL(server_fd < 0))
 194                 goto close_bpf_object;
 195 
 196         if (CHECK_FAIL(pthread_create(&tid, NULL, server_thread,
 197                                       (void *)&server_fd)))
 198                 goto close_server_fd;
 199 
 200         pthread_mutex_lock(&server_started_mtx);
 201         pthread_cond_wait(&server_started, &server_started_mtx);
 202         pthread_mutex_unlock(&server_started_mtx);
 203 
 204         client_fd = connect_to_server(server_fd);
 205         if (CHECK_FAIL(client_fd < 0))
 206                 goto close_server_fd;
 207 
 208         CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT1, "connect", 0));
 209         CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT2, "connect", 0));
 210         CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_LISTENER, "connect", 0));
 211 
 212         pthread_join(tid, &server_err);
 213 
 214         err = (int)(long)server_err;
 215         CHECK_FAIL(err);
 216 
 217         close(client_fd);
 218 
 219 close_server_fd:
 220         close(server_fd);
 221 close_bpf_object:
 222         bpf_object__close(obj);
 223 }
 224 
 225 void test_sockopt_inherit(void)
 226 {
 227         int cgroup_fd;
 228 
 229         cgroup_fd = test__join_cgroup("/sockopt_inherit");
 230         if (CHECK_FAIL(cgroup_fd < 0))
 231                 return;
 232 
 233         run_test(cgroup_fd);
 234         close(cgroup_fd);
 235 }

/* [<][>][^][v][top][bottom][index][help] */