Update TODO based on work for 0.4.1
[python-rwhoisd.git] / rwhoisd / QueryProcessor.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, David E. Blacka
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 2 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
18 # USA
19
20 import sys, re
21 import Cidr, Rwhois, QueryParser
22
23 class QueryProcessor:
24
25     def __init__(self, db):
26         self.db = db
27
28     def _filter_obj_term(self, obj, term):
29         """Given a rwhoisobject and a query term (a 3 element tuple:
30         attr, operator, value), determine if the object satisfies the
31         term.  Returns True if the object matches the term, False if
32         not."""
33
34         attr, op, searchval = term
35         res = False
36
37         # filter by named attribute
38         if attr:
39             vals = obj.get_attr(attr)
40             if not vals:
41                 res = False
42             else:
43                 res = match_values(searchval, vals)
44             if op == "!=": return not res
45             return res
46         # filter by general term
47         else:
48             for val in obj.values():
49                 if match_value(searchval, val):
50                     return True
51             return False
52
53     def _filter_obj(self, obj, terms):
54         """Given a rwhoisobject and a list of query terms (i.e., a
55         whole AND clause), return True if the object satisfies the
56         terms."""
57
58         for term in terms:
59             if not self._filter_obj_term(obj, term): return False
60         return True
61
62     def _filter_results(self, reslist, terms):
63         """Given list of result objects (not simply the ids returned
64         from the search) and a list of query terms (i.e., a query
65         clause), remove elements that do not satisfy the terms.
66         Returns a list of objects that satisfy the filters."""
67
68         if not terms: return reslist
69         return [ x for x in reslist if self._filter_obj(x, terms) ]
70
71     def process_query_clause(self, clause, max=0):
72         """Process a query clause (a grouping of terms ANDed
73         together).  This is where the indexed searches actually get
74         done.  The technique used here is to search on one index and
75         use the rest of the clause to filter the results.  Returns a
76         QueryResult object"""
77
78         # the technique is to do an index search on the first (or
79         # maybe best) indexed term (bare terms are always considered
80         # indexed), and filter those results with the remaining terms.
81
82         # Note: this could be better if we found the "optimal" query
83         # term.  One approach may be to create a cost function and
84         # search for the minimum cost term.
85
86         # Note: another approach might be to actually do indexed
87         # searches on all applicable terms (bare or using an indexed
88         # attribute) and find the intersection of the results.
89
90         st  = None
91         sti = 0
92
93         orig_clause = clause[:]
94         
95         # find the first searchable term:
96         for term, i in zip(clause, xrange(sys.maxint)):
97             attr, op, value = term
98             if op == "!=": continue
99             if not attr or self.db.is_indexed_attr(attr):
100                 st, sti = term, i
101                 break
102         if not st:
103             raise Rwhois.RwhoisError, (351, "No indexed terms in query clause")
104
105         # remove the search term from the clause, what remains is the
106         # filter.
107         del clause[sti]
108
109         # if we have an attribute name, search on that.
110         if st[0]:
111             res = self.db.search_attr(st[0], st[2], max)
112         else:
113             if Cidr.valid_cidr(st[2].strip("*")):
114                 res = self.db.search_cidr(st[2], max)
115             else:
116                 res = self.db.search_normal(st[2], max)
117
118         objs = self._filter_results(self.db.fetch_objects(res.list()), clause)
119
120         queryres = QueryResult(objs)
121
122         # look for referrals
123         refs = self.process_referral_search(orig_clause)
124         queryres.add_referrals(refs)
125         
126         return queryres
127
128     def _is_in_autharea(self, value):
129         """Returns True if value could be considered to be contained
130         within an authority area.  That is, is a subnet of a
131         network-type authority area or a subdomain of a domainname
132         type authority area."""
133         
134         aas = self.db.get_authareas()
135         
136         if isinstance(value, Cidr.Cidr):
137             for aa in aas:
138                 cv = Cidr.valid_cidr(aa)
139                 if cv and cv.is_supernet(value):
140                     return True
141         else:
142             for aa in aas:
143                 if is_domainname(aa) and is_subdomain(aa, value):
144                     return True
145         return False
146
147     def _referral_search_cidr(self, cv, value):
148         """Return the IndexResult of a referral search for value, or
149         None if the value doesn't qualify for a Cidr referral
150         search."""
151         
152         if not cv: return None
153         if not self._is_in_autharea(cv): return None
154         return self.db.search_referral(value)
155
156     def _referral_search_domain(self, value):
157         """Return the IndexResult of a referral search for value, or
158         None if the value doesn't qualify for a domain referral
159         search."""
160         
161         if not is_domainname(value): return None
162         if not self._is_in_autharea(value): return None
163         dn = value
164         res = None
165         while dn:
166             res = self.db.search_referral(dn)
167             if res.list(): break
168             dn = reduce_domain(dn)
169         return res
170
171     def _referral_search_term(self, value):
172         """Return the IndexResult of a referral search for value, or
173         None if the value didn't qualify for a referral search."""
174         
175         cv = Cidr.valid_cidr(value)
176         if cv:
177             return self._referral_search_cidr(cv, value)
178         elif is_domainname(value):
179             return self._referral_search_domain(value)
180         return None
181         
182     def process_referral_search(self, clause):
183         """Given a query clause, attempt to search for referrals
184         associated with the terms.  Return a list of referral strings
185         that matched terms in the clause (if any).  The only terms
186         that actually get searched are the ones that look
187         'heirarchical'.  For now, the attribute part of the term is
188         essentially ignored, so a search for something like
189         'name=127.0.0.1' might concievably generate a referral, when
190         perhaps it shouldn't."""
191         
192         # first check to see if the search is explictly for a referral
193         for term in clause:
194             if (term[0] == "class-name" and term[1] == "="
195                 and term[2] == "referral") or term[0] == "referred-auth-area":
196                 # in which case, we return nothing
197                 return []
198
199         referrals = []
200
201         # look for heirarchical-looking terms.
202         for attr, op, value in clause:
203             if op == "!=": continue
204             res = self._referral_search_term(value)
205             if not res or not res.list():
206                 continue
207
208             ref_objs = self.db.fetch_objects(res.list())
209             ref_strs = [x for y in ref_objs for x in y.get_attr("referral")]
210             referrals.extend(ref_strs)
211
212         return referrals
213
214         
215     def process_full_query(self, query, max=0):
216         """Given a parsed query object, process it by unioning the
217         results of the various ORed together clauses"""
218
219         # shortcut for the very common single clause case:
220         if len(query.clauses) == 1:
221             res = self.process_query_clause(query.clauses[0], max)
222             return res
223
224         # otherwise, union the results from all the causes
225         res = QueryResult()
226         for clause in query.clauses:
227             res.extend(self.process_query_clause(clause), max)
228             if max and len(res) >= max:
229                 res.truncate(max)
230                 break
231
232         return res
233
234     def process_query(self, session, queryline):
235         """Given a session config and a query line, parse the query,
236         perform any searches, return any referrals."""
237         
238         if not session.queryparser:
239             session.queryparser = QueryParser.get_parser()
240
241         # parse the query
242         try:
243             query = QueryParser.parse(session.queryparser, queryline)
244         except Rwhois.RwhoisError, x:
245             session.wfile.write(Rwhois.error_message(x))
246             return
247         
248         max = session.limit
249         if max: max += 1
250
251         query_result = self.process_full_query(query, max)
252
253         objects   = query_result.objects()
254         referrals = query_result.referrals()
255         
256         if not objects and not referrals:
257             session.wfile.write(Rwhois.error_message(230))
258             # session.wfile.write("\r\n")
259             return
260
261         limit_exceeded = False
262         if session.limit and len(objects) > session.limit:
263             del objects[session.limit:]
264             limit_exceeded = True
265             
266         for obj in objects:
267             session.wfile.write(obj.to_wire_str())
268             session.wfile.write("\r\n")
269
270         if referrals:
271             if objects:
272                 session.wfile.write("\r\n")
273             session.wfile.write("\r\n".join(referrals))
274             session.wfile.write("\r\n")
275                                 
276         if limit_exceeded:
277             session.wfile.write(Rwhois.error_message(330))
278         else:
279             session.wfile.write(Rwhois.ok())
280
281 class QueryResult:
282
283     def __init__(self, objs=[], referrals=[]):
284         self.data  = objs
285         self.ids   = [ x.getid() for x in objs ]
286         self._dict = dict(zip(self.ids, self.ids))
287         self.refs  = referrals
288
289     def extend(self, list):
290         if isinstance(list, type(self)):
291             list = list.objects()
292         new_objs = [ x for x in list if not self._dict.has_key(x.getid()) ]
293         new_ids = [ x.getid() for x in new_objs ]
294         self.data.extend(new_objs)
295         self.ids.extend(new_ids)
296         self._dict.update(dict(zip(new_ids, new_ids)))
297
298     def add_referrals(self, referrals):
299         self.refs.extend(referrals)
300     
301     def objects(self):
302         return self.data
303
304     def referrals(self):
305         return self.refs
306     
307     def ids(self):
308         return self.ids
309
310     def truncate(self, n=0):
311         to_del = self.ids[n:]
312         for i in to_del: del self._dict[i]
313         self.ids = self.ids[:n]
314         self.data = self.data[:n]
315
316         
317 def match_value(searchval, val):
318     """Determine if a search value matches a data value.  If both
319     matching terms are valid CIDR objects, then they are matched
320     according the CIDR wildcard rules (i.e., a single trailing * is a
321     supernet search, ** is a subnet search).  If the search value is
322     not wildcarded, then they are just tested for numeric equality.
323     Otherwise, the terms are compared using string semantics
324     (substring, prefix, suffix, and exact match."""
325
326     if match_cidr(searchval, val): return True
327
328     # normalize the values for comparison.
329     searchval = searchval.lower()
330     val = val.lower()
331
332     # the substring case
333     if searchval.startswith("*") and searchval.endswith("*"):
334         sv = searchval.strip("*");
335         if val.find(sv) >= 0:
336             return True
337         else:
338             return False
339     # the suffix case
340     elif searchval.startswith("*"):
341         sv = searchval.lstrip("*")
342         return val.endswith(sv)
343     # the prefix case
344     elif searchval.endswith("*"):
345         sv = searchval.rstrip("*")
346         return val.startswith(sv)
347     # the exact match case
348     else:
349         return searchval == val
350
351 def match_values(searchval, val_list):
352
353     for val in val_list:
354         if match_value(searchval, val): return True
355     return False
356
357 def match_cidr(searchval, val):
358     """If both terms are valid CIDR values (minus any trailing
359     wildcards of the search value), compare according the CIDR
360     wildcard rules: subnet, supernet, and exact match.  If both terms
361     are not CIDR address, return False."""
362
363
364     sv = Cidr.valid_cidr(searchval.rstrip("*"))
365     rv = Cidr.valid_cidr(val)
366
367     if not sv or not rv: return False
368
369     if (searchval.endswith("**")):
370         return rv.is_subnet(sv)
371     elif (searchval.endswith("*")):
372         return rv.is_supernet(sv)
373     else:
374         return rv == sv
375
376
377 # this forms a pretty basic heuristic to see of a value looks like a
378 # domain name.
379 domain_regex = re.compile("[a-z0-9-]+\.[a-z0-9-.]+", re.I)
380
381 def is_domainname(value):
382     if domain_regex.match(value):
383         return True
384     return False
385
386 def is_subdomain(domain, subdomain):
387     domain = domain.lower();
388     subdomain = subdomain.lower();
389     
390     dlist = domain.split('.')
391     sdlist = subdomain.split('.')
392     
393     if len(dlist) > len(sdlist): return False
394     if len(dlist) == len(sdlist): return domain == subdomain
395
396     dlist.reverse();
397     sdlist.reverse()
398
399     return dlist == sdlist[:len(dlist)]
400
401 def reduce_domain(domain):
402     dlist = domain.split('.')
403     dlist.pop(0)
404     return '.'.join(dlist)
405
406 def is_heirarchical(value):
407     if cidr.valid_cidr(value): return True
408     if is_domainname(value): return True
409     return False
410
411 if __name__ == '__main__':
412
413     import MemDB, Session
414     
415     db = MemDB.MemDB()
416
417     print "loading schema:", sys.argv[1]
418     db.init_schema(sys.argv[1])
419     for data_file in sys.argv[2:]:
420         print "loading data file:", data_file
421         db.load_data(data_file)
422     db.index_data()
423
424     QueryParser.db = db
425     processor = QueryProcessor(db)
426
427     session = Session.Context()
428     session.wfile = sys.stdout
429     
430     while 1:
431         line = sys.stdin.readline().strip();
432         if not line: break
433         if line.startswith("#"): continue
434
435         print "parsing: '%s'" % line
436         processor.process_query(session, line)
437         session.wfile.write("\r\n");
438         session.wfile.flush()