Initial revision
[python-rwhoisd.git] / rwhoisd / MemIndex.py
1 import bisect, types
2 import Cidr
3
4 class MemIndex:
5     """This class implements a simple in-memory key-value map.  This
6     index supports efficient prefix matching (as well as pretty
7     efficient exact matching).  Internally, it is implemented as a
8     sorted list supporting binary searches."""
9
10     # NOTE: This implementation is probably far from as efficient as
11     # it could be.  Ideally, we should avoid the use of custom
12     # comparison functions, so that list.sort() will use built-in
13     # comparitors.  This would mean storing raw key tuples in index as
14     # opposed to element objects.  Also, it would mean avoiding the
15     # use of objects (like Cidr) as keys in favor of a primitive type.
16     # In the Cidr case, we would either have to use longs or strings,
17     # as Python doesn't seem to have an unsigned 32-bit integer type.
18     
19     def __init__(self):
20         self.index = []
21         self.sorted = False
22
23     def add(self, key, value=None):
24         """Add a key-value pair to the map.  If the map is already in
25         the prepared state, this operation will preserve it, so don't
26         use this if many elements are to be added at once.  The 'key'
27         argument may be a 2 element tuple, in which case 'value' is
28         ignored."""
29
30         if isinstance(key, types.TupleType):
31             el = element(key[0], key[1])
32         else:
33             el = element(key, value)
34
35         if self.sorted:
36             i = bisect.bisect_left(self.index, el)
37             while i < len(self.index):
38                 if self.index[i].total_equals(el):
39                     break
40                 if self.index[i] != el:
41                     self.index.insert(i, el)
42                     break
43                 i += 1
44             else:
45                 self.index.append(el)
46         else:
47             self.index.append(el)
48
49     def addlist(self, list):
50         """Add the entire list of elements to the map.  The elements
51         of 'list' may be 2 element tuples or actual 'element' objects.
52         Use this method to add many elements at once."""
53
54         self.sorted = False
55         for i in list:
56             if isinstance(i, types.TupleType):
57                 self.index.append(element(i[0], i[1]))
58             elif isinstance(i, element):
59                 self.index.append(i)
60
61     def prepare(self):
62         """Put the map in a prepared state, if necessary."""
63
64         n = len(self.index)
65         if not n: return
66         if not self.sorted:
67             self.index.sort()
68             # unique the index
69             last = self.index[0]
70             lasti = i = 1
71             while i < n:
72                 if not self.index[i].total_equals(last):
73                     self.index[lasti] = last = self.index[i]
74                     lasti += 1
75                 i += 1
76             self.index[lasti:]
77             self.sorted = True
78
79     def _find(self, key):
80         """Return the (search_element, index) tuple.  Used internally
81         only."""
82         
83         self.prepare()
84         search_el = element(key, None)
85         i = bisect.bisect_left(self.index, search_el)
86         if i > len(self.index) or i < 0:
87             print "warning: bisect.bisect_left returned something " + \
88                   "unexpected:", i, len(self.index)
89         return (search_el, i)
90
91     def find(self, key, prefix_match=False, max=0):
92         """Return a list of values whose keys string match 'key'.  If
93         prefix_match is True, then keys will match if 'key' is a
94         prefix of the element key."""
95
96         search_el, i = self._find(key)
97         res = []
98         while i < len(self.index):
99             if max and len(res) == max: break
100             if search_el.equals(self.index[i], prefix_match):
101                 res.append(self.index[i].value)
102                 i += 1
103             else:
104                 break
105         return res
106
107 class CidrMemIndex(MemIndex):
108     """This is an in-memory map that has been extended to support CIDR
109     searching semantics."""
110
111     # NOTE: this structure lends to fairly efficient exact searches
112     # (O[log2N]), effience subnet searches (also O[log2N]), but not
113     # terribly efficient supernet searches (O[32log2N]), because we
114     # have to potentially do 32 exact matches.  If we want efficient
115     # supernet searches, we will probably have to use some sort of
116     # general (i.e., not binary) search tree datastructure, as there
117     # is no sorted ordering that will efficiently give supernets that
118     # I can think of.
119
120     def add(self, key, value=None):
121         if isinstance(key, types.TupleType):
122             MemIndex.add(self, (Cidr.valid_cidr(key[0]), key[1]), value)
123         else:
124             MemIndex.add(self, Cidr.valid_cidr(key), value)
125         return
126
127     def addlist(self, list):
128
129         # make sure the keys are Cidr objects
130         for i in list:
131             if isinstance(i, types.TupleType):
132                 i = (Cidr.valid_cidr(el[0]), el[1])
133             elif isinstance(el, element):
134                 i.key = Cidr.valid_cidr(i.key)
135         
136         MemIndex.addlist(self, list)
137         return
138     
139     def find_exact(self, key, max = 0):
140
141         key = Cidr.valid_cidr(key)
142         search_el, i = self._find(key)
143         res = []
144         while i < len(self.index) and self.index[i].key == key:
145             res.append(self.index[i].value)
146             if max and len(res) == max: break
147             i += 1
148         return res
149     
150     def find_subnets(self, key, max = 0):
151         """Return all values that are subnets of 'key', including any
152         that match 'key' itself."""
153
154         key = Cidr.valid_cidr(key)
155         search_el, i = self._find(key)
156
157         res = []
158         while i < len(self.index) and self.index[i].key.is_subnet(key):
159             if max and len(res) == max: break
160             res.append(self.index[i].value)
161             i += 1
162         return res
163
164     def find_supernets(self, key, max = 0):
165         """Return all values that are supernets of 'key', including
166         any that match 'key' itself."""
167
168         key = Cidr.valid_cidr(key)
169         k = key.clone()
170         res = []
171         while k.netlen >= 0:
172             k.calc()
173             res += self.find_exact(k, max)
174             if max and len(res) >= max:
175                 return res[:max]
176             k.netlen -= 1
177
178         
179         return res
180
181     def find(self, key, prefix_match=0, max=0):
182         """Return either the exact match of 'key', or the closest
183         supernet of 'key'.  If prefix_match is True, then find all
184         supernets of 'key'"""
185
186         key = Cidr.valid_cidr(key)
187         if prefix_match == 0:
188             res = self.find_exact(key, max)
189                 
190             if not res:
191                 # now do a modified supernet search that stops after
192                 # the first proper supernet, but gets all values
193                 # matching that supernet key
194                 k = key.clone()
195                 k.netlen -= 1
196                 while not res and k.netlen >= 0:
197                     k.calc()
198                     res = self.find_exact(k, max)
199                     k.netlen -= 1
200             return res
201         
202         # for now, a prefix match means all supernets
203         return self.find_supernets(key, max)
204
205 class ComboMemIndex:
206     """This is an in-memory map that contains both a normal string
207     index and a CIDR index.  Valid CIDR values we be applied against
208     the CIDR index.  Other values will be applied against the normal
209     index."""
210     
211     def __init__(self):
212         self.normal_index = MemIndex()
213         self.cidr_index   = CidrMemIndex()
214
215     def add(self, key, value = None):
216         """Add a key,value pair to the correct map.  See MemIndex for
217         the behavior of this method"""
218         
219         if isinstance(key, types.TupleType):
220             k = key[0]
221         else:
222             k = key
223         c = Cidr.valid_cidr(key)
224         if c:
225             self.cidr_index.add(key, value)
226         else:
227             self.normal_index.add(key, value)
228         return
229
230     def addlist(self, list):
231         """Add a list of elements or key, value tuples to the
232         appropriate maps."""
233         
234         cidr_list = []
235         normal_list = []
236         
237         for i in list:
238             if isinstance(i, element):
239                 k, v = i.key, i.value
240             elif isinstance(i, types.TupleType):
241                 k, v = i[:2]
242             
243             c = Cidr.valid_cidr(k)
244             if c:
245                 cidr_list.append((c, v))
246             else:
247                 normal_list.append((k, v))
248
249         if cidr_list:
250             self.cidr_index.addlist(cidr_list)
251         if normal_list:
252             self.normal_index.addlist(normal_list)
253         return
254
255     def prepare(self):
256         """Prepare the internally held maps for searching."""
257
258         self.cidr_index.prepare()
259         self.normal_index.prepare()
260
261     def find(self, key, prefix_match=False, max=0):
262         """Return a list of values whose keys match 'key'."""
263
264         c = Cidr.valid_cidr(key)
265         if c:
266             return self.cidr_index.find(c, prefix_match, max)
267         return self.normal_index.find(key, prefix_match, max)
268
269     def find_exact(self, key, max = 0):
270         """Return a list of values whose keys match 'key'.  if 'key'
271         is not a CIDR value, then this is the same as find()."""
272
273         c = Cidr.valid_cidr(key)
274         if c:
275             return self.cidr_index.find_exact(c, max)
276         return self.normal_index.find(key, False, max)
277
278     def find_subnets(self, key, max = 0):
279         """If 'key' is a CIDR value (either a Cidr object or a valid
280         CIDR string representation, do a find_subnets on the internal
281         CidrMemIndex, otherwise return None."""
282         
283         c = Cidr.valid_cidr(key)
284         if c: return self.cidr_index.find_subnets(key, max)
285         return None
286
287     def find_supernets(self, key, max = 0):
288         """If 'key' is a CIDR value (either a Cidr object or a valid
289         CIDR string representation, do a find_supernets on the internal
290         CidrMemIndex, otherwise return None."""
291
292         c = Cidr.valid_cidr(key)
293         if c: return self.cidr_index.find_supernets(key, max)
294         return None
295     
296 class element:
297     """This is the base element class.  It basically exists to
298     simplify sorting."""
299     
300     def __init__(self, key, value):
301         self.key   = key
302         self.value = value
303
304     def __cmp__(self, other):
305         """Compare only on the key."""
306
307         if not type(self.key) == type(other.key):
308             print "other is incompatible type?", repr(other.key), other.key
309         if self.key < other.key:
310             return -1
311         if self.key == other.key:
312             return 0
313         return 1
314
315     def __str__(self):
316         return "<" + str(self.key) + ", " + str(self.value) + ">"
317
318     def __repr__(self):
319         return "element" + str(self)
320     
321     def __hash__(self):
322         return self.key.__hash__()
323
324     def equals(self, other, prefix_match=0):
325         if prefix_match:
326             return self.key == other.key[:len(self.key)]
327         return self.key == other.key
328
329     def total_equals(self, other):
330         if not isinstance(other, type(self)): return False
331         return self.key == other.key and self.value == other.value
332
333 if __name__ == "__main__":
334
335     source = [ ("foo", "foo-id"), ("bar", "bar-id"), ("baz", "baz-id"),
336                ("foobar", "foo-id-2"), ("barnone", "bar-id-2"),
337                ("zygnax", "z-id") ]
338
339     mi = MemIndex()
340     mi.addlist(source)
341
342     print "finding foobar:"
343     res = mi.find("foobar")
344     print res
345
346     print "finding foo*:"
347     res = mi.find("foo", 1)
348     print res
349
350     print "finding baz:"
351     res = mi.find("baz")
352     print res
353
354     print "adding bork"
355     mi.add("bork", "bork-id")
356
357     print "finding b*:"
358     res = mi.find("b", 1)
359     print res
360
361     ci = CidrMemIndex()
362
363     ci.add(Cidr.Cidr("127.0.0.1/24"), "net-local-1");
364     ci.add(Cidr.Cidr("127.0.0.1/32"), "net-local-2");
365     ci.add(Cidr.Cidr("216.168.224.0", 22), "net-vrsn-1")
366     ci.add(Cidr.Cidr("216.168.252.1", 32), "net-vrsn-2")
367     ci.add(Cidr.Cidr("24.36.191.0/24"), "net-foo-c")
368     ci.add(Cidr.Cidr("24.36.191.32/27"), "net-foo-sub-c")
369     ci.add(Cidr.Cidr("24.36/16"), "net-foo-b")
370
371     print "finding exactly 127.0.0.0/24"
372     res = ci.find(Cidr.Cidr("127.0.0.0/24"))
373     print res
374
375     print "finding exactly 127.0.0.16/32"
376     res = ci.find(Cidr.Cidr("127.0.0.16/32"))
377     print res
378
379     print "finding supernets of 127.0.0.16/32"
380     res = ci.find_supernets(Cidr.Cidr("127.0.0.16/32"))
381     print res
382     
383     print "finding supernets of 24.36.191.32/27"
384     res = ci.find(Cidr.Cidr("24.36.191.32/27"), 1)
385     print res
386
387     print "finding supernets of 24.36.191.33/27"
388     res = ci.find_supernets(Cidr.Cidr("24.36.191.33/27"))
389     print res
390
391     print "finding supernets of 24.36.191.64/27"
392     res = ci.find_supernets(Cidr.Cidr("24.36.191.64/27"))
393     print res
394
395     print "finding subnets of 127.0/16"
396     res = ci.find_subnets(Cidr.Cidr("127.0/16"))
397     print res