root/tools/testing/selftests/net/tcp_fastopen_backup_key.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. get_keys
  2. set_keys
  3. build_rcv_fd
  4. connect_and_send
  5. is_listen_fd
  6. rotate_key
  7. run_one_test
  8. parse_opts
  9. main

   1 // SPDX-License-Identifier: GPL-2.0
   2 
   3 /*
   4  * Test key rotation for TFO.
   5  * New keys are 'rotated' in two steps:
   6  * 1) Add new key as the 'backup' key 'behind' the primary key
   7  * 2) Make new key the primary by swapping the backup and primary keys
   8  *
   9  * The rotation is done in stages using multiple sockets bound
  10  * to the same port via SO_REUSEPORT. This simulates key rotation
  11  * behind say a load balancer. We verify that across the rotation
  12  * there are no cases in which a cookie is not accepted by verifying
  13  * that TcpExtTCPFastOpenPassiveFail remains 0.
  14  */
  15 #define _GNU_SOURCE
  16 #include <arpa/inet.h>
  17 #include <errno.h>
  18 #include <error.h>
  19 #include <stdbool.h>
  20 #include <stdio.h>
  21 #include <stdlib.h>
  22 #include <string.h>
  23 #include <sys/epoll.h>
  24 #include <unistd.h>
  25 #include <netinet/tcp.h>
  26 #include <fcntl.h>
  27 #include <time.h>
  28 
  29 #ifndef TCP_FASTOPEN_KEY
  30 #define TCP_FASTOPEN_KEY 33
  31 #endif
  32 
  33 #define N_LISTEN 10
  34 #define PROC_FASTOPEN_KEY "/proc/sys/net/ipv4/tcp_fastopen_key"
  35 #define KEY_LENGTH 16
  36 
  37 #ifndef ARRAY_SIZE
  38 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
  39 #endif
  40 
  41 static bool do_ipv6;
  42 static bool do_sockopt;
  43 static bool do_rotate;
  44 static int key_len = KEY_LENGTH;
  45 static int rcv_fds[N_LISTEN];
  46 static int proc_fd;
  47 static const char *IP4_ADDR = "127.0.0.1";
  48 static const char *IP6_ADDR = "::1";
  49 static const int PORT = 8891;
  50 
  51 static void get_keys(int fd, uint32_t *keys)
  52 {
  53         char buf[128];
  54         socklen_t len = KEY_LENGTH * 2;
  55 
  56         if (do_sockopt) {
  57                 if (getsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, &len))
  58                         error(1, errno, "Unable to get key");
  59                 return;
  60         }
  61         lseek(proc_fd, 0, SEEK_SET);
  62         if (read(proc_fd, buf, sizeof(buf)) <= 0)
  63                 error(1, errno, "Unable to read %s", PROC_FASTOPEN_KEY);
  64         if (sscanf(buf, "%x-%x-%x-%x,%x-%x-%x-%x", keys, keys + 1, keys + 2,
  65             keys + 3, keys + 4, keys + 5, keys + 6, keys + 7) != 8)
  66                 error(1, 0, "Unable to parse %s", PROC_FASTOPEN_KEY);
  67 }
  68 
  69 static void set_keys(int fd, uint32_t *keys)
  70 {
  71         char buf[128];
  72 
  73         if (do_sockopt) {
  74                 if (setsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys,
  75                     key_len))
  76                         error(1, errno, "Unable to set key");
  77                 return;
  78         }
  79         if (do_rotate)
  80                 snprintf(buf, 128, "%08x-%08x-%08x-%08x,%08x-%08x-%08x-%08x",
  81                          keys[0], keys[1], keys[2], keys[3], keys[4], keys[5],
  82                          keys[6], keys[7]);
  83         else
  84                 snprintf(buf, 128, "%08x-%08x-%08x-%08x",
  85                          keys[0], keys[1], keys[2], keys[3]);
  86         lseek(proc_fd, 0, SEEK_SET);
  87         if (write(proc_fd, buf, sizeof(buf)) <= 0)
  88                 error(1, errno, "Unable to write %s", PROC_FASTOPEN_KEY);
  89 }
  90 
  91 static void build_rcv_fd(int family, int proto, int *rcv_fds)
  92 {
  93         struct sockaddr_in  addr4 = {0};
  94         struct sockaddr_in6 addr6 = {0};
  95         struct sockaddr *addr;
  96         int opt = 1, i, sz;
  97         int qlen = 100;
  98         uint32_t keys[8];
  99 
 100         switch (family) {
 101         case AF_INET:
 102                 addr4.sin_family = family;
 103                 addr4.sin_addr.s_addr = htonl(INADDR_ANY);
 104                 addr4.sin_port = htons(PORT);
 105                 sz = sizeof(addr4);
 106                 addr = (struct sockaddr *)&addr4;
 107                 break;
 108         case AF_INET6:
 109                 addr6.sin6_family = AF_INET6;
 110                 addr6.sin6_addr = in6addr_any;
 111                 addr6.sin6_port = htons(PORT);
 112                 sz = sizeof(addr6);
 113                 addr = (struct sockaddr *)&addr6;
 114                 break;
 115         default:
 116                 error(1, 0, "Unsupported family %d", family);
 117                 /* clang does not recognize error() above as terminating
 118                  * the program, so it complains that saddr, sz are
 119                  * not initialized when this code path is taken. Silence it.
 120                  */
 121                 return;
 122         }
 123         for (i = 0; i < ARRAY_SIZE(keys); i++)
 124                 keys[i] = rand();
 125         for (i = 0; i < N_LISTEN; i++) {
 126                 rcv_fds[i] = socket(family, proto, 0);
 127                 if (rcv_fds[i] < 0)
 128                         error(1, errno, "failed to create receive socket");
 129                 if (setsockopt(rcv_fds[i], SOL_SOCKET, SO_REUSEPORT, &opt,
 130                                sizeof(opt)))
 131                         error(1, errno, "failed to set SO_REUSEPORT");
 132                 if (bind(rcv_fds[i], addr, sz))
 133                         error(1, errno, "failed to bind receive socket");
 134                 if (setsockopt(rcv_fds[i], SOL_TCP, TCP_FASTOPEN, &qlen,
 135                                sizeof(qlen)))
 136                         error(1, errno, "failed to set TCP_FASTOPEN");
 137                 set_keys(rcv_fds[i], keys);
 138                 if (proto == SOCK_STREAM && listen(rcv_fds[i], 10))
 139                         error(1, errno, "failed to listen on receive port");
 140         }
 141 }
 142 
 143 static int connect_and_send(int family, int proto)
 144 {
 145         struct sockaddr_in  saddr4 = {0};
 146         struct sockaddr_in  daddr4 = {0};
 147         struct sockaddr_in6 saddr6 = {0};
 148         struct sockaddr_in6 daddr6 = {0};
 149         struct sockaddr *saddr, *daddr;
 150         int fd, sz, ret;
 151         char data[1];
 152 
 153         switch (family) {
 154         case AF_INET:
 155                 saddr4.sin_family = AF_INET;
 156                 saddr4.sin_addr.s_addr = htonl(INADDR_ANY);
 157                 saddr4.sin_port = 0;
 158 
 159                 daddr4.sin_family = AF_INET;
 160                 if (!inet_pton(family, IP4_ADDR, &daddr4.sin_addr.s_addr))
 161                         error(1, errno, "inet_pton failed: %s", IP4_ADDR);
 162                 daddr4.sin_port = htons(PORT);
 163 
 164                 sz = sizeof(saddr4);
 165                 saddr = (struct sockaddr *)&saddr4;
 166                 daddr = (struct sockaddr *)&daddr4;
 167                 break;
 168         case AF_INET6:
 169                 saddr6.sin6_family = AF_INET6;
 170                 saddr6.sin6_addr = in6addr_any;
 171 
 172                 daddr6.sin6_family = AF_INET6;
 173                 if (!inet_pton(family, IP6_ADDR, &daddr6.sin6_addr))
 174                         error(1, errno, "inet_pton failed: %s", IP6_ADDR);
 175                 daddr6.sin6_port = htons(PORT);
 176 
 177                 sz = sizeof(saddr6);
 178                 saddr = (struct sockaddr *)&saddr6;
 179                 daddr = (struct sockaddr *)&daddr6;
 180                 break;
 181         default:
 182                 error(1, 0, "Unsupported family %d", family);
 183                 /* clang does not recognize error() above as terminating
 184                  * the program, so it complains that saddr, daddr, sz are
 185                  * not initialized when this code path is taken. Silence it.
 186                  */
 187                 return -1;
 188         }
 189         fd = socket(family, proto, 0);
 190         if (fd < 0)
 191                 error(1, errno, "failed to create send socket");
 192         if (bind(fd, saddr, sz))
 193                 error(1, errno, "failed to bind send socket");
 194         data[0] = 'a';
 195         ret = sendto(fd, data, 1, MSG_FASTOPEN, daddr, sz);
 196         if (ret != 1)
 197                 error(1, errno, "failed to sendto");
 198 
 199         return fd;
 200 }
 201 
 202 static bool is_listen_fd(int fd)
 203 {
 204         int i;
 205 
 206         for (i = 0; i < N_LISTEN; i++) {
 207                 if (rcv_fds[i] == fd)
 208                         return true;
 209         }
 210         return false;
 211 }
 212 
 213 static void rotate_key(int fd)
 214 {
 215         static int iter;
 216         static uint32_t new_key[4];
 217         uint32_t keys[8];
 218         uint32_t tmp_key[4];
 219         int i;
 220 
 221         if (iter < N_LISTEN) {
 222                 /* first set new key as backups */
 223                 if (iter == 0) {
 224                         for (i = 0; i < ARRAY_SIZE(new_key); i++)
 225                                 new_key[i] = rand();
 226                 }
 227                 get_keys(fd, keys);
 228                 memcpy(keys + 4, new_key, KEY_LENGTH);
 229                 set_keys(fd, keys);
 230         } else {
 231                 /* swap the keys */
 232                 get_keys(fd, keys);
 233                 memcpy(tmp_key, keys + 4, KEY_LENGTH);
 234                 memcpy(keys + 4, keys, KEY_LENGTH);
 235                 memcpy(keys, tmp_key, KEY_LENGTH);
 236                 set_keys(fd, keys);
 237         }
 238         if (++iter >= (N_LISTEN * 2))
 239                 iter = 0;
 240 }
 241 
 242 static void run_one_test(int family)
 243 {
 244         struct epoll_event ev;
 245         int i, send_fd;
 246         int n_loops = 10000;
 247         int rotate_key_fd = 0;
 248         int key_rotate_interval = 50;
 249         int fd, epfd;
 250         char buf[1];
 251 
 252         build_rcv_fd(family, SOCK_STREAM, rcv_fds);
 253         epfd = epoll_create(1);
 254         if (epfd < 0)
 255                 error(1, errno, "failed to create epoll");
 256         ev.events = EPOLLIN;
 257         for (i = 0; i < N_LISTEN; i++) {
 258                 ev.data.fd = rcv_fds[i];
 259                 if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fds[i], &ev))
 260                         error(1, errno, "failed to register sock epoll");
 261         }
 262         while (n_loops--) {
 263                 send_fd = connect_and_send(family, SOCK_STREAM);
 264                 if (do_rotate && ((n_loops % key_rotate_interval) == 0)) {
 265                         rotate_key(rcv_fds[rotate_key_fd]);
 266                         if (++rotate_key_fd >= N_LISTEN)
 267                                 rotate_key_fd = 0;
 268                 }
 269                 while (1) {
 270                         i = epoll_wait(epfd, &ev, 1, -1);
 271                         if (i < 0)
 272                                 error(1, errno, "epoll_wait failed");
 273                         if (is_listen_fd(ev.data.fd)) {
 274                                 fd = accept(ev.data.fd, NULL, NULL);
 275                                 if (fd < 0)
 276                                         error(1, errno, "failed to accept");
 277                                 ev.data.fd = fd;
 278                                 if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev))
 279                                         error(1, errno, "failed epoll add");
 280                                 continue;
 281                         }
 282                         i = recv(ev.data.fd, buf, sizeof(buf), 0);
 283                         if (i != 1)
 284                                 error(1, errno, "failed recv data");
 285                         if (epoll_ctl(epfd, EPOLL_CTL_DEL, ev.data.fd, NULL))
 286                                 error(1, errno, "failed epoll del");
 287                         close(ev.data.fd);
 288                         break;
 289                 }
 290                 close(send_fd);
 291         }
 292         for (i = 0; i < N_LISTEN; i++)
 293                 close(rcv_fds[i]);
 294 }
 295 
 296 static void parse_opts(int argc, char **argv)
 297 {
 298         int c;
 299 
 300         while ((c = getopt(argc, argv, "46sr")) != -1) {
 301                 switch (c) {
 302                 case '4':
 303                         do_ipv6 = false;
 304                         break;
 305                 case '6':
 306                         do_ipv6 = true;
 307                         break;
 308                 case 's':
 309                         do_sockopt = true;
 310                         break;
 311                 case 'r':
 312                         do_rotate = true;
 313                         key_len = KEY_LENGTH * 2;
 314                         break;
 315                 default:
 316                         error(1, 0, "%s: parse error", argv[0]);
 317                 }
 318         }
 319 }
 320 
 321 int main(int argc, char **argv)
 322 {
 323         parse_opts(argc, argv);
 324         proc_fd = open(PROC_FASTOPEN_KEY, O_RDWR);
 325         if (proc_fd < 0)
 326                 error(1, errno, "Unable to open %s", PROC_FASTOPEN_KEY);
 327         srand(time(NULL));
 328         if (do_ipv6)
 329                 run_one_test(AF_INET6);
 330         else
 331                 run_one_test(AF_INET);
 332         close(proc_fd);
 333         fprintf(stderr, "PASS\n");
 334         return 0;
 335 }

/* [<][>][^][v][top][bottom][index][help] */