add test for MemIndex.py; fix bug with adding lists of elements to Cidr/ComboMemIndex
[python-rwhoisd.git] / rwhoisd / MemIndex.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, David E. Blacka
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 2 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
18 # USA
19
20 import bisect, types
21 import Cidr
22
23 class MemIndex:
24     """This class implements a simple in-memory key-value map.  This
25     index supports efficient prefix matching (as well as pretty
26     efficient exact matching).  Internally, it is implemented as a
27     sorted list supporting binary searches."""
28
29     # NOTE: This implementation is probably far from as efficient as
30     # it could be.  Ideally, we should avoid the use of custom
31     # comparison functions, so that list.sort() will use built-in
32     # comparitors.  This would mean storing raw key tuples in index as
33     # opposed to element objects.  Also, it would mean avoiding the
34     # use of objects (like Cidr) as keys in favor of a primitive type.
35     # In the Cidr case, we would either have to use longs or strings,
36     # as Python doesn't seem to have an unsigned 32-bit integer type.
37
38     def __init__(self):
39         self.index = []
40         self.sorted = False
41
42     def add(self, key, value=None):
43         """Add a key-value pair to the map.  If the map is already in
44         the prepared state, this operation will preserve it, so don't
45         use this if many elements are to be added at once.  The 'key'
46         argument may be a 2 element tuple, in which case 'value' is
47         ignored."""
48
49         if isinstance(key, types.TupleType):
50             el = element(key[0], key[1])
51         else:
52             el = element(key, value)
53
54         if self.sorted:
55             i = bisect.bisect_left(self.index, el)
56             while i < len(self.index):
57                 if self.index[i].total_equals(el):
58                     break
59                 if self.index[i] != el:
60                     self.index.insert(i, el)
61                     break
62                 i += 1
63             else:
64                 self.index.append(el)
65         else:
66             self.index.append(el)
67
68     def addlist(self, list):
69         """Add the entire list of elements to the map.  The elements
70         of 'list' may be 2 element tuples or actual 'element' objects.
71         Use this method to add many elements at once."""
72
73         self.sorted = False
74         for i in list:
75             if isinstance(i, types.TupleType):
76                 self.index.append(element(i[0], i[1]))
77             elif isinstance(i, element):
78                 self.index.append(i)
79
80     def prepare(self):
81         """Put the map in a prepared state, if necessary."""
82
83         n = len(self.index)
84         if not n: return
85         if not self.sorted:
86             self.index.sort()
87             # unique the index
88             last = self.index[0]
89             lasti = i = 1
90             while i < n:
91                 if not self.index[i].total_equals(last):
92                     self.index[lasti] = last = self.index[i]
93                     lasti += 1
94                 i += 1
95             self.index[lasti:]
96             self.sorted = True
97
98     def _find(self, key):
99         """Return the (search_element, index) tuple.  Used internally
100         only."""
101
102         self.prepare()
103         search_el = element(key, None)
104         i = bisect.bisect_left(self.index, search_el)
105         if i > len(self.index) or i < 0:
106             print "warning: bisect.bisect_left returned something " + \
107                   "unexpected:", i, len(self.index)
108         return (search_el, i)
109
110     def find(self, key, prefix_match=False, max=0):
111         """Return a list of values whose keys string match 'key'.  If
112         prefix_match is True, then keys will match if 'key' is a
113         prefix of the element key."""
114
115         search_el, i = self._find(key)
116         res = []
117         while i < len(self.index):
118             if max and len(res) == max: break
119             if search_el.equals(self.index[i], prefix_match):
120                 res.append(self.index[i].value)
121                 i += 1
122             else:
123                 break
124         return res
125
126 class CidrMemIndex(MemIndex):
127     """This is an in-memory map that has been extended to support CIDR
128     searching semantics."""
129
130     # NOTE: this structure lends to fairly efficient exact searches
131     # (O[log2N]), efficient subnet searches (also O[log2N]), but not
132     # terribly efficient supernet searches (O[32 * log2N] or O[128 *
133     # log2N] for IPv6), because we have to potentially do 32 (or 128!)
134     # exact matches.  If we want efficient supernet searches, we will
135     # probably have to use some sort of general (i.e., not binary)
136     # search tree datastructure, as there is no sorted ordering that
137     # will efficiently give supernets that I can think of.
138
139     # convert a key, value pair into a list of (cidr, value) tuples.
140     # It can be a list with more than one element if key is actually a
141     # netblock.
142     def _conv_key_value(self, key, value):
143         res = []
144         if isinstance(key, Cidr.Cidr):
145             res.append((key, value))
146             return res
147         if self.is_netblock(key):
148             cidrs = self.parse_netblock(key)
149             for c in cidrs:
150                 res.append((c, value))
151         else:
152             c = Cidr.valid_cidr(key)
153             res.append((c, value))
154         return res
155
156     # convert a (key, value) tuple into a list of (cidr, value)
157     # tuples.
158     def _conv_tuple(self, tuple):
159         return self._conv_key_value(tuple[0], tuple[1])
160
161     def add(self, key, value=None):
162         if isinstance(key, types.TupleType):
163             l = self._conv_tuple(key)
164         else:
165             l = self._conv_key_value(key, value)
166
167         for k, v in l:
168             MemIndex.add(self, k, v)
169         return
170
171     def addlist(self, list):
172
173         res_list = []
174         # make sure the keys are Cidr objects
175         for i in list:
176             if isinstance(i, types.TupleType):
177                 l = self._conv_tuple(i)
178             elif isinstance(el, element):
179                 l = self._conv_key_value(i.key, i.value)
180             res_list.append(l)
181
182         # flatten the list of lists of tuples to just a list of tuples
183         res_list = [ x for y in res_list for x in y ]
184         MemIndex.addlist(self, res_list)
185         return
186
187     def is_netblock(self, key):
188         if "-" in key: return True
189         return False
190
191     def parse_netblock(self, key):
192         start, end = key.split("-", 1);
193         start = start.strip()
194         end = end.strip()
195
196         return Cidr.netblock_to_cidr(start, end)
197
198     def find_exact(self, key, max = 0):
199
200         key = Cidr.valid_cidr(key)
201         search_el, i = self._find(key)
202         res = []
203         while i < len(self.index) and self.index[i].key == key:
204             res.append(self.index[i].value)
205             if max and len(res) == max: break
206             i += 1
207         return res
208
209     def find_subnets(self, key, max = 0):
210         """Return all values that are subnets of 'key', including any
211         that match 'key' itself."""
212
213         key = Cidr.valid_cidr(key)
214         search_el, i = self._find(key)
215
216         res = set()
217         while i < len(self.index) and self.index[i].key.is_subnet(key):
218             if max and len(res) == max: break
219             res.add(self.index[i].value)
220             i += 1
221         return list(res)
222
223     def find_supernets(self, key, max = 0):
224         """Return all values that are supernets of 'key', including
225         any that match 'key' itself."""
226
227         key = Cidr.valid_cidr(key)
228         k = key.clone()
229         res = []
230         while k.netlen >= 0:
231             k.calc()
232             res += self.find_exact(k, max)
233             if max and len(res) >= max:
234                 return res[:max]
235             k.netlen -= 1
236
237
238         return res
239
240     def find(self, key, prefix_match=0, max=0):
241         """Return either the exact match of 'key', or the closest
242         supernet of 'key'.  If prefix_match is True, then find all
243         supernets of 'key'"""
244
245         key = Cidr.valid_cidr(key)
246         if prefix_match == 0:
247             res = self.find_exact(key, max)
248
249             if not res:
250                 # now do a modified supernet search that stops after
251                 # the first proper supernet, but gets all values
252                 # matching that supernet key
253                 k = key.clone()
254                 k.netlen -= 1
255                 while not res and k.netlen >= 0:
256                     k.calc()
257                     res = self.find_exact(k, max)
258                     k.netlen -= 1
259             return res
260
261         # for now, a prefix match means all supernets
262         return self.find_supernets(key, max)
263
264 class ComboMemIndex:
265     """This is an in-memory map that contains both a normal string
266     index and a CIDR index.  Valid CIDR values we be applied against
267     the CIDR index.  Other values will be applied against the normal
268     index."""
269
270     def __init__(self):
271         self.normal_index = MemIndex()
272         self.cidr_index   = CidrMemIndex()
273
274     def add(self, key, value = None):
275         """Add a key,value pair to the correct map.  See MemIndex for
276         the behavior of this method"""
277
278         if isinstance(key, types.TupleType):
279             k = key[0]
280         else:
281             k = key
282         c = Cidr.valid_cidr(key)
283         if c:
284             self.cidr_index.add(key, value)
285         else:
286             self.normal_index.add(key, value)
287         return
288
289     def addlist(self, list):
290         """Add a list of elements or key, value tuples to the
291         appropriate maps."""
292
293         cidr_list = []
294         normal_list = []
295
296         for i in list:
297             if isinstance(i, element):
298                 k, v = i.key, i.value
299             elif isinstance(i, types.TupleType):
300                 k, v = i[:2]
301
302             c = Cidr.valid_cidr(k)
303             if c:
304                 cidr_list.append((c, v))
305             else:
306                 normal_list.append((k, v))
307
308         if cidr_list:
309             self.cidr_index.addlist(cidr_list)
310         if normal_list:
311             self.normal_index.addlist(normal_list)
312         return
313
314     def prepare(self):
315         """Prepare the internally held maps for searching."""
316
317         self.cidr_index.prepare()
318         self.normal_index.prepare()
319
320     def find(self, key, prefix_match=False, max=0):
321         """Return a list of values whose keys match 'key'."""
322
323         c = Cidr.valid_cidr(key)
324         if c:
325             return self.cidr_index.find(c, prefix_match, max)
326         return self.normal_index.find(key, prefix_match, max)
327
328     def find_exact(self, key, max = 0):
329         """Return a list of values whose keys match 'key'.  if 'key'
330         is not a CIDR value, then this is the same as find()."""
331
332         c = Cidr.valid_cidr(key)
333         if c:
334             return self.cidr_index.find_exact(c, max)
335         return self.normal_index.find(key, False, max)
336
337     def find_subnets(self, key, max = 0):
338         """If 'key' is a CIDR value (either a Cidr object or a valid
339         CIDR string representation, do a find_subnets on the internal
340         CidrMemIndex, otherwise return None."""
341
342         c = Cidr.valid_cidr(key)
343         if c: return self.cidr_index.find_subnets(key, max)
344         return None
345
346     def find_supernets(self, key, max = 0):
347         """If 'key' is a CIDR value (either a Cidr object or a valid
348         CIDR string representation, do a find_supernets on the internal
349         CidrMemIndex, otherwise return None."""
350
351         c = Cidr.valid_cidr(key)
352         if c: return self.cidr_index.find_supernets(key, max)
353         return None
354
355 class element:
356     """This is the base element class.  It basically exists to
357     simplify sorting."""
358
359     def __init__(self, key, value):
360         self.key   = key
361         self.value = value
362
363     def __cmp__(self, other):
364         """Compare only on the key."""
365
366         if not type(self.key) == type(other.key):
367             print "other is incompatible type?", repr(other.key), other.key
368         if self.key < other.key:
369             return -1
370         if self.key == other.key:
371             return 0
372         return 1
373
374     def __str__(self):
375         return "<" + str(self.key) + ", " + str(self.value) + ">"
376
377     def __repr__(self):
378         return "element" + str(self)
379
380     def __hash__(self):
381         return self.key.__hash__()
382
383     def equals(self, other, prefix_match=0):
384         if prefix_match:
385             return self.key == other.key[:len(self.key)]
386         return self.key == other.key
387
388     def total_equals(self, other):
389         if not isinstance(other, type(self)): return False
390         return self.key == other.key and self.value == other.value
391