Change Cidr.Cidr.create() to just Cidr.new(), which seems a bit more concise
[python-rwhoisd.git] / rwhoisd / MemIndex.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, David E. Blacka
4 #
5 # $Id: MemIndex.py,v 1.2 2003/04/28 16:43:19 davidb Exp $
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation; either version 2 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful, but
13 # WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15 # General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program; if not, write to the Free Software
19 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
20 # USA
21
22 import bisect, types
23 import Cidr
24
25 class MemIndex:
26     """This class implements a simple in-memory key-value map.  This
27     index supports efficient prefix matching (as well as pretty
28     efficient exact matching).  Internally, it is implemented as a
29     sorted list supporting binary searches."""
30
31     # NOTE: This implementation is probably far from as efficient as
32     # it could be.  Ideally, we should avoid the use of custom
33     # comparison functions, so that list.sort() will use built-in
34     # comparitors.  This would mean storing raw key tuples in index as
35     # opposed to element objects.  Also, it would mean avoiding the
36     # use of objects (like Cidr) as keys in favor of a primitive type.
37     # In the Cidr case, we would either have to use longs or strings,
38     # as Python doesn't seem to have an unsigned 32-bit integer type.
39
40     def __init__(self):
41         self.index = []
42         self.sorted = False
43
44     def add(self, key, value=None):
45         """Add a key-value pair to the map.  If the map is already in
46         the prepared state, this operation will preserve it, so don't
47         use this if many elements are to be added at once.  The 'key'
48         argument may be a 2 element tuple, in which case 'value' is
49         ignored."""
50
51         if isinstance(key, types.TupleType):
52             el = element(key[0], key[1])
53         else:
54             el = element(key, value)
55
56         if self.sorted:
57             i = bisect.bisect_left(self.index, el)
58             while i < len(self.index):
59                 if self.index[i].total_equals(el):
60                     break
61                 if self.index[i] != el:
62                     self.index.insert(i, el)
63                     break
64                 i += 1
65             else:
66                 self.index.append(el)
67         else:
68             self.index.append(el)
69
70     def addlist(self, list):
71         """Add the entire list of elements to the map.  The elements
72         of 'list' may be 2 element tuples or actual 'element' objects.
73         Use this method to add many elements at once."""
74
75         self.sorted = False
76         for i in list:
77             if isinstance(i, types.TupleType):
78                 self.index.append(element(i[0], i[1]))
79             elif isinstance(i, element):
80                 self.index.append(i)
81
82     def prepare(self):
83         """Put the map in a prepared state, if necessary."""
84
85         n = len(self.index)
86         if not n: return
87         if not self.sorted:
88             self.index.sort()
89             # unique the index
90             last = self.index[0]
91             lasti = i = 1
92             while i < n:
93                 if not self.index[i].total_equals(last):
94                     self.index[lasti] = last = self.index[i]
95                     lasti += 1
96                 i += 1
97             self.index[lasti:]
98             self.sorted = True
99
100     def _find(self, key):
101         """Return the (search_element, index) tuple.  Used internally
102         only."""
103
104         self.prepare()
105         search_el = element(key, None)
106         i = bisect.bisect_left(self.index, search_el)
107         if i > len(self.index) or i < 0:
108             print "warning: bisect.bisect_left returned something " + \
109                   "unexpected:", i, len(self.index)
110         return (search_el, i)
111
112     def find(self, key, prefix_match=False, max=0):
113         """Return a list of values whose keys string match 'key'.  If
114         prefix_match is True, then keys will match if 'key' is a
115         prefix of the element key."""
116
117         search_el, i = self._find(key)
118         res = []
119         while i < len(self.index):
120             if max and len(res) == max: break
121             if search_el.equals(self.index[i], prefix_match):
122                 res.append(self.index[i].value)
123                 i += 1
124             else:
125                 break
126         return res
127
128 class CidrMemIndex(MemIndex):
129     """This is an in-memory map that has been extended to support CIDR
130     searching semantics."""
131
132     # NOTE: this structure lends to fairly efficient exact searches
133     # (O[log2N]), efficient subnet searches (also O[log2N]), but not
134     # terribly efficient supernet searches (O[32 * log2N] or O[128 *
135     # log2N] for IPv6), because we have to potentially do 32 (or 128!)
136     # exact matches.  If we want efficient supernet searches, we will
137     # probably have to use some sort of general (i.e., not binary)
138     # search tree datastructure, as there is no sorted ordering that
139     # will efficiently give supernets that I can think of.
140
141     # convert a key, value pair into a list of (cidr, value) tuples.
142     # It can be a list with more than one element if key is actually a
143     # netblock.
144     def _conv_key_value(self, key, value):
145         res = []
146         if isinstance(key, Cidr.Cidr):
147             res.append((key, value))
148             return res
149         if self.is_netblock(key):
150             cidrs = self.parse_netblock(key)
151             for c in cidrs:
152                 res.append((c, value))
153         else:
154             c = Cidr.valid_cidr(key)
155             res.append((c, value))
156         return res
157
158     # convert a (key, value) tuple into a list of (cidr, value)
159     # tuples.
160     def _conv_tuple(self, tuple):
161         return self._conv_key_value(tuple[0], tuple[1])
162
163     def add(self, key, value=None):
164         if isinstance(key, types.TupleType):
165             l = self._conv_tuple(key)
166         else:
167             l = self._conv_key_value(key, value)
168
169         for k, v in l:
170             MemIndex.add(self, k, v)
171         return
172
173     def addlist(self, list):
174
175         res_list = []
176         # make sure the keys are Cidr objects
177         for i in list:
178             if isinstance(i, types.TupleType):
179                 l = self._conv_tuple(i)
180             elif isinstance(el, element):
181                 l = self._conv_key_value(i.key, i.value)
182             res_list.append(l)
183
184         MemIndex.addlist(self, res_list)
185         return
186
187     def is_netblock(self, key):
188         if "-" in key: return True
189         return False
190
191     def parse_netblock(self, key):
192         start, end = key.split("-", 1);
193         start = start.strip()
194         end = end.strip()
195
196         return Cidr.netblock_to_cidr(start, end)
197
198     def find_exact(self, key, max = 0):
199
200         key = Cidr.valid_cidr(key)
201         search_el, i = self._find(key)
202         res = []
203         while i < len(self.index) and self.index[i].key == key:
204             res.append(self.index[i].value)
205             if max and len(res) == max: break
206             i += 1
207         return res
208
209     def find_subnets(self, key, max = 0):
210         """Return all values that are subnets of 'key', including any
211         that match 'key' itself."""
212
213         key = Cidr.valid_cidr(key)
214         search_el, i = self._find(key)
215
216         res = set()
217         while i < len(self.index) and self.index[i].key.is_subnet(key):
218             if max and len(res) == max: break
219             res.add(self.index[i].value)
220             i += 1
221         return list(res)
222
223     def find_supernets(self, key, max = 0):
224         """Return all values that are supernets of 'key', including
225         any that match 'key' itself."""
226
227         key = Cidr.valid_cidr(key)
228         k = key.clone()
229         res = []
230         while k.netlen >= 0:
231             k.calc()
232             res += self.find_exact(k, max)
233             if max and len(res) >= max:
234                 return res[:max]
235             k.netlen -= 1
236
237
238         return res
239
240     def find(self, key, prefix_match=0, max=0):
241         """Return either the exact match of 'key', or the closest
242         supernet of 'key'.  If prefix_match is True, then find all
243         supernets of 'key'"""
244
245         key = Cidr.valid_cidr(key)
246         if prefix_match == 0:
247             res = self.find_exact(key, max)
248
249             if not res:
250                 # now do a modified supernet search that stops after
251                 # the first proper supernet, but gets all values
252                 # matching that supernet key
253                 k = key.clone()
254                 k.netlen -= 1
255                 while not res and k.netlen >= 0:
256                     k.calc()
257                     res = self.find_exact(k, max)
258                     k.netlen -= 1
259             return res
260
261         # for now, a prefix match means all supernets
262         return self.find_supernets(key, max)
263
264 class ComboMemIndex:
265     """This is an in-memory map that contains both a normal string
266     index and a CIDR index.  Valid CIDR values we be applied against
267     the CIDR index.  Other values will be applied against the normal
268     index."""
269
270     def __init__(self):
271         self.normal_index = MemIndex()
272         self.cidr_index   = CidrMemIndex()
273
274     def add(self, key, value = None):
275         """Add a key,value pair to the correct map.  See MemIndex for
276         the behavior of this method"""
277
278         if isinstance(key, types.TupleType):
279             k = key[0]
280         else:
281             k = key
282         c = Cidr.valid_cidr(key)
283         if c:
284             self.cidr_index.add(key, value)
285         else:
286             self.normal_index.add(key, value)
287         return
288
289     def addlist(self, list):
290         """Add a list of elements or key, value tuples to the
291         appropriate maps."""
292
293         cidr_list = []
294         normal_list = []
295
296         for i in list:
297             if isinstance(i, element):
298                 k, v = i.key, i.value
299             elif isinstance(i, types.TupleType):
300                 k, v = i[:2]
301
302             c = Cidr.valid_cidr(k)
303             if c:
304                 cidr_list.append((c, v))
305             else:
306                 normal_list.append((k, v))
307
308         if cidr_list:
309             self.cidr_index.addlist(cidr_list)
310         if normal_list:
311             self.normal_index.addlist(normal_list)
312         return
313
314     def prepare(self):
315         """Prepare the internally held maps for searching."""
316
317         self.cidr_index.prepare()
318         self.normal_index.prepare()
319
320     def find(self, key, prefix_match=False, max=0):
321         """Return a list of values whose keys match 'key'."""
322
323         c = Cidr.valid_cidr(key)
324         if c:
325             return self.cidr_index.find(c, prefix_match, max)
326         return self.normal_index.find(key, prefix_match, max)
327
328     def find_exact(self, key, max = 0):
329         """Return a list of values whose keys match 'key'.  if 'key'
330         is not a CIDR value, then this is the same as find()."""
331
332         c = Cidr.valid_cidr(key)
333         if c:
334             return self.cidr_index.find_exact(c, max)
335         return self.normal_index.find(key, False, max)
336
337     def find_subnets(self, key, max = 0):
338         """If 'key' is a CIDR value (either a Cidr object or a valid
339         CIDR string representation, do a find_subnets on the internal
340         CidrMemIndex, otherwise return None."""
341
342         c = Cidr.valid_cidr(key)
343         if c: return self.cidr_index.find_subnets(key, max)
344         return None
345
346     def find_supernets(self, key, max = 0):
347         """If 'key' is a CIDR value (either a Cidr object or a valid
348         CIDR string representation, do a find_supernets on the internal
349         CidrMemIndex, otherwise return None."""
350
351         c = Cidr.valid_cidr(key)
352         if c: return self.cidr_index.find_supernets(key, max)
353         return None
354
355 class element:
356     """This is the base element class.  It basically exists to
357     simplify sorting."""
358
359     def __init__(self, key, value):
360         self.key   = key
361         self.value = value
362
363     def __cmp__(self, other):
364         """Compare only on the key."""
365
366         if not type(self.key) == type(other.key):
367             print "other is incompatible type?", repr(other.key), other.key
368         if self.key < other.key:
369             return -1
370         if self.key == other.key:
371             return 0
372         return 1
373
374     def __str__(self):
375         return "<" + str(self.key) + ", " + str(self.value) + ">"
376
377     def __repr__(self):
378         return "element" + str(self)
379
380     def __hash__(self):
381         return self.key.__hash__()
382
383     def equals(self, other, prefix_match=0):
384         if prefix_match:
385             return self.key == other.key[:len(self.key)]
386         return self.key == other.key
387
388     def total_equals(self, other):
389         if not isinstance(other, type(self)): return False
390         return self.key == other.key and self.value == other.value
391
392 if __name__ == "__main__":
393
394     source = [ ("foo", "foo-id"), ("bar", "bar-id"), ("baz", "baz-id"),
395                ("foobar", "foo-id-2"), ("barnone", "bar-id-2"),
396                ("zygnax", "z-id") ]
397
398     mi = MemIndex()
399     mi.addlist(source)
400
401     print "finding foobar:"
402     res = mi.find("foobar")
403     print res
404
405     print "finding foo*:"
406     res = mi.find("foo", 1)
407     print res
408
409     print "finding baz:"
410     res = mi.find("baz")
411     print res
412
413     print "adding bork"
414     mi.add("bork", "bork-id")
415
416     print "finding b*:"
417     res = mi.find("b", 1)
418     print res
419
420     ci = CidrMemIndex()
421
422     ci.add("127.0.0.1/24", "net-local-1");
423     ci.add("127.0.0.1/32", "net-local-2");
424     ci.add(Cidr.new("216.168.224.0", 22), "net-vrsn-1")
425     ci.add(Cidr.new("216.168.252.1", 32), "net-vrsn-2")
426     ci.add("24.36.191.0/24", "net-foo-c")
427     ci.add("24.36.191.32/27", "net-foo-sub-c")
428     ci.add("24.36/16", "net-foo-b")
429     ci.add("3ffe:4:5::0/48", "net-foo-d6")
430     ci.add("3ffe:4:5:6::0/64", "net-foo-e6")
431     ci.add("48.12.6.0 - 48.12.6.95", "net-bar-1")
432
433     print "finding exactly 127.0.0.0/24"
434     res = ci.find(Cidr.new("127.0.0.0/24"))
435     print res
436
437     print "finding exactly 127.0.0.16/32"
438     res = ci.find(Cidr.new("127.0.0.16/32"))
439     print res
440
441     print "finding exactly 3ffe:4:5:6::0/64"
442     res = ci.find(Cidr.valid_cidr("3ffe:4:5:6::/64"))
443     print res
444
445     print "finding supernets of 127.0.0.16/32"
446     res = ci.find_supernets(Cidr.new("127.0.0.16/32"))
447     print res
448
449     print "finding supernets of 24.36.191.32/27"
450     res = ci.find(Cidr.new("24.36.191.32/27"), 1)
451     print res
452
453     print "finding supernets of 24.36.191.33/27"
454     res = ci.find_supernets(Cidr.new("24.36.191.33/27"))
455     print res
456
457     print "finding supernets of 24.36.191.64/27"
458     res = ci.find_supernets(Cidr.new("24.36.191.64/27"))
459     print res
460
461     print "finding supernets of 3ffe:4:5:6:7::0/80"
462     res = ci.find_supernets(Cidr.valid_cidr("3ffe:4:5:6:7::0/80"))
463     print res
464
465     print "finding supernets of 48.12.6.90"
466     res = ci.find_supernets(Cidr.valid_cidr("48.12.6.90"))
467     print res
468
469     print "finding subnets of 127.0/16"
470     res = ci.find_subnets(Cidr.new("127.0/16"))
471     print res
472
473     print "finding subnets of 3ffe:4::0/32"
474     res = ci.find_subnets(Cidr.valid_cidr("3ffe:4::0/32"))
475     print res
476
477     print "finding subnets of 48.12.0.0/16"
478     res = ci.find_subnets(Cidr.valid_cidr("48.12.0.0/16"))
479     print res