added referral processing
[python-rwhoisd.git] / rwhoisd / QueryProcessor.py
index 4e9c8d2..e3bcd75 100644 (file)
@@ -1,4 +1,4 @@
-import sys
+import sys, re
 import Cidr, Rwhois, QueryParser
 
 class QueryProcessor:
@@ -68,12 +68,11 @@ class QueryProcessor:
         # searches on all applicable terms (bare or using an indexed
         # attribute) and find the intersection of the results.
 
-        # FIXME: need to put in the referral chasing logic here, I
-        # think.
-        
         st  = None
         sti = 0
 
+        orig_clause = clause[:]
+        
         # find the first searchable term:
         for term, i in zip(clause, xrange(sys.maxint)):
             attr, op, value = term
@@ -99,8 +98,101 @@ class QueryProcessor:
 
         objs = self._filter_results(self.db.fetch_objects(res.list()), clause)
 
-        return QueryResult(objs)
+        queryres = QueryResult(objs)
+
+        # look for referrals
+        refs = self.process_referral_search(orig_clause)
+        queryres.add_referrals(refs)
+        
+        return queryres
+
+    def _is_in_autharea(self, value):
+        """Returns True if value could be considered to be contained
+        within an authority area.  That is, is a subnet of a
+        network-type authority area or a subdomain of a domainname
+        type authority area."""
+        
+        aas = self.db.get_authareas()
+        
+        if isinstance(value, Cidr.Cidr):
+            for aa in aas:
+                cv = Cidr.valid_cidr(aa)
+                if cv and cv.is_supernet(value):
+                    return True
+        else:
+            for aa in aas:
+                if is_domainname(aa) and is_subdomain(aa, value):
+                    return True
+        return False
+
+    def _referral_search_cidr(self, cv, value):
+        """Return the IndexResult of a referral search for value, or
+        None if the value doesn't qualify for a Cidr referral
+        search."""
+        
+        if not cv: return None
+        if not self._is_in_autharea(cv): return None
+        return self.db.search_referral(value)
+
+    def _referral_search_domain(self, value):
+        """Return the IndexResult of a referral search for value, or
+        None if the value doesn't qualify for a domain referral
+        search."""
+        
+        if not is_domainname(value): return None
+        if not self._is_in_autharea(value): return None
+        dn = value
+        res = None
+        while dn:
+            res = self.db.search_referral(dn)
+            if res.list(): break
+            dn = reduce_domain(dn)
+        return res
+
+    def _referral_search_term(self, value):
+        """Return the IndexResult of a referral search for value, or
+        None if the value didn't qualify for a referral search."""
+        
+        cv = Cidr.valid_cidr(value)
+        if cv:
+            return self._referral_search_cidr(cv, value)
+        elif is_domainname(value):
+            return self._referral_search_domain(value)
+        return None
+        
+    def process_referral_search(self, clause):
+        """Given a query clause, attempt to search for referrals
+        associated with the terms.  Return a list of referral strings
+        that matched terms in the clause (if any).  The only terms
+        that actually get searched are the ones that look
+        'heirarchical'.  For now, the attribute part of the term is
+        essentially ignored, so a search for something like
+        'name=127.0.0.1' might concievably generate a referral, when
+        perhaps it shouldn't."""
+        
+        # first check to see if the search is explictly for a referral
+        for term in clause:
+            if (term[0] == "class-name" and term[1] == "="
+                and term[2] == "referral") or term[0] == "referred-auth-area":
+                # in which case, we return nothing
+                return []
+
+        referrals = []
+
+        # look for heirarchical-looking terms.
+        for attr, op, value in clause:
+            if op == "!=": continue
+            res = self._referral_search_term(value)
+            if not res or not res.list():
+                continue
+
+            ref_objs = self.db.fetch_objects(res.list())
+            ref_strs = [x for y in ref_objs for x in y.get_attr("referral")]
+            referrals.extend(ref_strs)
+
+        return referrals
 
+        
     def process_full_query(self, query, max=0):
         """Given a parsed query object, process it by unioning the
         results of the various ORed together clauses"""
@@ -110,6 +202,7 @@ class QueryProcessor:
             res = self.process_query_clause(query.clauses[0])
             return res
 
+        # otherwise, union the results from all the causes
         res = QueryResult()
         for clause in query.clauses:
             res.extend(self.process_query_clause(clause))
@@ -151,6 +244,8 @@ class QueryProcessor:
             session.wfile.write("\r\n")
 
         if referrals:
+            if objects:
+                session.wfile.write("\r\n")
             session.wfile.write("\r\n".join(referrals))
             session.wfile.write("\r\n")
                                 
@@ -255,6 +350,40 @@ def match_cidr(searchval, val):
         return rv == sv
 
 
+# this forms a pretty basic heuristic to see of a value looks like a
+# domain name.
+domain_regex = re.compile("[a-z0-9-]+\.[a-z0-9-.]+", re.I)
+
+def is_domainname(value):
+    if domain_regex.match(value):
+        return True
+    return False
+
+def is_subdomain(domain, subdomain):
+    domain = domain.lower();
+    subdomain = subdomain.lower();
+    
+    dlist = domain.split('.')
+    sdlist = subdomain.split('.')
+    
+    if len(dlist) > len(sdlist): return False
+    if len(dlist) == len(sdlist): return domain == subdomain
+
+    dlist.reverse();
+    sdlist.reverse()
+
+    return dlist == sdlist[:len(dlist)]
+
+def reduce_domain(domain):
+    dlist = domain.split('.')
+    dlist.pop(0)
+    return '.'.join(dlist)
+
+def is_heirarchical(value):
+    if cidr.valid_cidr(value): return True
+    if is_domainname(value): return True
+    return False
+
 if __name__ == '__main__':
 
     import MemDB, Session