1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
10 #include <net/gue.h>
11 #include <net/ip.h>
12 #include <net/protocol.h>
13 #include <net/udp.h>
14 #include <net/udp_tunnel.h>
15 #include <net/xfrm.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
18 
19 struct fou {
20 	struct socket *sock;
21 	u8 protocol;
22 	u8 flags;
23 	__be16 port;
24 	u16 type;
25 	struct udp_offload udp_offloads;
26 	struct list_head list;
27 	struct rcu_head rcu;
28 };
29 
30 #define FOU_F_REMCSUM_NOPARTIAL BIT(0)
31 
32 struct fou_cfg {
33 	u16 type;
34 	u8 protocol;
35 	u8 flags;
36 	struct udp_port_cfg udp_config;
37 };
38 
39 static unsigned int fou_net_id;
40 
41 struct fou_net {
42 	struct list_head fou_list;
43 	struct mutex fou_lock;
44 };
45 
fou_from_sock(struct sock * sk)46 static inline struct fou *fou_from_sock(struct sock *sk)
47 {
48 	return sk->sk_user_data;
49 }
50 
fou_recv_pull(struct sk_buff * skb,size_t len)51 static void fou_recv_pull(struct sk_buff *skb, size_t len)
52 {
53 	struct iphdr *iph = ip_hdr(skb);
54 
55 	/* Remove 'len' bytes from the packet (UDP header and
56 	 * FOU header if present).
57 	 */
58 	iph->tot_len = htons(ntohs(iph->tot_len) - len);
59 	__skb_pull(skb, len);
60 	skb_postpull_rcsum(skb, udp_hdr(skb), len);
61 	skb_reset_transport_header(skb);
62 }
63 
fou_udp_recv(struct sock * sk,struct sk_buff * skb)64 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
65 {
66 	struct fou *fou = fou_from_sock(sk);
67 
68 	if (!fou)
69 		return 1;
70 
71 	fou_recv_pull(skb, sizeof(struct udphdr));
72 
73 	return -fou->protocol;
74 }
75 
gue_remcsum(struct sk_buff * skb,struct guehdr * guehdr,void * data,size_t hdrlen,u8 ipproto,bool nopartial)76 static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
77 				  void *data, size_t hdrlen, u8 ipproto,
78 				  bool nopartial)
79 {
80 	__be16 *pd = data;
81 	size_t start = ntohs(pd[0]);
82 	size_t offset = ntohs(pd[1]);
83 	size_t plen = hdrlen + max_t(size_t, offset + sizeof(u16), start);
84 
85 	if (!pskb_may_pull(skb, plen))
86 		return NULL;
87 	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
88 
89 	skb_remcsum_process(skb, (void *)guehdr + hdrlen,
90 			    start, offset, nopartial);
91 
92 	return guehdr;
93 }
94 
gue_control_message(struct sk_buff * skb,struct guehdr * guehdr)95 static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
96 {
97 	/* No support yet */
98 	kfree_skb(skb);
99 	return 0;
100 }
101 
gue_udp_recv(struct sock * sk,struct sk_buff * skb)102 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
103 {
104 	struct fou *fou = fou_from_sock(sk);
105 	size_t len, optlen, hdrlen;
106 	struct guehdr *guehdr;
107 	void *data;
108 	u16 doffset = 0;
109 
110 	if (!fou)
111 		return 1;
112 
113 	len = sizeof(struct udphdr) + sizeof(struct guehdr);
114 	if (!pskb_may_pull(skb, len))
115 		goto drop;
116 
117 	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
118 
119 	optlen = guehdr->hlen << 2;
120 	len += optlen;
121 
122 	if (!pskb_may_pull(skb, len))
123 		goto drop;
124 
125 	/* guehdr may change after pull */
126 	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
127 
128 	hdrlen = sizeof(struct guehdr) + optlen;
129 
130 	if (guehdr->version != 0 || validate_gue_flags(guehdr, optlen))
131 		goto drop;
132 
133 	hdrlen = sizeof(struct guehdr) + optlen;
134 
135 	ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
136 
137 	/* Pull csum through the guehdr now . This can be used if
138 	 * there is a remote checksum offload.
139 	 */
140 	skb_postpull_rcsum(skb, udp_hdr(skb), len);
141 
142 	data = &guehdr[1];
143 
144 	if (guehdr->flags & GUE_FLAG_PRIV) {
145 		__be32 flags = *(__be32 *)(data + doffset);
146 
147 		doffset += GUE_LEN_PRIV;
148 
149 		if (flags & GUE_PFLAG_REMCSUM) {
150 			guehdr = gue_remcsum(skb, guehdr, data + doffset,
151 					     hdrlen, guehdr->proto_ctype,
152 					     !!(fou->flags &
153 						FOU_F_REMCSUM_NOPARTIAL));
154 			if (!guehdr)
155 				goto drop;
156 
157 			data = &guehdr[1];
158 
159 			doffset += GUE_PLEN_REMCSUM;
160 		}
161 	}
162 
163 	if (unlikely(guehdr->control))
164 		return gue_control_message(skb, guehdr);
165 
166 	__skb_pull(skb, sizeof(struct udphdr) + hdrlen);
167 	skb_reset_transport_header(skb);
168 
169 	return -guehdr->proto_ctype;
170 
171 drop:
172 	kfree_skb(skb);
173 	return 0;
174 }
175 
fou_gro_receive(struct sk_buff ** head,struct sk_buff * skb,struct udp_offload * uoff)176 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
177 					struct sk_buff *skb,
178 					struct udp_offload *uoff)
179 {
180 	const struct net_offload *ops;
181 	struct sk_buff **pp = NULL;
182 	u8 proto = NAPI_GRO_CB(skb)->proto;
183 	const struct net_offload **offloads;
184 
185 	rcu_read_lock();
186 	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
187 	ops = rcu_dereference(offloads[proto]);
188 	if (!ops || !ops->callbacks.gro_receive)
189 		goto out_unlock;
190 
191 	pp = ops->callbacks.gro_receive(head, skb);
192 
193 out_unlock:
194 	rcu_read_unlock();
195 
196 	return pp;
197 }
198 
fou_gro_complete(struct sk_buff * skb,int nhoff,struct udp_offload * uoff)199 static int fou_gro_complete(struct sk_buff *skb, int nhoff,
200 			    struct udp_offload *uoff)
201 {
202 	const struct net_offload *ops;
203 	u8 proto = NAPI_GRO_CB(skb)->proto;
204 	int err = -ENOSYS;
205 	const struct net_offload **offloads;
206 
207 	udp_tunnel_gro_complete(skb, nhoff);
208 
209 	rcu_read_lock();
210 	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
211 	ops = rcu_dereference(offloads[proto]);
212 	if (WARN_ON(!ops || !ops->callbacks.gro_complete))
213 		goto out_unlock;
214 
215 	err = ops->callbacks.gro_complete(skb, nhoff);
216 
217 out_unlock:
218 	rcu_read_unlock();
219 
220 	return err;
221 }
222 
gue_gro_remcsum(struct sk_buff * skb,unsigned int off,struct guehdr * guehdr,void * data,size_t hdrlen,u8 ipproto,struct gro_remcsum * grc,bool nopartial)223 static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off,
224 				      struct guehdr *guehdr, void *data,
225 				      size_t hdrlen, u8 ipproto,
226 				      struct gro_remcsum *grc, bool nopartial)
227 {
228 	__be16 *pd = data;
229 	size_t start = ntohs(pd[0]);
230 	size_t offset = ntohs(pd[1]);
231 	size_t plen = hdrlen + max_t(size_t, offset + sizeof(u16), start);
232 
233 	if (skb->remcsum_offload)
234 		return NULL;
235 
236 	if (!NAPI_GRO_CB(skb)->csum_valid)
237 		return NULL;
238 
239 	/* Pull checksum that will be written */
240 	if (skb_gro_header_hard(skb, off + plen)) {
241 		guehdr = skb_gro_header_slow(skb, off + plen, off);
242 		if (!guehdr)
243 			return NULL;
244 	}
245 
246 	skb_gro_remcsum_process(skb, (void *)guehdr + hdrlen,
247 				start, offset, grc, nopartial);
248 
249 	skb->remcsum_offload = 1;
250 
251 	return guehdr;
252 }
253 
gue_gro_receive(struct sk_buff ** head,struct sk_buff * skb,struct udp_offload * uoff)254 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
255 					struct sk_buff *skb,
256 					struct udp_offload *uoff)
257 {
258 	const struct net_offload **offloads;
259 	const struct net_offload *ops;
260 	struct sk_buff **pp = NULL;
261 	struct sk_buff *p;
262 	struct guehdr *guehdr;
263 	size_t len, optlen, hdrlen, off;
264 	void *data;
265 	u16 doffset = 0;
266 	int flush = 1;
267 	struct fou *fou = container_of(uoff, struct fou, udp_offloads);
268 	struct gro_remcsum grc;
269 
270 	skb_gro_remcsum_init(&grc);
271 
272 	off = skb_gro_offset(skb);
273 	len = off + sizeof(*guehdr);
274 
275 	guehdr = skb_gro_header_fast(skb, off);
276 	if (skb_gro_header_hard(skb, len)) {
277 		guehdr = skb_gro_header_slow(skb, len, off);
278 		if (unlikely(!guehdr))
279 			goto out;
280 	}
281 
282 	optlen = guehdr->hlen << 2;
283 	len += optlen;
284 
285 	if (skb_gro_header_hard(skb, len)) {
286 		guehdr = skb_gro_header_slow(skb, len, off);
287 		if (unlikely(!guehdr))
288 			goto out;
289 	}
290 
291 	if (unlikely(guehdr->control) || guehdr->version != 0 ||
292 	    validate_gue_flags(guehdr, optlen))
293 		goto out;
294 
295 	hdrlen = sizeof(*guehdr) + optlen;
296 
297 	/* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
298 	 * this is needed if there is a remote checkcsum offload.
299 	 */
300 	skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
301 
302 	data = &guehdr[1];
303 
304 	if (guehdr->flags & GUE_FLAG_PRIV) {
305 		__be32 flags = *(__be32 *)(data + doffset);
306 
307 		doffset += GUE_LEN_PRIV;
308 
309 		if (flags & GUE_PFLAG_REMCSUM) {
310 			guehdr = gue_gro_remcsum(skb, off, guehdr,
311 						 data + doffset, hdrlen,
312 						 guehdr->proto_ctype, &grc,
313 						 !!(fou->flags &
314 						    FOU_F_REMCSUM_NOPARTIAL));
315 			if (!guehdr)
316 				goto out;
317 
318 			data = &guehdr[1];
319 
320 			doffset += GUE_PLEN_REMCSUM;
321 		}
322 	}
323 
324 	skb_gro_pull(skb, hdrlen);
325 
326 	flush = 0;
327 
328 	for (p = *head; p; p = p->next) {
329 		const struct guehdr *guehdr2;
330 
331 		if (!NAPI_GRO_CB(p)->same_flow)
332 			continue;
333 
334 		guehdr2 = (struct guehdr *)(p->data + off);
335 
336 		/* Compare base GUE header to be equal (covers
337 		 * hlen, version, proto_ctype, and flags.
338 		 */
339 		if (guehdr->word != guehdr2->word) {
340 			NAPI_GRO_CB(p)->same_flow = 0;
341 			continue;
342 		}
343 
344 		/* Compare optional fields are the same. */
345 		if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
346 					   guehdr->hlen << 2)) {
347 			NAPI_GRO_CB(p)->same_flow = 0;
348 			continue;
349 		}
350 	}
351 
352 	rcu_read_lock();
353 	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
354 	ops = rcu_dereference(offloads[guehdr->proto_ctype]);
355 	if (WARN_ON(!ops || !ops->callbacks.gro_receive))
356 		goto out_unlock;
357 
358 	pp = ops->callbacks.gro_receive(head, skb);
359 
360 out_unlock:
361 	rcu_read_unlock();
362 out:
363 	NAPI_GRO_CB(skb)->flush |= flush;
364 	skb_gro_remcsum_cleanup(skb, &grc);
365 
366 	return pp;
367 }
368 
gue_gro_complete(struct sk_buff * skb,int nhoff,struct udp_offload * uoff)369 static int gue_gro_complete(struct sk_buff *skb, int nhoff,
370 			    struct udp_offload *uoff)
371 {
372 	const struct net_offload **offloads;
373 	struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
374 	const struct net_offload *ops;
375 	unsigned int guehlen;
376 	u8 proto;
377 	int err = -ENOENT;
378 
379 	proto = guehdr->proto_ctype;
380 
381 	guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
382 
383 	rcu_read_lock();
384 	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
385 	ops = rcu_dereference(offloads[proto]);
386 	if (WARN_ON(!ops || !ops->callbacks.gro_complete))
387 		goto out_unlock;
388 
389 	err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
390 
391 out_unlock:
392 	rcu_read_unlock();
393 	return err;
394 }
395 
fou_add_to_port_list(struct net * net,struct fou * fou)396 static int fou_add_to_port_list(struct net *net, struct fou *fou)
397 {
398 	struct fou_net *fn = net_generic(net, fou_net_id);
399 	struct fou *fout;
400 
401 	mutex_lock(&fn->fou_lock);
402 	list_for_each_entry(fout, &fn->fou_list, list) {
403 		if (fou->port == fout->port) {
404 			mutex_unlock(&fn->fou_lock);
405 			return -EALREADY;
406 		}
407 	}
408 
409 	list_add(&fou->list, &fn->fou_list);
410 	mutex_unlock(&fn->fou_lock);
411 
412 	return 0;
413 }
414 
fou_release(struct fou * fou)415 static void fou_release(struct fou *fou)
416 {
417 	struct socket *sock = fou->sock;
418 	struct sock *sk = sock->sk;
419 
420 	if (sk->sk_family == AF_INET)
421 		udp_del_offload(&fou->udp_offloads);
422 	list_del(&fou->list);
423 	udp_tunnel_sock_release(sock);
424 
425 	kfree_rcu(fou, rcu);
426 }
427 
fou_encap_init(struct sock * sk,struct fou * fou,struct fou_cfg * cfg)428 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
429 {
430 	udp_sk(sk)->encap_rcv = fou_udp_recv;
431 	fou->protocol = cfg->protocol;
432 	fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
433 	fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
434 	fou->udp_offloads.port = cfg->udp_config.local_udp_port;
435 	fou->udp_offloads.ipproto = cfg->protocol;
436 
437 	return 0;
438 }
439 
gue_encap_init(struct sock * sk,struct fou * fou,struct fou_cfg * cfg)440 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
441 {
442 	udp_sk(sk)->encap_rcv = gue_udp_recv;
443 	fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
444 	fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
445 	fou->udp_offloads.port = cfg->udp_config.local_udp_port;
446 
447 	return 0;
448 }
449 
fou_create(struct net * net,struct fou_cfg * cfg,struct socket ** sockp)450 static int fou_create(struct net *net, struct fou_cfg *cfg,
451 		      struct socket **sockp)
452 {
453 	struct socket *sock = NULL;
454 	struct fou *fou = NULL;
455 	struct sock *sk;
456 	int err;
457 
458 	/* Open UDP socket */
459 	err = udp_sock_create(net, &cfg->udp_config, &sock);
460 	if (err < 0)
461 		goto error;
462 
463 	/* Allocate FOU port structure */
464 	fou = kzalloc(sizeof(*fou), GFP_KERNEL);
465 	if (!fou) {
466 		err = -ENOMEM;
467 		goto error;
468 	}
469 
470 	sk = sock->sk;
471 
472 	fou->flags = cfg->flags;
473 	fou->port = cfg->udp_config.local_udp_port;
474 
475 	/* Initial for fou type */
476 	switch (cfg->type) {
477 	case FOU_ENCAP_DIRECT:
478 		err = fou_encap_init(sk, fou, cfg);
479 		if (err)
480 			goto error;
481 		break;
482 	case FOU_ENCAP_GUE:
483 		err = gue_encap_init(sk, fou, cfg);
484 		if (err)
485 			goto error;
486 		break;
487 	default:
488 		err = -EINVAL;
489 		goto error;
490 	}
491 
492 	fou->type = cfg->type;
493 
494 	udp_sk(sk)->encap_type = 1;
495 	udp_encap_enable();
496 
497 	sk->sk_user_data = fou;
498 	fou->sock = sock;
499 
500 	inet_inc_convert_csum(sk);
501 
502 	sk->sk_allocation = GFP_ATOMIC;
503 
504 	if (cfg->udp_config.family == AF_INET) {
505 		err = udp_add_offload(&fou->udp_offloads);
506 		if (err)
507 			goto error;
508 	}
509 
510 	err = fou_add_to_port_list(net, fou);
511 	if (err)
512 		goto error;
513 
514 	if (sockp)
515 		*sockp = sock;
516 
517 	return 0;
518 
519 error:
520 	kfree(fou);
521 	if (sock)
522 		udp_tunnel_sock_release(sock);
523 
524 	return err;
525 }
526 
fou_destroy(struct net * net,struct fou_cfg * cfg)527 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
528 {
529 	struct fou_net *fn = net_generic(net, fou_net_id);
530 	__be16 port = cfg->udp_config.local_udp_port;
531 	int err = -EINVAL;
532 	struct fou *fou;
533 
534 	mutex_lock(&fn->fou_lock);
535 	list_for_each_entry(fou, &fn->fou_list, list) {
536 		if (fou->port == port) {
537 			fou_release(fou);
538 			err = 0;
539 			break;
540 		}
541 	}
542 	mutex_unlock(&fn->fou_lock);
543 
544 	return err;
545 }
546 
547 static struct genl_family fou_nl_family = {
548 	.id		= GENL_ID_GENERATE,
549 	.hdrsize	= 0,
550 	.name		= FOU_GENL_NAME,
551 	.version	= FOU_GENL_VERSION,
552 	.maxattr	= FOU_ATTR_MAX,
553 	.netnsok	= true,
554 };
555 
556 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
557 	[FOU_ATTR_PORT] = { .type = NLA_U16, },
558 	[FOU_ATTR_AF] = { .type = NLA_U8, },
559 	[FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
560 	[FOU_ATTR_TYPE] = { .type = NLA_U8, },
561 	[FOU_ATTR_REMCSUM_NOPARTIAL] = { .type = NLA_FLAG, },
562 };
563 
parse_nl_config(struct genl_info * info,struct fou_cfg * cfg)564 static int parse_nl_config(struct genl_info *info,
565 			   struct fou_cfg *cfg)
566 {
567 	memset(cfg, 0, sizeof(*cfg));
568 
569 	cfg->udp_config.family = AF_INET;
570 
571 	if (info->attrs[FOU_ATTR_AF]) {
572 		u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
573 
574 		if (family != AF_INET && family != AF_INET6)
575 			return -EINVAL;
576 
577 		cfg->udp_config.family = family;
578 	}
579 
580 	if (info->attrs[FOU_ATTR_PORT]) {
581 		__be16 port = nla_get_be16(info->attrs[FOU_ATTR_PORT]);
582 
583 		cfg->udp_config.local_udp_port = port;
584 	}
585 
586 	if (info->attrs[FOU_ATTR_IPPROTO])
587 		cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
588 
589 	if (info->attrs[FOU_ATTR_TYPE])
590 		cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
591 
592 	if (info->attrs[FOU_ATTR_REMCSUM_NOPARTIAL])
593 		cfg->flags |= FOU_F_REMCSUM_NOPARTIAL;
594 
595 	return 0;
596 }
597 
fou_nl_cmd_add_port(struct sk_buff * skb,struct genl_info * info)598 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
599 {
600 	struct net *net = genl_info_net(info);
601 	struct fou_cfg cfg;
602 	int err;
603 
604 	err = parse_nl_config(info, &cfg);
605 	if (err)
606 		return err;
607 
608 	return fou_create(net, &cfg, NULL);
609 }
610 
fou_nl_cmd_rm_port(struct sk_buff * skb,struct genl_info * info)611 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
612 {
613 	struct net *net = genl_info_net(info);
614 	struct fou_cfg cfg;
615 	int err;
616 
617 	err = parse_nl_config(info, &cfg);
618 	if (err)
619 		return err;
620 
621 	return fou_destroy(net, &cfg);
622 }
623 
fou_fill_info(struct fou * fou,struct sk_buff * msg)624 static int fou_fill_info(struct fou *fou, struct sk_buff *msg)
625 {
626 	if (nla_put_u8(msg, FOU_ATTR_AF, fou->sock->sk->sk_family) ||
627 	    nla_put_be16(msg, FOU_ATTR_PORT, fou->port) ||
628 	    nla_put_u8(msg, FOU_ATTR_IPPROTO, fou->protocol) ||
629 	    nla_put_u8(msg, FOU_ATTR_TYPE, fou->type))
630 		return -1;
631 
632 	if (fou->flags & FOU_F_REMCSUM_NOPARTIAL)
633 		if (nla_put_flag(msg, FOU_ATTR_REMCSUM_NOPARTIAL))
634 			return -1;
635 	return 0;
636 }
637 
fou_dump_info(struct fou * fou,u32 portid,u32 seq,u32 flags,struct sk_buff * skb,u8 cmd)638 static int fou_dump_info(struct fou *fou, u32 portid, u32 seq,
639 			 u32 flags, struct sk_buff *skb, u8 cmd)
640 {
641 	void *hdr;
642 
643 	hdr = genlmsg_put(skb, portid, seq, &fou_nl_family, flags, cmd);
644 	if (!hdr)
645 		return -ENOMEM;
646 
647 	if (fou_fill_info(fou, skb) < 0)
648 		goto nla_put_failure;
649 
650 	genlmsg_end(skb, hdr);
651 	return 0;
652 
653 nla_put_failure:
654 	genlmsg_cancel(skb, hdr);
655 	return -EMSGSIZE;
656 }
657 
fou_nl_cmd_get_port(struct sk_buff * skb,struct genl_info * info)658 static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info)
659 {
660 	struct net *net = genl_info_net(info);
661 	struct fou_net *fn = net_generic(net, fou_net_id);
662 	struct sk_buff *msg;
663 	struct fou_cfg cfg;
664 	struct fou *fout;
665 	__be16 port;
666 	int ret;
667 
668 	ret = parse_nl_config(info, &cfg);
669 	if (ret)
670 		return ret;
671 	port = cfg.udp_config.local_udp_port;
672 	if (port == 0)
673 		return -EINVAL;
674 
675 	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
676 	if (!msg)
677 		return -ENOMEM;
678 
679 	ret = -ESRCH;
680 	mutex_lock(&fn->fou_lock);
681 	list_for_each_entry(fout, &fn->fou_list, list) {
682 		if (port == fout->port) {
683 			ret = fou_dump_info(fout, info->snd_portid,
684 					    info->snd_seq, 0, msg,
685 					    info->genlhdr->cmd);
686 			break;
687 		}
688 	}
689 	mutex_unlock(&fn->fou_lock);
690 	if (ret < 0)
691 		goto out_free;
692 
693 	return genlmsg_reply(msg, info);
694 
695 out_free:
696 	nlmsg_free(msg);
697 	return ret;
698 }
699 
fou_nl_dump(struct sk_buff * skb,struct netlink_callback * cb)700 static int fou_nl_dump(struct sk_buff *skb, struct netlink_callback *cb)
701 {
702 	struct net *net = sock_net(skb->sk);
703 	struct fou_net *fn = net_generic(net, fou_net_id);
704 	struct fou *fout;
705 	int idx = 0, ret;
706 
707 	mutex_lock(&fn->fou_lock);
708 	list_for_each_entry(fout, &fn->fou_list, list) {
709 		if (idx++ < cb->args[0])
710 			continue;
711 		ret = fou_dump_info(fout, NETLINK_CB(cb->skb).portid,
712 				    cb->nlh->nlmsg_seq, NLM_F_MULTI,
713 				    skb, FOU_CMD_GET);
714 		if (ret)
715 			break;
716 	}
717 	mutex_unlock(&fn->fou_lock);
718 
719 	cb->args[0] = idx;
720 	return skb->len;
721 }
722 
723 static const struct genl_ops fou_nl_ops[] = {
724 	{
725 		.cmd = FOU_CMD_ADD,
726 		.doit = fou_nl_cmd_add_port,
727 		.policy = fou_nl_policy,
728 		.flags = GENL_ADMIN_PERM,
729 	},
730 	{
731 		.cmd = FOU_CMD_DEL,
732 		.doit = fou_nl_cmd_rm_port,
733 		.policy = fou_nl_policy,
734 		.flags = GENL_ADMIN_PERM,
735 	},
736 	{
737 		.cmd = FOU_CMD_GET,
738 		.doit = fou_nl_cmd_get_port,
739 		.dumpit = fou_nl_dump,
740 		.policy = fou_nl_policy,
741 	},
742 };
743 
fou_encap_hlen(struct ip_tunnel_encap * e)744 size_t fou_encap_hlen(struct ip_tunnel_encap *e)
745 {
746 	return sizeof(struct udphdr);
747 }
748 EXPORT_SYMBOL(fou_encap_hlen);
749 
gue_encap_hlen(struct ip_tunnel_encap * e)750 size_t gue_encap_hlen(struct ip_tunnel_encap *e)
751 {
752 	size_t len;
753 	bool need_priv = false;
754 
755 	len = sizeof(struct udphdr) + sizeof(struct guehdr);
756 
757 	if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) {
758 		len += GUE_PLEN_REMCSUM;
759 		need_priv = true;
760 	}
761 
762 	len += need_priv ? GUE_LEN_PRIV : 0;
763 
764 	return len;
765 }
766 EXPORT_SYMBOL(gue_encap_hlen);
767 
fou_build_udp(struct sk_buff * skb,struct ip_tunnel_encap * e,struct flowi4 * fl4,u8 * protocol,__be16 sport)768 static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
769 			  struct flowi4 *fl4, u8 *protocol, __be16 sport)
770 {
771 	struct udphdr *uh;
772 
773 	skb_push(skb, sizeof(struct udphdr));
774 	skb_reset_transport_header(skb);
775 
776 	uh = udp_hdr(skb);
777 
778 	uh->dest = e->dport;
779 	uh->source = sport;
780 	uh->len = htons(skb->len);
781 	uh->check = 0;
782 	udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
783 		     fl4->saddr, fl4->daddr, skb->len);
784 
785 	*protocol = IPPROTO_UDP;
786 }
787 
fou_build_header(struct sk_buff * skb,struct ip_tunnel_encap * e,u8 * protocol,struct flowi4 * fl4)788 int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
789 		     u8 *protocol, struct flowi4 *fl4)
790 {
791 	bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
792 	int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
793 	__be16 sport;
794 
795 	skb = iptunnel_handle_offloads(skb, csum, type);
796 
797 	if (IS_ERR(skb))
798 		return PTR_ERR(skb);
799 
800 	sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
801 					       skb, 0, 0, false);
802 	fou_build_udp(skb, e, fl4, protocol, sport);
803 
804 	return 0;
805 }
806 EXPORT_SYMBOL(fou_build_header);
807 
gue_build_header(struct sk_buff * skb,struct ip_tunnel_encap * e,u8 * protocol,struct flowi4 * fl4)808 int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
809 		     u8 *protocol, struct flowi4 *fl4)
810 {
811 	bool csum = !!(e->flags & TUNNEL_ENCAP_FLAG_CSUM);
812 	int type = csum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
813 	struct guehdr *guehdr;
814 	size_t hdrlen, optlen = 0;
815 	__be16 sport;
816 	void *data;
817 	bool need_priv = false;
818 
819 	if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
820 	    skb->ip_summed == CHECKSUM_PARTIAL) {
821 		csum = false;
822 		optlen += GUE_PLEN_REMCSUM;
823 		type |= SKB_GSO_TUNNEL_REMCSUM;
824 		need_priv = true;
825 	}
826 
827 	optlen += need_priv ? GUE_LEN_PRIV : 0;
828 
829 	skb = iptunnel_handle_offloads(skb, csum, type);
830 
831 	if (IS_ERR(skb))
832 		return PTR_ERR(skb);
833 
834 	/* Get source port (based on flow hash) before skb_push */
835 	sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
836 					       skb, 0, 0, false);
837 
838 	hdrlen = sizeof(struct guehdr) + optlen;
839 
840 	skb_push(skb, hdrlen);
841 
842 	guehdr = (struct guehdr *)skb->data;
843 
844 	guehdr->control = 0;
845 	guehdr->version = 0;
846 	guehdr->hlen = optlen >> 2;
847 	guehdr->flags = 0;
848 	guehdr->proto_ctype = *protocol;
849 
850 	data = &guehdr[1];
851 
852 	if (need_priv) {
853 		__be32 *flags = data;
854 
855 		guehdr->flags |= GUE_FLAG_PRIV;
856 		*flags = 0;
857 		data += GUE_LEN_PRIV;
858 
859 		if (type & SKB_GSO_TUNNEL_REMCSUM) {
860 			u16 csum_start = skb_checksum_start_offset(skb);
861 			__be16 *pd = data;
862 
863 			if (csum_start < hdrlen)
864 				return -EINVAL;
865 
866 			csum_start -= hdrlen;
867 			pd[0] = htons(csum_start);
868 			pd[1] = htons(csum_start + skb->csum_offset);
869 
870 			if (!skb_is_gso(skb)) {
871 				skb->ip_summed = CHECKSUM_NONE;
872 				skb->encapsulation = 0;
873 			}
874 
875 			*flags |= GUE_PFLAG_REMCSUM;
876 			data += GUE_PLEN_REMCSUM;
877 		}
878 
879 	}
880 
881 	fou_build_udp(skb, e, fl4, protocol, sport);
882 
883 	return 0;
884 }
885 EXPORT_SYMBOL(gue_build_header);
886 
887 #ifdef CONFIG_NET_FOU_IP_TUNNELS
888 
889 static const struct ip_tunnel_encap_ops fou_iptun_ops = {
890 	.encap_hlen = fou_encap_hlen,
891 	.build_header = fou_build_header,
892 };
893 
894 static const struct ip_tunnel_encap_ops gue_iptun_ops = {
895 	.encap_hlen = gue_encap_hlen,
896 	.build_header = gue_build_header,
897 };
898 
ip_tunnel_encap_add_fou_ops(void)899 static int ip_tunnel_encap_add_fou_ops(void)
900 {
901 	int ret;
902 
903 	ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
904 	if (ret < 0) {
905 		pr_err("can't add fou ops\n");
906 		return ret;
907 	}
908 
909 	ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
910 	if (ret < 0) {
911 		pr_err("can't add gue ops\n");
912 		ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
913 		return ret;
914 	}
915 
916 	return 0;
917 }
918 
ip_tunnel_encap_del_fou_ops(void)919 static void ip_tunnel_encap_del_fou_ops(void)
920 {
921 	ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
922 	ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
923 }
924 
925 #else
926 
ip_tunnel_encap_add_fou_ops(void)927 static int ip_tunnel_encap_add_fou_ops(void)
928 {
929 	return 0;
930 }
931 
ip_tunnel_encap_del_fou_ops(void)932 static void ip_tunnel_encap_del_fou_ops(void)
933 {
934 }
935 
936 #endif
937 
fou_init_net(struct net * net)938 static __net_init int fou_init_net(struct net *net)
939 {
940 	struct fou_net *fn = net_generic(net, fou_net_id);
941 
942 	INIT_LIST_HEAD(&fn->fou_list);
943 	mutex_init(&fn->fou_lock);
944 	return 0;
945 }
946 
fou_exit_net(struct net * net)947 static __net_exit void fou_exit_net(struct net *net)
948 {
949 	struct fou_net *fn = net_generic(net, fou_net_id);
950 	struct fou *fou, *next;
951 
952 	/* Close all the FOU sockets */
953 	mutex_lock(&fn->fou_lock);
954 	list_for_each_entry_safe(fou, next, &fn->fou_list, list)
955 		fou_release(fou);
956 	mutex_unlock(&fn->fou_lock);
957 }
958 
959 static struct pernet_operations fou_net_ops = {
960 	.init = fou_init_net,
961 	.exit = fou_exit_net,
962 	.id   = &fou_net_id,
963 	.size = sizeof(struct fou_net),
964 };
965 
fou_init(void)966 static int __init fou_init(void)
967 {
968 	int ret;
969 
970 	ret = register_pernet_device(&fou_net_ops);
971 	if (ret)
972 		goto exit;
973 
974 	ret = genl_register_family_with_ops(&fou_nl_family,
975 					    fou_nl_ops);
976 	if (ret < 0)
977 		goto unregister;
978 
979 	ret = ip_tunnel_encap_add_fou_ops();
980 	if (ret == 0)
981 		return 0;
982 
983 	genl_unregister_family(&fou_nl_family);
984 unregister:
985 	unregister_pernet_device(&fou_net_ops);
986 exit:
987 	return ret;
988 }
989 
fou_fini(void)990 static void __exit fou_fini(void)
991 {
992 	ip_tunnel_encap_del_fou_ops();
993 	genl_unregister_family(&fou_nl_family);
994 	unregister_pernet_device(&fou_net_ops);
995 }
996 
997 module_init(fou_init);
998 module_exit(fou_fini);
999 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
1000 MODULE_LICENSE("GPL");
1001