1/*
2 * Copyright (c) 2014, Cisco Systems, Inc. All rights reserved.
3 *
4 * This program is free software; you may redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; version 2 of the License.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
9 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
10 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
11 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
12 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
13 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
14 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 * SOFTWARE.
16 *
17 */
18
19#include <linux/init.h>
20#include <linux/list.h>
21#include <linux/slab.h>
22#include <linux/list_sort.h>
23
24#include <linux/interval_tree_generic.h>
25#include "usnic_uiom_interval_tree.h"
26
27#define START(node) ((node)->start)
28#define LAST(node) ((node)->last)
29
30#define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out)	\
31		do {							\
32			node = usnic_uiom_interval_node_alloc(start,	\
33					end, ref_cnt, flags);		\
34				if (!node) {				\
35					err = -ENOMEM;			\
36					goto err_out;			\
37				}					\
38		} while (0)
39
40#define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
41
42#define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err,	\
43				err_out, list)				\
44				do {					\
45					MAKE_NODE(node, start, end,	\
46						ref_cnt, flags, err,	\
47						err_out);		\
48					MARK_FOR_ADD(node, list);	\
49				} while (0)
50
51#define FLAGS_EQUAL(flags1, flags2, mask)				\
52			(((flags1) & (mask)) == ((flags2) & (mask)))
53
54static struct usnic_uiom_interval_node*
55usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
56				int flags)
57{
58	struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
59								GFP_ATOMIC);
60	if (!interval)
61		return NULL;
62
63	interval->start = start;
64	interval->last = last;
65	interval->flags = flags;
66	interval->ref_cnt = ref_cnt;
67
68	return interval;
69}
70
71static int interval_cmp(void *priv, struct list_head *a, struct list_head *b)
72{
73	struct usnic_uiom_interval_node *node_a, *node_b;
74
75	node_a = list_entry(a, struct usnic_uiom_interval_node, link);
76	node_b = list_entry(b, struct usnic_uiom_interval_node, link);
77
78	/* long to int */
79	if (node_a->start < node_b->start)
80		return -1;
81	else if (node_a->start > node_b->start)
82		return 1;
83
84	return 0;
85}
86
87static void
88find_intervals_intersection_sorted(struct rb_root *root, unsigned long start,
89					unsigned long last,
90					struct list_head *list)
91{
92	struct usnic_uiom_interval_node *node;
93
94	INIT_LIST_HEAD(list);
95
96	for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
97		node;
98		node = usnic_uiom_interval_tree_iter_next(node, start, last))
99		list_add_tail(&node->link, list);
100
101	list_sort(NULL, list, interval_cmp);
102}
103
104int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
105					int flags, int flag_mask,
106					struct rb_root *root,
107					struct list_head *diff_set)
108{
109	struct usnic_uiom_interval_node *interval, *tmp;
110	int err = 0;
111	long int pivot = start;
112	LIST_HEAD(intersection_set);
113
114	INIT_LIST_HEAD(diff_set);
115
116	find_intervals_intersection_sorted(root, start, last,
117						&intersection_set);
118
119	list_for_each_entry(interval, &intersection_set, link) {
120		if (pivot < interval->start) {
121			MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
122						1, flags, err, err_out,
123						diff_set);
124			pivot = interval->start;
125		}
126
127		/*
128		 * Invariant: Set [start, pivot] is either in diff_set or root,
129		 * but not in both.
130		 */
131
132		if (pivot > interval->last) {
133			continue;
134		} else if (pivot <= interval->last &&
135				FLAGS_EQUAL(interval->flags, flags,
136				flag_mask)) {
137			pivot = interval->last + 1;
138		}
139	}
140
141	if (pivot <= last)
142		MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
143					diff_set);
144
145	return 0;
146
147err_out:
148	list_for_each_entry_safe(interval, tmp, diff_set, link) {
149		list_del(&interval->link);
150		kfree(interval);
151	}
152
153	return err;
154}
155
156void usnic_uiom_put_interval_set(struct list_head *intervals)
157{
158	struct usnic_uiom_interval_node *interval, *tmp;
159	list_for_each_entry_safe(interval, tmp, intervals, link)
160		kfree(interval);
161}
162
163int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start,
164				unsigned long last, int flags)
165{
166	struct usnic_uiom_interval_node *interval, *tmp;
167	unsigned long istart, ilast;
168	int iref_cnt, iflags;
169	unsigned long lpivot = start;
170	int err = 0;
171	LIST_HEAD(to_add);
172	LIST_HEAD(intersection_set);
173
174	find_intervals_intersection_sorted(root, start, last,
175						&intersection_set);
176
177	list_for_each_entry(interval, &intersection_set, link) {
178		/*
179		 * Invariant - lpivot is the left edge of next interval to be
180		 * inserted
181		 */
182		istart = interval->start;
183		ilast = interval->last;
184		iref_cnt = interval->ref_cnt;
185		iflags = interval->flags;
186
187		if (istart < lpivot) {
188			MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
189						iflags, err, err_out, &to_add);
190		} else if (istart > lpivot) {
191			MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
192						err, err_out, &to_add);
193			lpivot = istart;
194		} else {
195			lpivot = istart;
196		}
197
198		if (ilast > last) {
199			MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
200						iflags | flags, err, err_out,
201						&to_add);
202			MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
203						iflags, err, err_out, &to_add);
204		} else {
205			MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
206						iflags | flags, err, err_out,
207						&to_add);
208		}
209
210		lpivot = ilast + 1;
211	}
212
213	if (lpivot <= last)
214		MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
215					&to_add);
216
217	list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
218		usnic_uiom_interval_tree_remove(interval, root);
219		kfree(interval);
220	}
221
222	list_for_each_entry(interval, &to_add, link)
223		usnic_uiom_interval_tree_insert(interval, root);
224
225	return 0;
226
227err_out:
228	list_for_each_entry_safe(interval, tmp, &to_add, link)
229		kfree(interval);
230
231	return err;
232}
233
234void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start,
235				unsigned long last, struct list_head *removed)
236{
237	struct usnic_uiom_interval_node *interval;
238
239	for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
240			interval;
241			interval = usnic_uiom_interval_tree_iter_next(interval,
242									start,
243									last)) {
244		if (--interval->ref_cnt == 0)
245			list_add_tail(&interval->link, removed);
246	}
247
248	list_for_each_entry(interval, removed, link)
249		usnic_uiom_interval_tree_remove(interval, root);
250}
251
252INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
253			unsigned long, __subtree_last,
254			START, LAST, , usnic_uiom_interval_tree)
255