Initial revision
[python-rwhoisd.git] / rwhoisd / QueryProcessor.py
1 import sys
2 import Cidr, Rwhois, QueryParser
3
4 class QueryProcessor:
5
6     def __init__(self, db):
7         self.db = db
8
9     def _filter_obj_term(self, obj, term):
10         """Given a rwhoisobject and a query term (a 3 element tuple:
11         attr, operator, value), determine if the object satisfies the
12         term.  Returns True if the object matches the term, False if
13         not."""
14
15         attr, op, searchval = term
16         res = False
17
18         # filter by named attribute
19         if attr:
20             vals = obj.get_attr(attr)
21             if not vals:
22                 res = False
23             else:
24                 res = match_values(searchval, vals)
25             if op == "!=": return not res
26             return res
27         # filter by general term
28         else:
29             for val in obj.values():
30                 if match_value(searchval, val):
31                     return True
32             return False
33
34     def _filter_obj(self, obj, terms):
35         """Given a rwhoisobject and a list of query terms (i.e., a
36         whole AND clause), return True if the object satisfies the
37         terms."""
38
39         for term in terms:
40             if not self._filter_obj_term(obj, term): return False
41         return True
42
43     def _filter_results(self, reslist, terms):
44         """Given list of result objects (not simply the ids returned
45         from the search) and a list of query terms (i.e., a query
46         clause), remove elements that do not satisfy the terms.
47         Returns a list of objects that satisfy the filters."""
48
49         if not terms: return reslist
50         return [ x for x in reslist if self._filter_obj(x, terms) ]
51
52     def process_query_clause(self, clause, max=0):
53         """Process a query clause (a grouping of terms ANDed
54         together).  This is where the indexed searches actually get
55         done.  The technique used here is to search on one index and
56         use the rest of the clause to filter the results.  Returns a
57         QueryResult object"""
58
59         # the technique is to do an index search on the first (or
60         # maybe best) indexed term (bare terms are always considered
61         # indexed), and filter those results with the remaining terms.
62
63         # Note: this could be better if we found the "optimal" query
64         # term.  One approach may be to create a cost function and
65         # search for the minimum cost term.
66
67         # Note: another approach might be to actually do indexed
68         # searches on all applicable terms (bare or using an indexed
69         # attribute) and find the intersection of the results.
70
71         # FIXME: need to put in the referral chasing logic here, I
72         # think.
73         
74         st  = None
75         sti = 0
76
77         # find the first searchable term:
78         for term, i in zip(clause, xrange(sys.maxint)):
79             attr, op, value = term
80             if op == "!=": continue
81             if not attr or self.db.is_indexed_attr(attr):
82                 st, sti = term, i
83                 break
84         if not st:
85             raise Rwhois.RwhoisError, (351, "No indexed terms in query clause")
86
87         # remove the search term from the clause, what remains is the
88         # filter.
89         del clause[sti]
90
91         # if we have an attribute name, search on that.
92         if st[0]:
93             res = self.db.search_attr(st[0], st[2], max)
94         else:
95             if Cidr.valid_cidr(st[2].strip("*")):
96                 res = self.db.search_cidr(st[2], max)
97             else:
98                 res = self.db.search_normal(st[2], max)
99
100         objs = self._filter_results(self.db.fetch_objects(res.list()), clause)
101
102         return QueryResult(objs)
103
104     def process_full_query(self, query, max=0):
105         """Given a parsed query object, process it by unioning the
106         results of the various ORed together clauses"""
107
108         # shortcut for the very common single clause case:
109         if len(query.clauses) == 1:
110             res = self.process_query_clause(query.clauses[0])
111             return res
112
113         res = QueryResult()
114         for clause in query.clauses:
115             res.extend(self.process_query_clause(clause))
116             if max and len(res) >= max:
117                 res.truncate(max)
118                 break
119
120         return res
121
122     def process_query(self, session, queryline):
123         """Given a session config and a query line, parse the query,
124         perform any searches, return any referrals."""
125         
126         if not session.queryparser:
127             session.queryparser = QueryParser.get_parser()
128
129         # parse the query
130         try:
131             query = QueryParser.parse(session.queryparser, queryline)
132         except Rwhois.RwhoisError, x:
133             session.wfile.write(Rwhois.error_message(x))
134             return
135         
136         max = session.limit
137         if max: max += 1
138
139         query_result = self.process_full_query(query, max)
140
141         objects   = query_result.objects()
142         referrals = query_result.referrals()
143         
144         if not objects and not referrals:
145             session.wfile.write(Rwhois.error_message(230))
146             # session.wfile.write("\r\n")
147             return
148
149         for obj in objects:
150             session.wfile.write(obj.to_wire_str())
151             session.wfile.write("\r\n")
152
153         if referrals:
154             session.wfile.write("\r\n".join(referrals))
155             session.wfile.write("\r\n")
156                                 
157         if session.limit and len(objects) > session.limit:
158             session.wfile.write(330)
159         else:
160             session.wfile.write(Rwhois.ok())
161
162 class QueryResult:
163
164     def __init__(self, objs=[], referrals=[]):
165         self.data  = objs
166         self.ids   = [ x.getid() for x in objs ]
167         self._dict = dict(zip(self.ids, self.ids))
168         self.refs  = referrals
169
170     def extend(self, list):
171         if isinstance(list, type(self)):
172             list = list.objects()
173         new_objs = [ x for x in list if not self._dict.has_key(x.getid()) ]
174         new_ids = [ x.getid() for x in new_objs ]
175         self.data.extend(new_objs)
176         self.ids.extend(new_ids)
177         self._dict.update(dict(zip(new_ids, new_ids)))
178
179     def add_referrals(self, referrals):
180         self.refs.extend(referrals)
181     
182     def objects(self):
183         return self.data
184
185     def referrals(self):
186         return self.refs
187     
188     def ids(self):
189         return self.ids
190
191     def truncate(self, n=0):
192         to_del = self.ids[n:]
193         for i in to_del: del self._dict[i]
194         self.ids = self.ids[:n]
195         self.data = self.data[:n]
196
197         
198 def match_value(searchval, val):
199     """Determine if a search value matches a data value.  If both
200     matching terms are valid CIDR objects, then they are matched
201     according the CIDR wildcard rules (i.e., a single trailing * is a
202     supernet search, ** is a subnet search).  If the search value is
203     not wildcarded, then they are just tested for numeric equality.
204     Otherwise, the terms are compared using string semantics
205     (substring, prefix, suffix, and exact match."""
206
207     if match_cidr(searchval, val): return True
208
209     # normalize the values for comparison.
210     searchval = searchval.lower()
211     val = val.lower()
212
213     # the substring case
214     if searchval.startswith("*") and searchval.endswith("*"):
215         sv = searchval.strip("*");
216         if val.find(sv) >= 0:
217             return True
218         else:
219             return False
220     # the suffix case
221     elif searchval.startswith("*"):
222         sv = searchval.lstrip("*")
223         return val.endswith(sv)
224     # the prefix case
225     elif searchval.endswith("*"):
226         sv = searchval.rstrip("*")
227         return val.startswith(sv)
228     # the exact match case
229     else:
230         return searchval == val
231
232 def match_values(searchval, val_list):
233
234     for val in val_list:
235         if match_value(searchval, val): return True
236     return False
237
238 def match_cidr(searchval, val):
239     """If both terms are valid CIDR values (minus any trailing
240     wildcards of the search value), compare according the CIDR
241     wildcard rules: subnet, supernet, and exact match.  If both terms
242     are not CIDR address, return False."""
243
244
245     sv = Cidr.valid_cidr(searchval.rstrip("*"))
246     rv = Cidr.valid_cidr(val)
247
248     if not sv or not rv: return False
249
250     if (searchval.endswith("**")):
251         return rv.is_subnet(sv)
252     elif (searchval.endswith("*")):
253         return rv.is_supernet(sv)
254     else:
255         return rv == sv
256
257
258 if __name__ == '__main__':
259
260     import MemDB, Session
261     
262     db = MemDB.MemDB()
263
264     print "loading schema:", sys.argv[1]
265     db.init_schema(sys.argv[1])
266     for data_file in sys.argv[2:]:
267         print "loading data file:", data_file
268         db.load_data(data_file)
269     db.index_data()
270
271     QueryParser.db = db
272     processor = QueryProcessor(db)
273
274     session = Session.Context()
275     session.wfile = sys.stdout
276     
277     while 1:
278         line = sys.stdin.readline().strip();
279         if not line: break
280         if line.startswith("#"): continue
281
282         print "parsing: '%s'" % line
283         processor.process_query(session, line)
284         session.wfile.write("\r\n");
285         session.wfile.flush()