Add netblock examples and documantion
[python-rwhoisd.git] / rwhoisd / QueryProcessor.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, David E. Blacka
4 #
5 # $Id: QueryProcessor.py,v 1.3 2003/04/28 16:44:56 davidb Exp $
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation; either version 2 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful, but
13 # WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15 # General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program; if not, write to the Free Software
19 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
20 # USA
21
22 import sys, re
23 import Cidr, Rwhois, QueryParser
24
25 class QueryProcessor:
26
27     def __init__(self, db):
28         self.db = db
29
30     def _filter_obj_term(self, obj, term):
31         """Given a rwhoisobject and a query term (a 3 element tuple:
32         attr, operator, value), determine if the object satisfies the
33         term.  Returns True if the object matches the term, False if
34         not."""
35
36         attr, op, searchval = term
37         res = False
38
39         # filter by named attribute
40         if attr:
41             vals = obj.get_attr(attr)
42             if not vals:
43                 res = False
44             else:
45                 res = match_values(searchval, vals)
46             if op == "!=": return not res
47             return res
48         # filter by general term
49         else:
50             for val in obj.values():
51                 if match_value(searchval, val):
52                     return True
53             return False
54
55     def _filter_obj(self, obj, terms):
56         """Given a rwhoisobject and a list of query terms (i.e., a
57         whole AND clause), return True if the object satisfies the
58         terms."""
59
60         for term in terms:
61             if not self._filter_obj_term(obj, term): return False
62         return True
63
64     def _filter_results(self, reslist, terms):
65         """Given list of result objects (not simply the ids returned
66         from the search) and a list of query terms (i.e., a query
67         clause), remove elements that do not satisfy the terms.
68         Returns a list of objects that satisfy the filters."""
69
70         if not terms: return reslist
71         return [ x for x in reslist if self._filter_obj(x, terms) ]
72
73     def process_query_clause(self, clause, max=0):
74         """Process a query clause (a grouping of terms ANDed
75         together).  This is where the indexed searches actually get
76         done.  The technique used here is to search on one index and
77         use the rest of the clause to filter the results.  Returns a
78         QueryResult object"""
79
80         # the technique is to do an index search on the first (or
81         # maybe best) indexed term (bare terms are always considered
82         # indexed), and filter those results with the remaining terms.
83
84         # Note: this could be better if we found the "optimal" query
85         # term.  One approach may be to create a cost function and
86         # search for the minimum cost term.
87
88         # Note: another approach might be to actually do indexed
89         # searches on all applicable terms (bare or using an indexed
90         # attribute) and find the intersection of the results.
91
92         st  = None
93         sti = 0
94
95         orig_clause = clause[:]
96         
97         # find the first searchable term:
98         for term, i in zip(clause, xrange(sys.maxint)):
99             attr, op, value = term
100             if op == "!=": continue
101             if not attr or self.db.is_indexed_attr(attr):
102                 st, sti = term, i
103                 break
104         if not st:
105             raise Rwhois.RwhoisError, (351, "No indexed terms in query clause")
106
107         # remove the search term from the clause, what remains is the
108         # filter.
109         del clause[sti]
110
111         # if we have an attribute name, search on that.
112         if st[0]:
113             res = self.db.search_attr(st[0], st[2], max)
114         else:
115             if Cidr.valid_cidr(st[2].strip("*")):
116                 res = self.db.search_cidr(st[2], max)
117             else:
118                 res = self.db.search_normal(st[2], max)
119
120         objs = self._filter_results(self.db.fetch_objects(res.list()), clause)
121
122         queryres = QueryResult(objs)
123
124         # look for referrals
125         refs = self.process_referral_search(orig_clause)
126         queryres.add_referrals(refs)
127         
128         return queryres
129
130     def _is_in_autharea(self, value):
131         """Returns True if value could be considered to be contained
132         within an authority area.  That is, is a subnet of a
133         network-type authority area or a subdomain of a domainname
134         type authority area."""
135         
136         aas = self.db.get_authareas()
137         
138         if isinstance(value, Cidr.Cidr):
139             for aa in aas:
140                 cv = Cidr.valid_cidr(aa)
141                 if cv and cv.is_supernet(value):
142                     return True
143         else:
144             for aa in aas:
145                 if is_domainname(aa) and is_subdomain(aa, value):
146                     return True
147         return False
148
149     def _referral_search_cidr(self, cv, value):
150         """Return the IndexResult of a referral search for value, or
151         None if the value doesn't qualify for a Cidr referral
152         search."""
153         
154         if not cv: return None
155         if not self._is_in_autharea(cv): return None
156         return self.db.search_referral(value)
157
158     def _referral_search_domain(self, value):
159         """Return the IndexResult of a referral search for value, or
160         None if the value doesn't qualify for a domain referral
161         search."""
162         
163         if not is_domainname(value): return None
164         if not self._is_in_autharea(value): return None
165         dn = value
166         res = None
167         while dn:
168             res = self.db.search_referral(dn)
169             if res.list(): break
170             dn = reduce_domain(dn)
171         return res
172
173     def _referral_search_term(self, value):
174         """Return the IndexResult of a referral search for value, or
175         None if the value didn't qualify for a referral search."""
176         
177         cv = Cidr.valid_cidr(value)
178         if cv:
179             return self._referral_search_cidr(cv, value)
180         elif is_domainname(value):
181             return self._referral_search_domain(value)
182         return None
183         
184     def process_referral_search(self, clause):
185         """Given a query clause, attempt to search for referrals
186         associated with the terms.  Return a list of referral strings
187         that matched terms in the clause (if any).  The only terms
188         that actually get searched are the ones that look
189         'heirarchical'.  For now, the attribute part of the term is
190         essentially ignored, so a search for something like
191         'name=127.0.0.1' might concievably generate a referral, when
192         perhaps it shouldn't."""
193         
194         # first check to see if the search is explictly for a referral
195         for term in clause:
196             if (term[0] == "class-name" and term[1] == "="
197                 and term[2] == "referral") or term[0] == "referred-auth-area":
198                 # in which case, we return nothing
199                 return []
200
201         referrals = []
202
203         # look for heirarchical-looking terms.
204         for attr, op, value in clause:
205             if op == "!=": continue
206             res = self._referral_search_term(value)
207             if not res or not res.list():
208                 continue
209
210             ref_objs = self.db.fetch_objects(res.list())
211             ref_strs = [x for y in ref_objs for x in y.get_attr("referral")]
212             referrals.extend(ref_strs)
213
214         return referrals
215
216         
217     def process_full_query(self, query, max=0):
218         """Given a parsed query object, process it by unioning the
219         results of the various ORed together clauses"""
220
221         # shortcut for the very common single clause case:
222         if len(query.clauses) == 1:
223             res = self.process_query_clause(query.clauses[0], max)
224             return res
225
226         # otherwise, union the results from all the causes
227         res = QueryResult()
228         for clause in query.clauses:
229             res.extend(self.process_query_clause(clause), max)
230             if max and len(res) >= max:
231                 res.truncate(max)
232                 break
233
234         return res
235
236     def process_query(self, session, queryline):
237         """Given a session config and a query line, parse the query,
238         perform any searches, return any referrals."""
239         
240         if not session.queryparser:
241             session.queryparser = QueryParser.get_parser()
242
243         # parse the query
244         try:
245             query = QueryParser.parse(session.queryparser, queryline)
246         except Rwhois.RwhoisError, x:
247             session.wfile.write(Rwhois.error_message(x))
248             return
249         
250         max = session.limit
251         if max: max += 1
252
253         query_result = self.process_full_query(query, max)
254
255         objects   = query_result.objects()
256         referrals = query_result.referrals()
257         
258         if not objects and not referrals:
259             session.wfile.write(Rwhois.error_message(230))
260             # session.wfile.write("\r\n")
261             return
262
263         limit_exceeded = False
264         if session.limit and len(objects) > session.limit:
265             del objects[session.limit:]
266             limit_exceeded = True
267             
268         for obj in objects:
269             session.wfile.write(obj.to_wire_str())
270             session.wfile.write("\r\n")
271
272         if referrals:
273             if objects:
274                 session.wfile.write("\r\n")
275             session.wfile.write("\r\n".join(referrals))
276             session.wfile.write("\r\n")
277                                 
278         if limit_exceeded:
279             session.wfile.write(Rwhois.error_message(330))
280         else:
281             session.wfile.write(Rwhois.ok())
282
283 class QueryResult:
284
285     def __init__(self, objs=[], referrals=[]):
286         self.data  = objs
287         self.ids   = [ x.getid() for x in objs ]
288         self._dict = dict(zip(self.ids, self.ids))
289         self.refs  = referrals
290
291     def extend(self, list):
292         if isinstance(list, type(self)):
293             list = list.objects()
294         new_objs = [ x for x in list if not self._dict.has_key(x.getid()) ]
295         new_ids = [ x.getid() for x in new_objs ]
296         self.data.extend(new_objs)
297         self.ids.extend(new_ids)
298         self._dict.update(dict(zip(new_ids, new_ids)))
299
300     def add_referrals(self, referrals):
301         self.refs.extend(referrals)
302     
303     def objects(self):
304         return self.data
305
306     def referrals(self):
307         return self.refs
308     
309     def ids(self):
310         return self.ids
311
312     def truncate(self, n=0):
313         to_del = self.ids[n:]
314         for i in to_del: del self._dict[i]
315         self.ids = self.ids[:n]
316         self.data = self.data[:n]
317
318         
319 def match_value(searchval, val):
320     """Determine if a search value matches a data value.  If both
321     matching terms are valid CIDR objects, then they are matched
322     according the CIDR wildcard rules (i.e., a single trailing * is a
323     supernet search, ** is a subnet search).  If the search value is
324     not wildcarded, then they are just tested for numeric equality.
325     Otherwise, the terms are compared using string semantics
326     (substring, prefix, suffix, and exact match."""
327
328     if match_cidr(searchval, val): return True
329
330     # normalize the values for comparison.
331     searchval = searchval.lower()
332     val = val.lower()
333
334     # the substring case
335     if searchval.startswith("*") and searchval.endswith("*"):
336         sv = searchval.strip("*");
337         if val.find(sv) >= 0:
338             return True
339         else:
340             return False
341     # the suffix case
342     elif searchval.startswith("*"):
343         sv = searchval.lstrip("*")
344         return val.endswith(sv)
345     # the prefix case
346     elif searchval.endswith("*"):
347         sv = searchval.rstrip("*")
348         return val.startswith(sv)
349     # the exact match case
350     else:
351         return searchval == val
352
353 def match_values(searchval, val_list):
354
355     for val in val_list:
356         if match_value(searchval, val): return True
357     return False
358
359 def match_cidr(searchval, val):
360     """If both terms are valid CIDR values (minus any trailing
361     wildcards of the search value), compare according the CIDR
362     wildcard rules: subnet, supernet, and exact match.  If both terms
363     are not CIDR address, return False."""
364
365
366     sv = Cidr.valid_cidr(searchval.rstrip("*"))
367     rv = Cidr.valid_cidr(val)
368
369     if not sv or not rv: return False
370
371     if (searchval.endswith("**")):
372         return rv.is_subnet(sv)
373     elif (searchval.endswith("*")):
374         return rv.is_supernet(sv)
375     else:
376         return rv == sv
377
378
379 # this forms a pretty basic heuristic to see of a value looks like a
380 # domain name.
381 domain_regex = re.compile("[a-z0-9-]+\.[a-z0-9-.]+", re.I)
382
383 def is_domainname(value):
384     if domain_regex.match(value):
385         return True
386     return False
387
388 def is_subdomain(domain, subdomain):
389     domain = domain.lower();
390     subdomain = subdomain.lower();
391     
392     dlist = domain.split('.')
393     sdlist = subdomain.split('.')
394     
395     if len(dlist) > len(sdlist): return False
396     if len(dlist) == len(sdlist): return domain == subdomain
397
398     dlist.reverse();
399     sdlist.reverse()
400
401     return dlist == sdlist[:len(dlist)]
402
403 def reduce_domain(domain):
404     dlist = domain.split('.')
405     dlist.pop(0)
406     return '.'.join(dlist)
407
408 def is_heirarchical(value):
409     if cidr.valid_cidr(value): return True
410     if is_domainname(value): return True
411     return False
412
413 if __name__ == '__main__':
414
415     import MemDB, Session
416     
417     db = MemDB.MemDB()
418
419     print "loading schema:", sys.argv[1]
420     db.init_schema(sys.argv[1])
421     for data_file in sys.argv[2:]:
422         print "loading data file:", data_file
423         db.load_data(data_file)
424     db.index_data()
425
426     QueryParser.db = db
427     processor = QueryProcessor(db)
428
429     session = Session.Context()
430     session.wfile = sys.stdout
431     
432     while 1:
433         line = sys.stdin.readline().strip();
434         if not line: break
435         if line.startswith("#"): continue
436
437         print "parsing: '%s'" % line
438         processor.process_query(session, line)
439         session.wfile.write("\r\n");
440         session.wfile.flush()