Update TODO based on work for 0.4.1
[python-rwhoisd.git] / rwhoisd / MemIndex.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, David E. Blacka
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 2 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
18 # USA
19
20 import bisect, types
21 import Cidr
22
23 class MemIndex:
24     """This class implements a simple in-memory key-value map.  This
25     index supports efficient prefix matching (as well as pretty
26     efficient exact matching).  Internally, it is implemented as a
27     sorted list supporting binary searches."""
28
29     # NOTE: This implementation is probably far from as efficient as
30     # it could be.  Ideally, we should avoid the use of custom
31     # comparison functions, so that list.sort() will use built-in
32     # comparitors.  This would mean storing raw key tuples in index as
33     # opposed to element objects.  Also, it would mean avoiding the
34     # use of objects (like Cidr) as keys in favor of a primitive type.
35     # In the Cidr case, we would either have to use longs or strings,
36     # as Python doesn't seem to have an unsigned 32-bit integer type.
37
38     def __init__(self):
39         self.index = []
40         self.sorted = False
41
42     def add(self, key, value=None):
43         """Add a key-value pair to the map.  If the map is already in
44         the prepared state, this operation will preserve it, so don't
45         use this if many elements are to be added at once.  The 'key'
46         argument may be a 2 element tuple, in which case 'value' is
47         ignored."""
48
49         if isinstance(key, types.TupleType):
50             el = element(key[0], key[1])
51         else:
52             el = element(key, value)
53
54         if self.sorted:
55             i = bisect.bisect_left(self.index, el)
56             while i < len(self.index):
57                 if self.index[i].total_equals(el):
58                     break
59                 if self.index[i] != el:
60                     self.index.insert(i, el)
61                     break
62                 i += 1
63             else:
64                 self.index.append(el)
65         else:
66             self.index.append(el)
67
68     def addlist(self, list):
69         """Add the entire list of elements to the map.  The elements
70         of 'list' may be 2 element tuples or actual 'element' objects.
71         Use this method to add many elements at once."""
72
73         self.sorted = False
74         for i in list:
75             if isinstance(i, types.TupleType):
76                 self.index.append(element(i[0], i[1]))
77             elif isinstance(i, element):
78                 self.index.append(i)
79
80     def prepare(self):
81         """Put the map in a prepared state, if necessary."""
82
83         n = len(self.index)
84         if not n: return
85         if not self.sorted:
86             self.index.sort()
87             # unique the index
88             last = self.index[0]
89             lasti = i = 1
90             while i < n:
91                 if not self.index[i].total_equals(last):
92                     self.index[lasti] = last = self.index[i]
93                     lasti += 1
94                 i += 1
95             self.index[lasti:]
96             self.sorted = True
97
98     def _find(self, key):
99         """Return the (search_element, index) tuple.  Used internally
100         only."""
101
102         self.prepare()
103         search_el = element(key, None)
104         i = bisect.bisect_left(self.index, search_el)
105         if i > len(self.index) or i < 0:
106             print "warning: bisect.bisect_left returned something " + \
107                   "unexpected:", i, len(self.index)
108         return (search_el, i)
109
110     def find(self, key, prefix_match=False, max=0):
111         """Return a list of values whose keys string match 'key'.  If
112         prefix_match is True, then keys will match if 'key' is a
113         prefix of the element key."""
114
115         search_el, i = self._find(key)
116         res = []
117         while i < len(self.index):
118             if max and len(res) == max: break
119             if search_el.equals(self.index[i], prefix_match):
120                 res.append(self.index[i].value)
121                 i += 1
122             else:
123                 break
124         return res
125
126 class CidrMemIndex(MemIndex):
127     """This is an in-memory map that has been extended to support CIDR
128     searching semantics."""
129
130     # NOTE: this structure lends to fairly efficient exact searches
131     # (O[log2N]), efficient subnet searches (also O[log2N]), but not
132     # terribly efficient supernet searches (O[32 * log2N] or O[128 *
133     # log2N] for IPv6), because we have to potentially do 32 (or 128!)
134     # exact matches.  If we want efficient supernet searches, we will
135     # probably have to use some sort of general (i.e., not binary)
136     # search tree datastructure, as there is no sorted ordering that
137     # will efficiently give supernets that I can think of.
138
139     # convert a key, value pair into a list of (cidr, value) tuples.
140     # It can be a list with more than one element if key is actually a
141     # netblock.
142     def _conv_key_value(self, key, value):
143         res = []
144         if isinstance(key, Cidr.Cidr):
145             res.append((key, value))
146             return res
147         if self.is_netblock(key):
148             cidrs = self.parse_netblock(key)
149             for c in cidrs:
150                 res.append((c, value))
151         else:
152             c = Cidr.valid_cidr(key)
153             res.append((c, value))
154         return res
155
156     # convert a (key, value) tuple into a list of (cidr, value)
157     # tuples.
158     def _conv_tuple(self, tuple):
159         return self._conv_key_value(tuple[0], tuple[1])
160
161     def add(self, key, value=None):
162         if isinstance(key, types.TupleType):
163             l = self._conv_tuple(key)
164         else:
165             l = self._conv_key_value(key, value)
166
167         for k, v in l:
168             MemIndex.add(self, k, v)
169         return
170
171     def addlist(self, list):
172
173         res_list = []
174         # make sure the keys are Cidr objects
175         for i in list:
176             if isinstance(i, types.TupleType):
177                 l = self._conv_tuple(i)
178             elif isinstance(el, element):
179                 l = self._conv_key_value(i.key, i.value)
180             res_list.append(l)
181
182         MemIndex.addlist(self, res_list)
183         return
184
185     def is_netblock(self, key):
186         if "-" in key: return True
187         return False
188
189     def parse_netblock(self, key):
190         start, end = key.split("-", 1);
191         start = start.strip()
192         end = end.strip()
193
194         return Cidr.netblock_to_cidr(start, end)
195
196     def find_exact(self, key, max = 0):
197
198         key = Cidr.valid_cidr(key)
199         search_el, i = self._find(key)
200         res = []
201         while i < len(self.index) and self.index[i].key == key:
202             res.append(self.index[i].value)
203             if max and len(res) == max: break
204             i += 1
205         return res
206
207     def find_subnets(self, key, max = 0):
208         """Return all values that are subnets of 'key', including any
209         that match 'key' itself."""
210
211         key = Cidr.valid_cidr(key)
212         search_el, i = self._find(key)
213
214         res = set()
215         while i < len(self.index) and self.index[i].key.is_subnet(key):
216             if max and len(res) == max: break
217             res.add(self.index[i].value)
218             i += 1
219         return list(res)
220
221     def find_supernets(self, key, max = 0):
222         """Return all values that are supernets of 'key', including
223         any that match 'key' itself."""
224
225         key = Cidr.valid_cidr(key)
226         k = key.clone()
227         res = []
228         while k.netlen >= 0:
229             k.calc()
230             res += self.find_exact(k, max)
231             if max and len(res) >= max:
232                 return res[:max]
233             k.netlen -= 1
234
235
236         return res
237
238     def find(self, key, prefix_match=0, max=0):
239         """Return either the exact match of 'key', or the closest
240         supernet of 'key'.  If prefix_match is True, then find all
241         supernets of 'key'"""
242
243         key = Cidr.valid_cidr(key)
244         if prefix_match == 0:
245             res = self.find_exact(key, max)
246
247             if not res:
248                 # now do a modified supernet search that stops after
249                 # the first proper supernet, but gets all values
250                 # matching that supernet key
251                 k = key.clone()
252                 k.netlen -= 1
253                 while not res and k.netlen >= 0:
254                     k.calc()
255                     res = self.find_exact(k, max)
256                     k.netlen -= 1
257             return res
258
259         # for now, a prefix match means all supernets
260         return self.find_supernets(key, max)
261
262 class ComboMemIndex:
263     """This is an in-memory map that contains both a normal string
264     index and a CIDR index.  Valid CIDR values we be applied against
265     the CIDR index.  Other values will be applied against the normal
266     index."""
267
268     def __init__(self):
269         self.normal_index = MemIndex()
270         self.cidr_index   = CidrMemIndex()
271
272     def add(self, key, value = None):
273         """Add a key,value pair to the correct map.  See MemIndex for
274         the behavior of this method"""
275
276         if isinstance(key, types.TupleType):
277             k = key[0]
278         else:
279             k = key
280         c = Cidr.valid_cidr(key)
281         if c:
282             self.cidr_index.add(key, value)
283         else:
284             self.normal_index.add(key, value)
285         return
286
287     def addlist(self, list):
288         """Add a list of elements or key, value tuples to the
289         appropriate maps."""
290
291         cidr_list = []
292         normal_list = []
293
294         for i in list:
295             if isinstance(i, element):
296                 k, v = i.key, i.value
297             elif isinstance(i, types.TupleType):
298                 k, v = i[:2]
299
300             c = Cidr.valid_cidr(k)
301             if c:
302                 cidr_list.append((c, v))
303             else:
304                 normal_list.append((k, v))
305
306         if cidr_list:
307             self.cidr_index.addlist(cidr_list)
308         if normal_list:
309             self.normal_index.addlist(normal_list)
310         return
311
312     def prepare(self):
313         """Prepare the internally held maps for searching."""
314
315         self.cidr_index.prepare()
316         self.normal_index.prepare()
317
318     def find(self, key, prefix_match=False, max=0):
319         """Return a list of values whose keys match 'key'."""
320
321         c = Cidr.valid_cidr(key)
322         if c:
323             return self.cidr_index.find(c, prefix_match, max)
324         return self.normal_index.find(key, prefix_match, max)
325
326     def find_exact(self, key, max = 0):
327         """Return a list of values whose keys match 'key'.  if 'key'
328         is not a CIDR value, then this is the same as find()."""
329
330         c = Cidr.valid_cidr(key)
331         if c:
332             return self.cidr_index.find_exact(c, max)
333         return self.normal_index.find(key, False, max)
334
335     def find_subnets(self, key, max = 0):
336         """If 'key' is a CIDR value (either a Cidr object or a valid
337         CIDR string representation, do a find_subnets on the internal
338         CidrMemIndex, otherwise return None."""
339
340         c = Cidr.valid_cidr(key)
341         if c: return self.cidr_index.find_subnets(key, max)
342         return None
343
344     def find_supernets(self, key, max = 0):
345         """If 'key' is a CIDR value (either a Cidr object or a valid
346         CIDR string representation, do a find_supernets on the internal
347         CidrMemIndex, otherwise return None."""
348
349         c = Cidr.valid_cidr(key)
350         if c: return self.cidr_index.find_supernets(key, max)
351         return None
352
353 class element:
354     """This is the base element class.  It basically exists to
355     simplify sorting."""
356
357     def __init__(self, key, value):
358         self.key   = key
359         self.value = value
360
361     def __cmp__(self, other):
362         """Compare only on the key."""
363
364         if not type(self.key) == type(other.key):
365             print "other is incompatible type?", repr(other.key), other.key
366         if self.key < other.key:
367             return -1
368         if self.key == other.key:
369             return 0
370         return 1
371
372     def __str__(self):
373         return "<" + str(self.key) + ", " + str(self.value) + ">"
374
375     def __repr__(self):
376         return "element" + str(self)
377
378     def __hash__(self):
379         return self.key.__hash__()
380
381     def equals(self, other, prefix_match=0):
382         if prefix_match:
383             return self.key == other.key[:len(self.key)]
384         return self.key == other.key
385
386     def total_equals(self, other):
387         if not isinstance(other, type(self)): return False
388         return self.key == other.key and self.value == other.value
389
390 if __name__ == "__main__":
391
392     source = [ ("foo", "foo-id"), ("bar", "bar-id"), ("baz", "baz-id"),
393                ("foobar", "foo-id-2"), ("barnone", "bar-id-2"),
394                ("zygnax", "z-id") ]
395
396     mi = MemIndex()
397     mi.addlist(source)
398
399     print "finding foobar:"
400     res = mi.find("foobar")
401     print res
402
403     print "finding foo*:"
404     res = mi.find("foo", 1)
405     print res
406
407     print "finding baz:"
408     res = mi.find("baz")
409     print res
410
411     print "adding bork"
412     mi.add("bork", "bork-id")
413
414     print "finding b*:"
415     res = mi.find("b", 1)
416     print res
417
418     ci = CidrMemIndex()
419
420     ci.add("127.0.0.1/24", "net-local-1");
421     ci.add("127.0.0.1/32", "net-local-2");
422     ci.add(Cidr.new("216.168.224.0", 22), "net-vrsn-1")
423     ci.add(Cidr.new("216.168.252.1", 32), "net-vrsn-2")
424     ci.add("24.36.191.0/24", "net-foo-c")
425     ci.add("24.36.191.32/27", "net-foo-sub-c")
426     ci.add("24.36/16", "net-foo-b")
427     ci.add("3ffe:4:5::0/48", "net-foo-d6")
428     ci.add("3ffe:4:5:6::0/64", "net-foo-e6")
429     ci.add("48.12.6.0 - 48.12.6.95", "net-bar-1")
430
431     print "finding exactly 127.0.0.0/24"
432     res = ci.find(Cidr.new("127.0.0.0/24"))
433     print res
434
435     print "finding exactly 127.0.0.16/32"
436     res = ci.find(Cidr.new("127.0.0.16/32"))
437     print res
438
439     print "finding exactly 3ffe:4:5:6::0/64"
440     res = ci.find(Cidr.valid_cidr("3ffe:4:5:6::/64"))
441     print res
442
443     print "finding supernets of 127.0.0.16/32"
444     res = ci.find_supernets(Cidr.new("127.0.0.16/32"))
445     print res
446
447     print "finding supernets of 24.36.191.32/27"
448     res = ci.find(Cidr.new("24.36.191.32/27"), 1)
449     print res
450
451     print "finding supernets of 24.36.191.33/27"
452     res = ci.find_supernets(Cidr.new("24.36.191.33/27"))
453     print res
454
455     print "finding supernets of 24.36.191.64/27"
456     res = ci.find_supernets(Cidr.new("24.36.191.64/27"))
457     print res
458
459     print "finding supernets of 3ffe:4:5:6:7::0/80"
460     res = ci.find_supernets(Cidr.valid_cidr("3ffe:4:5:6:7::0/80"))
461     print res
462
463     print "finding supernets of 48.12.6.90"
464     res = ci.find_supernets(Cidr.valid_cidr("48.12.6.90"))
465     print res
466
467     print "finding subnets of 127.0/16"
468     res = ci.find_subnets(Cidr.new("127.0/16"))
469     print res
470
471     print "finding subnets of 3ffe:4::0/32"
472     res = ci.find_subnets(Cidr.valid_cidr("3ffe:4::0/32"))
473     print res
474
475     print "finding subnets of 48.12.0.0/16"
476     res = ci.find_subnets(Cidr.valid_cidr("48.12.0.0/16"))
477     print res