1/*
2 * Pluggable TCP congestion control support and newReno
3 * congestion control.
4 * Based on ideas from I/O scheduler support and Web100.
5 *
6 * Copyright (C) 2005 Stephen Hemminger <shemminger@osdl.org>
7 */
8
9#define pr_fmt(fmt) "TCP: " fmt
10
11#include <linux/module.h>
12#include <linux/mm.h>
13#include <linux/types.h>
14#include <linux/list.h>
15#include <linux/gfp.h>
16#include <linux/jhash.h>
17#include <net/tcp.h>
18
19static DEFINE_SPINLOCK(tcp_cong_list_lock);
20static LIST_HEAD(tcp_cong_list);
21
22/* Simple linear search, don't expect many entries! */
23static struct tcp_congestion_ops *tcp_ca_find(const char *name)
24{
25	struct tcp_congestion_ops *e;
26
27	list_for_each_entry_rcu(e, &tcp_cong_list, list) {
28		if (strcmp(e->name, name) == 0)
29			return e;
30	}
31
32	return NULL;
33}
34
35/* Must be called with rcu lock held */
36static const struct tcp_congestion_ops *__tcp_ca_find_autoload(const char *name)
37{
38	const struct tcp_congestion_ops *ca = tcp_ca_find(name);
39#ifdef CONFIG_MODULES
40	if (!ca && capable(CAP_NET_ADMIN)) {
41		rcu_read_unlock();
42		request_module("tcp_%s", name);
43		rcu_read_lock();
44		ca = tcp_ca_find(name);
45	}
46#endif
47	return ca;
48}
49
50/* Simple linear search, not much in here. */
51struct tcp_congestion_ops *tcp_ca_find_key(u32 key)
52{
53	struct tcp_congestion_ops *e;
54
55	list_for_each_entry_rcu(e, &tcp_cong_list, list) {
56		if (e->key == key)
57			return e;
58	}
59
60	return NULL;
61}
62
63/*
64 * Attach new congestion control algorithm to the list
65 * of available options.
66 */
67int tcp_register_congestion_control(struct tcp_congestion_ops *ca)
68{
69	int ret = 0;
70
71	/* all algorithms must implement ssthresh and cong_avoid ops */
72	if (!ca->ssthresh || !ca->cong_avoid) {
73		pr_err("%s does not implement required ops\n", ca->name);
74		return -EINVAL;
75	}
76
77	ca->key = jhash(ca->name, sizeof(ca->name), strlen(ca->name));
78
79	spin_lock(&tcp_cong_list_lock);
80	if (ca->key == TCP_CA_UNSPEC || tcp_ca_find_key(ca->key)) {
81		pr_notice("%s already registered or non-unique key\n",
82			  ca->name);
83		ret = -EEXIST;
84	} else {
85		list_add_tail_rcu(&ca->list, &tcp_cong_list);
86		pr_debug("%s registered\n", ca->name);
87	}
88	spin_unlock(&tcp_cong_list_lock);
89
90	return ret;
91}
92EXPORT_SYMBOL_GPL(tcp_register_congestion_control);
93
94/*
95 * Remove congestion control algorithm, called from
96 * the module's remove function.  Module ref counts are used
97 * to ensure that this can't be done till all sockets using
98 * that method are closed.
99 */
100void tcp_unregister_congestion_control(struct tcp_congestion_ops *ca)
101{
102	spin_lock(&tcp_cong_list_lock);
103	list_del_rcu(&ca->list);
104	spin_unlock(&tcp_cong_list_lock);
105
106	/* Wait for outstanding readers to complete before the
107	 * module gets removed entirely.
108	 *
109	 * A try_module_get() should fail by now as our module is
110	 * in "going" state since no refs are held anymore and
111	 * module_exit() handler being called.
112	 */
113	synchronize_rcu();
114}
115EXPORT_SYMBOL_GPL(tcp_unregister_congestion_control);
116
117u32 tcp_ca_get_key_by_name(const char *name)
118{
119	const struct tcp_congestion_ops *ca;
120	u32 key;
121
122	might_sleep();
123
124	rcu_read_lock();
125	ca = __tcp_ca_find_autoload(name);
126	key = ca ? ca->key : TCP_CA_UNSPEC;
127	rcu_read_unlock();
128
129	return key;
130}
131EXPORT_SYMBOL_GPL(tcp_ca_get_key_by_name);
132
133char *tcp_ca_get_name_by_key(u32 key, char *buffer)
134{
135	const struct tcp_congestion_ops *ca;
136	char *ret = NULL;
137
138	rcu_read_lock();
139	ca = tcp_ca_find_key(key);
140	if (ca)
141		ret = strncpy(buffer, ca->name,
142			      TCP_CA_NAME_MAX);
143	rcu_read_unlock();
144
145	return ret;
146}
147EXPORT_SYMBOL_GPL(tcp_ca_get_name_by_key);
148
149/* Assign choice of congestion control. */
150void tcp_assign_congestion_control(struct sock *sk)
151{
152	struct inet_connection_sock *icsk = inet_csk(sk);
153	struct tcp_congestion_ops *ca;
154
155	rcu_read_lock();
156	list_for_each_entry_rcu(ca, &tcp_cong_list, list) {
157		if (likely(try_module_get(ca->owner))) {
158			icsk->icsk_ca_ops = ca;
159			goto out;
160		}
161		/* Fallback to next available. The last really
162		 * guaranteed fallback is Reno from this list.
163		 */
164	}
165out:
166	rcu_read_unlock();
167
168	/* Clear out private data before diag gets it and
169	 * the ca has not been initialized.
170	 */
171	if (ca->get_info)
172		memset(icsk->icsk_ca_priv, 0, sizeof(icsk->icsk_ca_priv));
173}
174
175void tcp_init_congestion_control(struct sock *sk)
176{
177	const struct inet_connection_sock *icsk = inet_csk(sk);
178
179	if (icsk->icsk_ca_ops->init)
180		icsk->icsk_ca_ops->init(sk);
181}
182
183static void tcp_reinit_congestion_control(struct sock *sk,
184					  const struct tcp_congestion_ops *ca)
185{
186	struct inet_connection_sock *icsk = inet_csk(sk);
187
188	tcp_cleanup_congestion_control(sk);
189	icsk->icsk_ca_ops = ca;
190	icsk->icsk_ca_setsockopt = 1;
191
192	if (sk->sk_state != TCP_CLOSE && icsk->icsk_ca_ops->init)
193		icsk->icsk_ca_ops->init(sk);
194}
195
196/* Manage refcounts on socket close. */
197void tcp_cleanup_congestion_control(struct sock *sk)
198{
199	struct inet_connection_sock *icsk = inet_csk(sk);
200
201	if (icsk->icsk_ca_ops->release)
202		icsk->icsk_ca_ops->release(sk);
203	module_put(icsk->icsk_ca_ops->owner);
204}
205
206/* Used by sysctl to change default congestion control */
207int tcp_set_default_congestion_control(const char *name)
208{
209	struct tcp_congestion_ops *ca;
210	int ret = -ENOENT;
211
212	spin_lock(&tcp_cong_list_lock);
213	ca = tcp_ca_find(name);
214#ifdef CONFIG_MODULES
215	if (!ca && capable(CAP_NET_ADMIN)) {
216		spin_unlock(&tcp_cong_list_lock);
217
218		request_module("tcp_%s", name);
219		spin_lock(&tcp_cong_list_lock);
220		ca = tcp_ca_find(name);
221	}
222#endif
223
224	if (ca) {
225		ca->flags |= TCP_CONG_NON_RESTRICTED;	/* default is always allowed */
226		list_move(&ca->list, &tcp_cong_list);
227		ret = 0;
228	}
229	spin_unlock(&tcp_cong_list_lock);
230
231	return ret;
232}
233
234/* Set default value from kernel configuration at bootup */
235static int __init tcp_congestion_default(void)
236{
237	return tcp_set_default_congestion_control(CONFIG_DEFAULT_TCP_CONG);
238}
239late_initcall(tcp_congestion_default);
240
241/* Build string with list of available congestion control values */
242void tcp_get_available_congestion_control(char *buf, size_t maxlen)
243{
244	struct tcp_congestion_ops *ca;
245	size_t offs = 0;
246
247	rcu_read_lock();
248	list_for_each_entry_rcu(ca, &tcp_cong_list, list) {
249		offs += snprintf(buf + offs, maxlen - offs,
250				 "%s%s",
251				 offs == 0 ? "" : " ", ca->name);
252	}
253	rcu_read_unlock();
254}
255
256/* Get current default congestion control */
257void tcp_get_default_congestion_control(char *name)
258{
259	struct tcp_congestion_ops *ca;
260	/* We will always have reno... */
261	BUG_ON(list_empty(&tcp_cong_list));
262
263	rcu_read_lock();
264	ca = list_entry(tcp_cong_list.next, struct tcp_congestion_ops, list);
265	strncpy(name, ca->name, TCP_CA_NAME_MAX);
266	rcu_read_unlock();
267}
268
269/* Built list of non-restricted congestion control values */
270void tcp_get_allowed_congestion_control(char *buf, size_t maxlen)
271{
272	struct tcp_congestion_ops *ca;
273	size_t offs = 0;
274
275	*buf = '\0';
276	rcu_read_lock();
277	list_for_each_entry_rcu(ca, &tcp_cong_list, list) {
278		if (!(ca->flags & TCP_CONG_NON_RESTRICTED))
279			continue;
280		offs += snprintf(buf + offs, maxlen - offs,
281				 "%s%s",
282				 offs == 0 ? "" : " ", ca->name);
283	}
284	rcu_read_unlock();
285}
286
287/* Change list of non-restricted congestion control */
288int tcp_set_allowed_congestion_control(char *val)
289{
290	struct tcp_congestion_ops *ca;
291	char *saved_clone, *clone, *name;
292	int ret = 0;
293
294	saved_clone = clone = kstrdup(val, GFP_USER);
295	if (!clone)
296		return -ENOMEM;
297
298	spin_lock(&tcp_cong_list_lock);
299	/* pass 1 check for bad entries */
300	while ((name = strsep(&clone, " ")) && *name) {
301		ca = tcp_ca_find(name);
302		if (!ca) {
303			ret = -ENOENT;
304			goto out;
305		}
306	}
307
308	/* pass 2 clear old values */
309	list_for_each_entry_rcu(ca, &tcp_cong_list, list)
310		ca->flags &= ~TCP_CONG_NON_RESTRICTED;
311
312	/* pass 3 mark as allowed */
313	while ((name = strsep(&val, " ")) && *name) {
314		ca = tcp_ca_find(name);
315		WARN_ON(!ca);
316		if (ca)
317			ca->flags |= TCP_CONG_NON_RESTRICTED;
318	}
319out:
320	spin_unlock(&tcp_cong_list_lock);
321	kfree(saved_clone);
322
323	return ret;
324}
325
326/* Change congestion control for socket */
327int tcp_set_congestion_control(struct sock *sk, const char *name)
328{
329	struct inet_connection_sock *icsk = inet_csk(sk);
330	const struct tcp_congestion_ops *ca;
331	int err = 0;
332
333	if (icsk->icsk_ca_dst_locked)
334		return -EPERM;
335
336	rcu_read_lock();
337	ca = __tcp_ca_find_autoload(name);
338	/* No change asking for existing value */
339	if (ca == icsk->icsk_ca_ops) {
340		icsk->icsk_ca_setsockopt = 1;
341		goto out;
342	}
343	if (!ca)
344		err = -ENOENT;
345	else if (!((ca->flags & TCP_CONG_NON_RESTRICTED) ||
346		   ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)))
347		err = -EPERM;
348	else if (!try_module_get(ca->owner))
349		err = -EBUSY;
350	else
351		tcp_reinit_congestion_control(sk, ca);
352 out:
353	rcu_read_unlock();
354	return err;
355}
356
357/* Slow start is used when congestion window is no greater than the slow start
358 * threshold. We base on RFC2581 and also handle stretch ACKs properly.
359 * We do not implement RFC3465 Appropriate Byte Counting (ABC) per se but
360 * something better;) a packet is only considered (s)acked in its entirety to
361 * defend the ACK attacks described in the RFC. Slow start processes a stretch
362 * ACK of degree N as if N acks of degree 1 are received back to back except
363 * ABC caps N to 2. Slow start exits when cwnd grows over ssthresh and
364 * returns the leftover acks to adjust cwnd in congestion avoidance mode.
365 */
366u32 tcp_slow_start(struct tcp_sock *tp, u32 acked)
367{
368	u32 cwnd = tp->snd_cwnd + acked;
369
370	if (cwnd > tp->snd_ssthresh)
371		cwnd = tp->snd_ssthresh + 1;
372	acked -= cwnd - tp->snd_cwnd;
373	tp->snd_cwnd = min(cwnd, tp->snd_cwnd_clamp);
374
375	return acked;
376}
377EXPORT_SYMBOL_GPL(tcp_slow_start);
378
379/* In theory this is tp->snd_cwnd += 1 / tp->snd_cwnd (or alternative w),
380 * for every packet that was ACKed.
381 */
382void tcp_cong_avoid_ai(struct tcp_sock *tp, u32 w, u32 acked)
383{
384	/* If credits accumulated at a higher w, apply them gently now. */
385	if (tp->snd_cwnd_cnt >= w) {
386		tp->snd_cwnd_cnt = 0;
387		tp->snd_cwnd++;
388	}
389
390	tp->snd_cwnd_cnt += acked;
391	if (tp->snd_cwnd_cnt >= w) {
392		u32 delta = tp->snd_cwnd_cnt / w;
393
394		tp->snd_cwnd_cnt -= delta * w;
395		tp->snd_cwnd += delta;
396	}
397	tp->snd_cwnd = min(tp->snd_cwnd, tp->snd_cwnd_clamp);
398}
399EXPORT_SYMBOL_GPL(tcp_cong_avoid_ai);
400
401/*
402 * TCP Reno congestion control
403 * This is special case used for fallback as well.
404 */
405/* This is Jacobson's slow start and congestion avoidance.
406 * SIGCOMM '88, p. 328.
407 */
408void tcp_reno_cong_avoid(struct sock *sk, u32 ack, u32 acked)
409{
410	struct tcp_sock *tp = tcp_sk(sk);
411
412	if (!tcp_is_cwnd_limited(sk))
413		return;
414
415	/* In "safe" area, increase. */
416	if (tp->snd_cwnd <= tp->snd_ssthresh) {
417		acked = tcp_slow_start(tp, acked);
418		if (!acked)
419			return;
420	}
421	/* In dangerous area, increase slowly. */
422	tcp_cong_avoid_ai(tp, tp->snd_cwnd, acked);
423}
424EXPORT_SYMBOL_GPL(tcp_reno_cong_avoid);
425
426/* Slow start threshold is half the congestion window (min 2) */
427u32 tcp_reno_ssthresh(struct sock *sk)
428{
429	const struct tcp_sock *tp = tcp_sk(sk);
430
431	return max(tp->snd_cwnd >> 1U, 2U);
432}
433EXPORT_SYMBOL_GPL(tcp_reno_ssthresh);
434
435struct tcp_congestion_ops tcp_reno = {
436	.flags		= TCP_CONG_NON_RESTRICTED,
437	.name		= "reno",
438	.owner		= THIS_MODULE,
439	.ssthresh	= tcp_reno_ssthresh,
440	.cong_avoid	= tcp_reno_cong_avoid,
441};
442