e3bcd75f6c58ba71a8b4d9d7eb98b5114ddad822
[python-rwhoisd.git] / rwhoisd / QueryProcessor.py
1 import sys, re
2 import Cidr, Rwhois, QueryParser
3
4 class QueryProcessor:
5
6     def __init__(self, db):
7         self.db = db
8
9     def _filter_obj_term(self, obj, term):
10         """Given a rwhoisobject and a query term (a 3 element tuple:
11         attr, operator, value), determine if the object satisfies the
12         term.  Returns True if the object matches the term, False if
13         not."""
14
15         attr, op, searchval = term
16         res = False
17
18         # filter by named attribute
19         if attr:
20             vals = obj.get_attr(attr)
21             if not vals:
22                 res = False
23             else:
24                 res = match_values(searchval, vals)
25             if op == "!=": return not res
26             return res
27         # filter by general term
28         else:
29             for val in obj.values():
30                 if match_value(searchval, val):
31                     return True
32             return False
33
34     def _filter_obj(self, obj, terms):
35         """Given a rwhoisobject and a list of query terms (i.e., a
36         whole AND clause), return True if the object satisfies the
37         terms."""
38
39         for term in terms:
40             if not self._filter_obj_term(obj, term): return False
41         return True
42
43     def _filter_results(self, reslist, terms):
44         """Given list of result objects (not simply the ids returned
45         from the search) and a list of query terms (i.e., a query
46         clause), remove elements that do not satisfy the terms.
47         Returns a list of objects that satisfy the filters."""
48
49         if not terms: return reslist
50         return [ x for x in reslist if self._filter_obj(x, terms) ]
51
52     def process_query_clause(self, clause, max=0):
53         """Process a query clause (a grouping of terms ANDed
54         together).  This is where the indexed searches actually get
55         done.  The technique used here is to search on one index and
56         use the rest of the clause to filter the results.  Returns a
57         QueryResult object"""
58
59         # the technique is to do an index search on the first (or
60         # maybe best) indexed term (bare terms are always considered
61         # indexed), and filter those results with the remaining terms.
62
63         # Note: this could be better if we found the "optimal" query
64         # term.  One approach may be to create a cost function and
65         # search for the minimum cost term.
66
67         # Note: another approach might be to actually do indexed
68         # searches on all applicable terms (bare or using an indexed
69         # attribute) and find the intersection of the results.
70
71         st  = None
72         sti = 0
73
74         orig_clause = clause[:]
75         
76         # find the first searchable term:
77         for term, i in zip(clause, xrange(sys.maxint)):
78             attr, op, value = term
79             if op == "!=": continue
80             if not attr or self.db.is_indexed_attr(attr):
81                 st, sti = term, i
82                 break
83         if not st:
84             raise Rwhois.RwhoisError, (351, "No indexed terms in query clause")
85
86         # remove the search term from the clause, what remains is the
87         # filter.
88         del clause[sti]
89
90         # if we have an attribute name, search on that.
91         if st[0]:
92             res = self.db.search_attr(st[0], st[2], max)
93         else:
94             if Cidr.valid_cidr(st[2].strip("*")):
95                 res = self.db.search_cidr(st[2], max)
96             else:
97                 res = self.db.search_normal(st[2], max)
98
99         objs = self._filter_results(self.db.fetch_objects(res.list()), clause)
100
101         queryres = QueryResult(objs)
102
103         # look for referrals
104         refs = self.process_referral_search(orig_clause)
105         queryres.add_referrals(refs)
106         
107         return queryres
108
109     def _is_in_autharea(self, value):
110         """Returns True if value could be considered to be contained
111         within an authority area.  That is, is a subnet of a
112         network-type authority area or a subdomain of a domainname
113         type authority area."""
114         
115         aas = self.db.get_authareas()
116         
117         if isinstance(value, Cidr.Cidr):
118             for aa in aas:
119                 cv = Cidr.valid_cidr(aa)
120                 if cv and cv.is_supernet(value):
121                     return True
122         else:
123             for aa in aas:
124                 if is_domainname(aa) and is_subdomain(aa, value):
125                     return True
126         return False
127
128     def _referral_search_cidr(self, cv, value):
129         """Return the IndexResult of a referral search for value, or
130         None if the value doesn't qualify for a Cidr referral
131         search."""
132         
133         if not cv: return None
134         if not self._is_in_autharea(cv): return None
135         return self.db.search_referral(value)
136
137     def _referral_search_domain(self, value):
138         """Return the IndexResult of a referral search for value, or
139         None if the value doesn't qualify for a domain referral
140         search."""
141         
142         if not is_domainname(value): return None
143         if not self._is_in_autharea(value): return None
144         dn = value
145         res = None
146         while dn:
147             res = self.db.search_referral(dn)
148             if res.list(): break
149             dn = reduce_domain(dn)
150         return res
151
152     def _referral_search_term(self, value):
153         """Return the IndexResult of a referral search for value, or
154         None if the value didn't qualify for a referral search."""
155         
156         cv = Cidr.valid_cidr(value)
157         if cv:
158             return self._referral_search_cidr(cv, value)
159         elif is_domainname(value):
160             return self._referral_search_domain(value)
161         return None
162         
163     def process_referral_search(self, clause):
164         """Given a query clause, attempt to search for referrals
165         associated with the terms.  Return a list of referral strings
166         that matched terms in the clause (if any).  The only terms
167         that actually get searched are the ones that look
168         'heirarchical'.  For now, the attribute part of the term is
169         essentially ignored, so a search for something like
170         'name=127.0.0.1' might concievably generate a referral, when
171         perhaps it shouldn't."""
172         
173         # first check to see if the search is explictly for a referral
174         for term in clause:
175             if (term[0] == "class-name" and term[1] == "="
176                 and term[2] == "referral") or term[0] == "referred-auth-area":
177                 # in which case, we return nothing
178                 return []
179
180         referrals = []
181
182         # look for heirarchical-looking terms.
183         for attr, op, value in clause:
184             if op == "!=": continue
185             res = self._referral_search_term(value)
186             if not res or not res.list():
187                 continue
188
189             ref_objs = self.db.fetch_objects(res.list())
190             ref_strs = [x for y in ref_objs for x in y.get_attr("referral")]
191             referrals.extend(ref_strs)
192
193         return referrals
194
195         
196     def process_full_query(self, query, max=0):
197         """Given a parsed query object, process it by unioning the
198         results of the various ORed together clauses"""
199
200         # shortcut for the very common single clause case:
201         if len(query.clauses) == 1:
202             res = self.process_query_clause(query.clauses[0])
203             return res
204
205         # otherwise, union the results from all the causes
206         res = QueryResult()
207         for clause in query.clauses:
208             res.extend(self.process_query_clause(clause))
209             if max and len(res) >= max:
210                 res.truncate(max)
211                 break
212
213         return res
214
215     def process_query(self, session, queryline):
216         """Given a session config and a query line, parse the query,
217         perform any searches, return any referrals."""
218         
219         if not session.queryparser:
220             session.queryparser = QueryParser.get_parser()
221
222         # parse the query
223         try:
224             query = QueryParser.parse(session.queryparser, queryline)
225         except Rwhois.RwhoisError, x:
226             session.wfile.write(Rwhois.error_message(x))
227             return
228         
229         max = session.limit
230         if max: max += 1
231
232         query_result = self.process_full_query(query, max)
233
234         objects   = query_result.objects()
235         referrals = query_result.referrals()
236         
237         if not objects and not referrals:
238             session.wfile.write(Rwhois.error_message(230))
239             # session.wfile.write("\r\n")
240             return
241
242         for obj in objects:
243             session.wfile.write(obj.to_wire_str())
244             session.wfile.write("\r\n")
245
246         if referrals:
247             if objects:
248                 session.wfile.write("\r\n")
249             session.wfile.write("\r\n".join(referrals))
250             session.wfile.write("\r\n")
251                                 
252         if session.limit and len(objects) > session.limit:
253             session.wfile.write(330)
254         else:
255             session.wfile.write(Rwhois.ok())
256
257 class QueryResult:
258
259     def __init__(self, objs=[], referrals=[]):
260         self.data  = objs
261         self.ids   = [ x.getid() for x in objs ]
262         self._dict = dict(zip(self.ids, self.ids))
263         self.refs  = referrals
264
265     def extend(self, list):
266         if isinstance(list, type(self)):
267             list = list.objects()
268         new_objs = [ x for x in list if not self._dict.has_key(x.getid()) ]
269         new_ids = [ x.getid() for x in new_objs ]
270         self.data.extend(new_objs)
271         self.ids.extend(new_ids)
272         self._dict.update(dict(zip(new_ids, new_ids)))
273
274     def add_referrals(self, referrals):
275         self.refs.extend(referrals)
276     
277     def objects(self):
278         return self.data
279
280     def referrals(self):
281         return self.refs
282     
283     def ids(self):
284         return self.ids
285
286     def truncate(self, n=0):
287         to_del = self.ids[n:]
288         for i in to_del: del self._dict[i]
289         self.ids = self.ids[:n]
290         self.data = self.data[:n]
291
292         
293 def match_value(searchval, val):
294     """Determine if a search value matches a data value.  If both
295     matching terms are valid CIDR objects, then they are matched
296     according the CIDR wildcard rules (i.e., a single trailing * is a
297     supernet search, ** is a subnet search).  If the search value is
298     not wildcarded, then they are just tested for numeric equality.
299     Otherwise, the terms are compared using string semantics
300     (substring, prefix, suffix, and exact match."""
301
302     if match_cidr(searchval, val): return True
303
304     # normalize the values for comparison.
305     searchval = searchval.lower()
306     val = val.lower()
307
308     # the substring case
309     if searchval.startswith("*") and searchval.endswith("*"):
310         sv = searchval.strip("*");
311         if val.find(sv) >= 0:
312             return True
313         else:
314             return False
315     # the suffix case
316     elif searchval.startswith("*"):
317         sv = searchval.lstrip("*")
318         return val.endswith(sv)
319     # the prefix case
320     elif searchval.endswith("*"):
321         sv = searchval.rstrip("*")
322         return val.startswith(sv)
323     # the exact match case
324     else:
325         return searchval == val
326
327 def match_values(searchval, val_list):
328
329     for val in val_list:
330         if match_value(searchval, val): return True
331     return False
332
333 def match_cidr(searchval, val):
334     """If both terms are valid CIDR values (minus any trailing
335     wildcards of the search value), compare according the CIDR
336     wildcard rules: subnet, supernet, and exact match.  If both terms
337     are not CIDR address, return False."""
338
339
340     sv = Cidr.valid_cidr(searchval.rstrip("*"))
341     rv = Cidr.valid_cidr(val)
342
343     if not sv or not rv: return False
344
345     if (searchval.endswith("**")):
346         return rv.is_subnet(sv)
347     elif (searchval.endswith("*")):
348         return rv.is_supernet(sv)
349     else:
350         return rv == sv
351
352
353 # this forms a pretty basic heuristic to see of a value looks like a
354 # domain name.
355 domain_regex = re.compile("[a-z0-9-]+\.[a-z0-9-.]+", re.I)
356
357 def is_domainname(value):
358     if domain_regex.match(value):
359         return True
360     return False
361
362 def is_subdomain(domain, subdomain):
363     domain = domain.lower();
364     subdomain = subdomain.lower();
365     
366     dlist = domain.split('.')
367     sdlist = subdomain.split('.')
368     
369     if len(dlist) > len(sdlist): return False
370     if len(dlist) == len(sdlist): return domain == subdomain
371
372     dlist.reverse();
373     sdlist.reverse()
374
375     return dlist == sdlist[:len(dlist)]
376
377 def reduce_domain(domain):
378     dlist = domain.split('.')
379     dlist.pop(0)
380     return '.'.join(dlist)
381
382 def is_heirarchical(value):
383     if cidr.valid_cidr(value): return True
384     if is_domainname(value): return True
385     return False
386
387 if __name__ == '__main__':
388
389     import MemDB, Session
390     
391     db = MemDB.MemDB()
392
393     print "loading schema:", sys.argv[1]
394     db.init_schema(sys.argv[1])
395     for data_file in sys.argv[2:]:
396         print "loading data file:", data_file
397         db.load_data(data_file)
398     db.index_data()
399
400     QueryParser.db = db
401     processor = QueryProcessor(db)
402
403     session = Session.Context()
404     session.wfile = sys.stdout
405     
406     while 1:
407         line = sys.stdin.readline().strip();
408         if not line: break
409         if line.startswith("#"): continue
410
411         print "parsing: '%s'" % line
412         processor.process_query(session, line)
413         session.wfile.write("\r\n");
414         session.wfile.flush()