copyright and license notices
[python-rwhoisd.git] / rwhoisd / MemIndex.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, David E. Blacka
4 #
5 # $Id: MemIndex.py,v 1.2 2003/04/28 16:43:19 davidb Exp $
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation; either version 2 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful, but
13 # WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15 # General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program; if not, write to the Free Software
19 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
20 # USA
21
22 import bisect, types
23 import Cidr
24
25 class MemIndex:
26     """This class implements a simple in-memory key-value map.  This
27     index supports efficient prefix matching (as well as pretty
28     efficient exact matching).  Internally, it is implemented as a
29     sorted list supporting binary searches."""
30
31     # NOTE: This implementation is probably far from as efficient as
32     # it could be.  Ideally, we should avoid the use of custom
33     # comparison functions, so that list.sort() will use built-in
34     # comparitors.  This would mean storing raw key tuples in index as
35     # opposed to element objects.  Also, it would mean avoiding the
36     # use of objects (like Cidr) as keys in favor of a primitive type.
37     # In the Cidr case, we would either have to use longs or strings,
38     # as Python doesn't seem to have an unsigned 32-bit integer type.
39     
40     def __init__(self):
41         self.index = []
42         self.sorted = False
43
44     def add(self, key, value=None):
45         """Add a key-value pair to the map.  If the map is already in
46         the prepared state, this operation will preserve it, so don't
47         use this if many elements are to be added at once.  The 'key'
48         argument may be a 2 element tuple, in which case 'value' is
49         ignored."""
50
51         if isinstance(key, types.TupleType):
52             el = element(key[0], key[1])
53         else:
54             el = element(key, value)
55
56         if self.sorted:
57             i = bisect.bisect_left(self.index, el)
58             while i < len(self.index):
59                 if self.index[i].total_equals(el):
60                     break
61                 if self.index[i] != el:
62                     self.index.insert(i, el)
63                     break
64                 i += 1
65             else:
66                 self.index.append(el)
67         else:
68             self.index.append(el)
69
70     def addlist(self, list):
71         """Add the entire list of elements to the map.  The elements
72         of 'list' may be 2 element tuples or actual 'element' objects.
73         Use this method to add many elements at once."""
74
75         self.sorted = False
76         for i in list:
77             if isinstance(i, types.TupleType):
78                 self.index.append(element(i[0], i[1]))
79             elif isinstance(i, element):
80                 self.index.append(i)
81
82     def prepare(self):
83         """Put the map in a prepared state, if necessary."""
84
85         n = len(self.index)
86         if not n: return
87         if not self.sorted:
88             self.index.sort()
89             # unique the index
90             last = self.index[0]
91             lasti = i = 1
92             while i < n:
93                 if not self.index[i].total_equals(last):
94                     self.index[lasti] = last = self.index[i]
95                     lasti += 1
96                 i += 1
97             self.index[lasti:]
98             self.sorted = True
99
100     def _find(self, key):
101         """Return the (search_element, index) tuple.  Used internally
102         only."""
103         
104         self.prepare()
105         search_el = element(key, None)
106         i = bisect.bisect_left(self.index, search_el)
107         if i > len(self.index) or i < 0:
108             print "warning: bisect.bisect_left returned something " + \
109                   "unexpected:", i, len(self.index)
110         return (search_el, i)
111
112     def find(self, key, prefix_match=False, max=0):
113         """Return a list of values whose keys string match 'key'.  If
114         prefix_match is True, then keys will match if 'key' is a
115         prefix of the element key."""
116
117         search_el, i = self._find(key)
118         res = []
119         while i < len(self.index):
120             if max and len(res) == max: break
121             if search_el.equals(self.index[i], prefix_match):
122                 res.append(self.index[i].value)
123                 i += 1
124             else:
125                 break
126         return res
127
128 class CidrMemIndex(MemIndex):
129     """This is an in-memory map that has been extended to support CIDR
130     searching semantics."""
131
132     # NOTE: this structure lends to fairly efficient exact searches
133     # (O[log2N]), effience subnet searches (also O[log2N]), but not
134     # terribly efficient supernet searches (O[32log2N]), because we
135     # have to potentially do 32 exact matches.  If we want efficient
136     # supernet searches, we will probably have to use some sort of
137     # general (i.e., not binary) search tree datastructure, as there
138     # is no sorted ordering that will efficiently give supernets that
139     # I can think of.
140
141     def add(self, key, value=None):
142         if isinstance(key, types.TupleType):
143             MemIndex.add(self, (Cidr.valid_cidr(key[0]), key[1]), value)
144         else:
145             MemIndex.add(self, Cidr.valid_cidr(key), value)
146         return
147
148     def addlist(self, list):
149
150         # make sure the keys are Cidr objects
151         for i in list:
152             if isinstance(i, types.TupleType):
153                 i = (Cidr.valid_cidr(el[0]), el[1])
154             elif isinstance(el, element):
155                 i.key = Cidr.valid_cidr(i.key)
156         
157         MemIndex.addlist(self, list)
158         return
159     
160     def find_exact(self, key, max = 0):
161
162         key = Cidr.valid_cidr(key)
163         search_el, i = self._find(key)
164         res = []
165         while i < len(self.index) and self.index[i].key == key:
166             res.append(self.index[i].value)
167             if max and len(res) == max: break
168             i += 1
169         return res
170     
171     def find_subnets(self, key, max = 0):
172         """Return all values that are subnets of 'key', including any
173         that match 'key' itself."""
174
175         key = Cidr.valid_cidr(key)
176         search_el, i = self._find(key)
177
178         res = []
179         while i < len(self.index) and self.index[i].key.is_subnet(key):
180             if max and len(res) == max: break
181             res.append(self.index[i].value)
182             i += 1
183         return res
184
185     def find_supernets(self, key, max = 0):
186         """Return all values that are supernets of 'key', including
187         any that match 'key' itself."""
188
189         key = Cidr.valid_cidr(key)
190         k = key.clone()
191         res = []
192         while k.netlen >= 0:
193             k.calc()
194             res += self.find_exact(k, max)
195             if max and len(res) >= max:
196                 return res[:max]
197             k.netlen -= 1
198
199         
200         return res
201
202     def find(self, key, prefix_match=0, max=0):
203         """Return either the exact match of 'key', or the closest
204         supernet of 'key'.  If prefix_match is True, then find all
205         supernets of 'key'"""
206
207         key = Cidr.valid_cidr(key)
208         if prefix_match == 0:
209             res = self.find_exact(key, max)
210                 
211             if not res:
212                 # now do a modified supernet search that stops after
213                 # the first proper supernet, but gets all values
214                 # matching that supernet key
215                 k = key.clone()
216                 k.netlen -= 1
217                 while not res and k.netlen >= 0:
218                     k.calc()
219                     res = self.find_exact(k, max)
220                     k.netlen -= 1
221             return res
222         
223         # for now, a prefix match means all supernets
224         return self.find_supernets(key, max)
225
226 class ComboMemIndex:
227     """This is an in-memory map that contains both a normal string
228     index and a CIDR index.  Valid CIDR values we be applied against
229     the CIDR index.  Other values will be applied against the normal
230     index."""
231     
232     def __init__(self):
233         self.normal_index = MemIndex()
234         self.cidr_index   = CidrMemIndex()
235
236     def add(self, key, value = None):
237         """Add a key,value pair to the correct map.  See MemIndex for
238         the behavior of this method"""
239         
240         if isinstance(key, types.TupleType):
241             k = key[0]
242         else:
243             k = key
244         c = Cidr.valid_cidr(key)
245         if c:
246             self.cidr_index.add(key, value)
247         else:
248             self.normal_index.add(key, value)
249         return
250
251     def addlist(self, list):
252         """Add a list of elements or key, value tuples to the
253         appropriate maps."""
254         
255         cidr_list = []
256         normal_list = []
257         
258         for i in list:
259             if isinstance(i, element):
260                 k, v = i.key, i.value
261             elif isinstance(i, types.TupleType):
262                 k, v = i[:2]
263             
264             c = Cidr.valid_cidr(k)
265             if c:
266                 cidr_list.append((c, v))
267             else:
268                 normal_list.append((k, v))
269
270         if cidr_list:
271             self.cidr_index.addlist(cidr_list)
272         if normal_list:
273             self.normal_index.addlist(normal_list)
274         return
275
276     def prepare(self):
277         """Prepare the internally held maps for searching."""
278
279         self.cidr_index.prepare()
280         self.normal_index.prepare()
281
282     def find(self, key, prefix_match=False, max=0):
283         """Return a list of values whose keys match 'key'."""
284
285         c = Cidr.valid_cidr(key)
286         if c:
287             return self.cidr_index.find(c, prefix_match, max)
288         return self.normal_index.find(key, prefix_match, max)
289
290     def find_exact(self, key, max = 0):
291         """Return a list of values whose keys match 'key'.  if 'key'
292         is not a CIDR value, then this is the same as find()."""
293
294         c = Cidr.valid_cidr(key)
295         if c:
296             return self.cidr_index.find_exact(c, max)
297         return self.normal_index.find(key, False, max)
298
299     def find_subnets(self, key, max = 0):
300         """If 'key' is a CIDR value (either a Cidr object or a valid
301         CIDR string representation, do a find_subnets on the internal
302         CidrMemIndex, otherwise return None."""
303         
304         c = Cidr.valid_cidr(key)
305         if c: return self.cidr_index.find_subnets(key, max)
306         return None
307
308     def find_supernets(self, key, max = 0):
309         """If 'key' is a CIDR value (either a Cidr object or a valid
310         CIDR string representation, do a find_supernets on the internal
311         CidrMemIndex, otherwise return None."""
312
313         c = Cidr.valid_cidr(key)
314         if c: return self.cidr_index.find_supernets(key, max)
315         return None
316     
317 class element:
318     """This is the base element class.  It basically exists to
319     simplify sorting."""
320     
321     def __init__(self, key, value):
322         self.key   = key
323         self.value = value
324
325     def __cmp__(self, other):
326         """Compare only on the key."""
327
328         if not type(self.key) == type(other.key):
329             print "other is incompatible type?", repr(other.key), other.key
330         if self.key < other.key:
331             return -1
332         if self.key == other.key:
333             return 0
334         return 1
335
336     def __str__(self):
337         return "<" + str(self.key) + ", " + str(self.value) + ">"
338
339     def __repr__(self):
340         return "element" + str(self)
341     
342     def __hash__(self):
343         return self.key.__hash__()
344
345     def equals(self, other, prefix_match=0):
346         if prefix_match:
347             return self.key == other.key[:len(self.key)]
348         return self.key == other.key
349
350     def total_equals(self, other):
351         if not isinstance(other, type(self)): return False
352         return self.key == other.key and self.value == other.value
353
354 if __name__ == "__main__":
355
356     source = [ ("foo", "foo-id"), ("bar", "bar-id"), ("baz", "baz-id"),
357                ("foobar", "foo-id-2"), ("barnone", "bar-id-2"),
358                ("zygnax", "z-id") ]
359
360     mi = MemIndex()
361     mi.addlist(source)
362
363     print "finding foobar:"
364     res = mi.find("foobar")
365     print res
366
367     print "finding foo*:"
368     res = mi.find("foo", 1)
369     print res
370
371     print "finding baz:"
372     res = mi.find("baz")
373     print res
374
375     print "adding bork"
376     mi.add("bork", "bork-id")
377
378     print "finding b*:"
379     res = mi.find("b", 1)
380     print res
381
382     ci = CidrMemIndex()
383
384     ci.add(Cidr.Cidr("127.0.0.1/24"), "net-local-1");
385     ci.add(Cidr.Cidr("127.0.0.1/32"), "net-local-2");
386     ci.add(Cidr.Cidr("216.168.224.0", 22), "net-vrsn-1")
387     ci.add(Cidr.Cidr("216.168.252.1", 32), "net-vrsn-2")
388     ci.add(Cidr.Cidr("24.36.191.0/24"), "net-foo-c")
389     ci.add(Cidr.Cidr("24.36.191.32/27"), "net-foo-sub-c")
390     ci.add(Cidr.Cidr("24.36/16"), "net-foo-b")
391
392     print "finding exactly 127.0.0.0/24"
393     res = ci.find(Cidr.Cidr("127.0.0.0/24"))
394     print res
395
396     print "finding exactly 127.0.0.16/32"
397     res = ci.find(Cidr.Cidr("127.0.0.16/32"))
398     print res
399
400     print "finding supernets of 127.0.0.16/32"
401     res = ci.find_supernets(Cidr.Cidr("127.0.0.16/32"))
402     print res
403     
404     print "finding supernets of 24.36.191.32/27"
405     res = ci.find(Cidr.Cidr("24.36.191.32/27"), 1)
406     print res
407
408     print "finding supernets of 24.36.191.33/27"
409     res = ci.find_supernets(Cidr.Cidr("24.36.191.33/27"))
410     print res
411
412     print "finding supernets of 24.36.191.64/27"
413     res = ci.find_supernets(Cidr.Cidr("24.36.191.64/27"))
414     print res
415
416     print "finding subnets of 127.0/16"
417     res = ci.find_subnets(Cidr.Cidr("127.0/16"))
418     print res