Allow netblocks to be used in CIDR indexes.
[python-rwhoisd.git] / rwhoisd / MemIndex.py
index 996ef42..88a6cf0 100644 (file)
@@ -36,7 +36,7 @@ class MemIndex:
     # use of objects (like Cidr) as keys in favor of a primitive type.
     # In the Cidr case, we would either have to use longs or strings,
     # as Python doesn't seem to have an unsigned 32-bit integer type.
-    
+
     def __init__(self):
         self.index = []
         self.sorted = False
@@ -100,7 +100,7 @@ class MemIndex:
     def _find(self, key):
         """Return the (search_element, index) tuple.  Used internally
         only."""
-        
+
         self.prepare()
         search_el = element(key, None)
         i = bisect.bisect_left(self.index, search_el)
@@ -131,32 +131,70 @@ class CidrMemIndex(MemIndex):
 
     # NOTE: this structure lends to fairly efficient exact searches
     # (O[log2N]), efficient subnet searches (also O[log2N]), but not
-    # terribly efficient supernet searches (O[32 * log2N]), because we
-    # have to potentially do 32 exact matches.  If we want efficient
-    # supernet searches, we will probably have to use some sort of
-    # general (i.e., not binary) search tree datastructure, as there
-    # is no sorted ordering that will efficiently give supernets that
-    # I can think of.
+    # terribly efficient supernet searches (O[32 * log2N] or O[128 *
+    # log2N] for IPv6), because we have to potentially do 32 (or 128!)
+    # exact matches.  If we want efficient supernet searches, we will
+    # probably have to use some sort of general (i.e., not binary)
+    # search tree datastructure, as there is no sorted ordering that
+    # will efficiently give supernets that I can think of.
+
+    # convert a key, value pair into a list of (cidr, value) tuples.
+    # It can be a list with more than one element if key is actually a
+    # netblock.
+    def _conv_key_value(self, key, value):
+        res = []
+        if isinstance(key, Cidr.Cidr):
+            res.append((key, value))
+            return res
+        if self.is_netblock(key):
+            cidrs = self.parse_netblock(key)
+            for c in cidrs:
+                res.append((c, value))
+        else:
+            c = Cidr.valid_cidr(key)
+            res.append((c, value))
+        return res
+
+    # convert a (key, value) tuple into a list of (cidr, value)
+    # tuples.
+    def _conv_tuple(self, tuple):
+        return self._conv_key_value(tuple[0], tuple[1])
 
     def add(self, key, value=None):
         if isinstance(key, types.TupleType):
-            MemIndex.add(self, (Cidr.valid_cidr(key[0]), key[1]), value)
+            l = self._conv_tuple(key)
         else:
-            MemIndex.add(self, Cidr.valid_cidr(key), value)
+            l = self._conv_key_value(key, value)
+
+        for k, v in l:
+            MemIndex.add(self, k, v)
         return
 
     def addlist(self, list):
 
+        res_list = []
         # make sure the keys are Cidr objects
         for i in list:
             if isinstance(i, types.TupleType):
-                i = (Cidr.valid_cidr(el[0]), el[1])
+                l = self._conv_tuple(i)
             elif isinstance(el, element):
-                i.key = Cidr.valid_cidr(i.key)
-        
-        MemIndex.addlist(self, list)
+                l = self._conv_key_value(i.key, i.value)
+            res_list.append(l)
+
+        MemIndex.addlist(self, res_list)
         return
-    
+
+    def is_netblock(self, key):
+        if "-" in key: return True
+        return False
+
+    def parse_netblock(self, key):
+        start, end = key.split("-", 1);
+        start = start.strip()
+        end = end.strip()
+
+        return Cidr.netblock_to_cidr(start, end)
+
     def find_exact(self, key, max = 0):
 
         key = Cidr.valid_cidr(key)
@@ -167,7 +205,7 @@ class CidrMemIndex(MemIndex):
             if max and len(res) == max: break
             i += 1
         return res
-    
+
     def find_subnets(self, key, max = 0):
         """Return all values that are subnets of 'key', including any
         that match 'key' itself."""
@@ -175,12 +213,12 @@ class CidrMemIndex(MemIndex):
         key = Cidr.valid_cidr(key)
         search_el, i = self._find(key)
 
-        res = []
+        res = set()
         while i < len(self.index) and self.index[i].key.is_subnet(key):
             if max and len(res) == max: break
-            res.append(self.index[i].value)
+            res.add(self.index[i].value)
             i += 1
-        return res
+        return list(res)
 
     def find_supernets(self, key, max = 0):
         """Return all values that are supernets of 'key', including
@@ -196,7 +234,7 @@ class CidrMemIndex(MemIndex):
                 return res[:max]
             k.netlen -= 1
 
-        
+
         return res
 
     def find(self, key, prefix_match=0, max=0):
@@ -207,7 +245,7 @@ class CidrMemIndex(MemIndex):
         key = Cidr.valid_cidr(key)
         if prefix_match == 0:
             res = self.find_exact(key, max)
-                
+
             if not res:
                 # now do a modified supernet search that stops after
                 # the first proper supernet, but gets all values
@@ -219,7 +257,7 @@ class CidrMemIndex(MemIndex):
                     res = self.find_exact(k, max)
                     k.netlen -= 1
             return res
-        
+
         # for now, a prefix match means all supernets
         return self.find_supernets(key, max)
 
@@ -228,7 +266,7 @@ class ComboMemIndex:
     index and a CIDR index.  Valid CIDR values we be applied against
     the CIDR index.  Other values will be applied against the normal
     index."""
-    
+
     def __init__(self):
         self.normal_index = MemIndex()
         self.cidr_index   = CidrMemIndex()
@@ -236,7 +274,7 @@ class ComboMemIndex:
     def add(self, key, value = None):
         """Add a key,value pair to the correct map.  See MemIndex for
         the behavior of this method"""
-        
+
         if isinstance(key, types.TupleType):
             k = key[0]
         else:
@@ -251,16 +289,16 @@ class ComboMemIndex:
     def addlist(self, list):
         """Add a list of elements or key, value tuples to the
         appropriate maps."""
-        
+
         cidr_list = []
         normal_list = []
-        
+
         for i in list:
             if isinstance(i, element):
                 k, v = i.key, i.value
             elif isinstance(i, types.TupleType):
                 k, v = i[:2]
-            
+
             c = Cidr.valid_cidr(k)
             if c:
                 cidr_list.append((c, v))
@@ -300,7 +338,7 @@ class ComboMemIndex:
         """If 'key' is a CIDR value (either a Cidr object or a valid
         CIDR string representation, do a find_subnets on the internal
         CidrMemIndex, otherwise return None."""
-        
+
         c = Cidr.valid_cidr(key)
         if c: return self.cidr_index.find_subnets(key, max)
         return None
@@ -313,11 +351,11 @@ class ComboMemIndex:
         c = Cidr.valid_cidr(key)
         if c: return self.cidr_index.find_supernets(key, max)
         return None
-    
+
 class element:
     """This is the base element class.  It basically exists to
     simplify sorting."""
-    
+
     def __init__(self, key, value):
         self.key   = key
         self.value = value
@@ -338,7 +376,7 @@ class element:
 
     def __repr__(self):
         return "element" + str(self)
-    
+
     def __hash__(self):
         return self.key.__hash__()
 
@@ -381,38 +419,61 @@ if __name__ == "__main__":
 
     ci = CidrMemIndex()
 
-    ci.add(Cidr.Cidr("127.0.0.1/24"), "net-local-1");
-    ci.add(Cidr.Cidr("127.0.0.1/32"), "net-local-2");
-    ci.add(Cidr.Cidr("216.168.224.0", 22), "net-vrsn-1")
-    ci.add(Cidr.Cidr("216.168.252.1", 32), "net-vrsn-2")
-    ci.add(Cidr.Cidr("24.36.191.0/24"), "net-foo-c")
-    ci.add(Cidr.Cidr("24.36.191.32/27"), "net-foo-sub-c")
-    ci.add(Cidr.Cidr("24.36/16"), "net-foo-b")
+    ci.add("127.0.0.1/24", "net-local-1");
+    ci.add("127.0.0.1/32", "net-local-2");
+    ci.add(Cidr.Cidr.create("216.168.224.0", 22), "net-vrsn-1")
+    ci.add(Cidr.Cidr.create("216.168.252.1", 32), "net-vrsn-2")
+    ci.add("24.36.191.0/24", "net-foo-c")
+    ci.add("24.36.191.32/27", "net-foo-sub-c")
+    ci.add("24.36/16", "net-foo-b")
+    ci.add("3ffe:4:5::0/48", "net-foo-d6")
+    ci.add("3ffe:4:5:6::0/64", "net-foo-e6")
+    ci.add("48.12.6.0 - 48.12.6.95", "net-bar-1")
 
     print "finding exactly 127.0.0.0/24"
-    res = ci.find(Cidr.Cidr("127.0.0.0/24"))
+    res = ci.find(Cidr.Cidr.create("127.0.0.0/24"))
     print res
 
     print "finding exactly 127.0.0.16/32"
-    res = ci.find(Cidr.Cidr("127.0.0.16/32"))
+    res = ci.find(Cidr.Cidr.create("127.0.0.16/32"))
+    print res
+
+    print "finding exactly 3ffe:4:5:6::0/64"
+    res = ci.find(Cidr.valid_cidr("3ffe:4:5:6::/64"))
     print res
 
     print "finding supernets of 127.0.0.16/32"
-    res = ci.find_supernets(Cidr.Cidr("127.0.0.16/32"))
+    res = ci.find_supernets(Cidr.Cidr.create("127.0.0.16/32"))
     print res
-    
+
     print "finding supernets of 24.36.191.32/27"
-    res = ci.find(Cidr.Cidr("24.36.191.32/27"), 1)
+    res = ci.find(Cidr.Cidr.create("24.36.191.32/27"), 1)
     print res
 
     print "finding supernets of 24.36.191.33/27"
-    res = ci.find_supernets(Cidr.Cidr("24.36.191.33/27"))
+    res = ci.find_supernets(Cidr.Cidr.create("24.36.191.33/27"))
     print res
 
     print "finding supernets of 24.36.191.64/27"
-    res = ci.find_supernets(Cidr.Cidr("24.36.191.64/27"))
+    res = ci.find_supernets(Cidr.Cidr.create("24.36.191.64/27"))
+    print res
+
+    print "finding supernets of 3ffe:4:5:6:7::0/80"
+    res = ci.find_supernets(Cidr.valid_cidr("3ffe:4:5:6:7::0/80"))
+    print res
+
+    print "finding supernets of 48.12.6.90"
+    res = ci.find_supernets(Cidr.valid_cidr("48.12.6.90"))
     print res
 
     print "finding subnets of 127.0/16"
-    res = ci.find_subnets(Cidr.Cidr("127.0/16"))
+    res = ci.find_subnets(Cidr.Cidr.create("127.0/16"))
+    print res
+
+    print "finding subnets of 3ffe:4::0/32"
+    res = ci.find_subnets(Cidr.valid_cidr("3ffe:4::0/32"))
+    print res
+
+    print "finding subnets of 48.12.0.0/16"
+    res = ci.find_subnets(Cidr.valid_cidr("48.12.0.0/16"))
     print res