add a db method to get the list of authority areas (used in referral processing)
[python-rwhoisd.git] / rwhoisd / MemDB.py
1 import bisect, types
2 import MemIndex, Cidr
3 from Rwhois import rwhoisobject
4
5 class MemDB:
6
7     def __init__(self):
8
9         # a dictonary holding the various attribute indexes.  The keys
10         # are lowercase attribute names, values are MemIndex or
11         # CidrMemIndex objects.
12         self.indexes = {}
13
14         # a dictonary holding the actual rwhoisobjects.  keys are
15         # string IDs, values are rwhoisobject instances.
16         self.main_index = {}
17
18         # dictonary holding all of the seen attributes.  keys are
19         # lowercase attribute names, value is a character indicating
20         # the index type (if indexed), or None if not indexed.  Index
21         # type characters a 'N' for normal string index, 'C' for CIDR
22         # index.
23         self.attrs = {}
24
25         # Lists containing attribute names that have indexes by type.
26         # This exists so unconstrained searches can just iterate over
27         # them.
28         self.normal_indexes = []
29         self.cidr_indexes   = []
30
31         # dictonary holding all of the seen class names.  keys are
32         # lowercase classnames, value is always None.
33         self.classes = {}
34
35         # dictionary holding all of the seen auth-areas.  keys are
36         # lowercase authority area names, value is always None.
37         self.authareas = {}
38
39     def init_schema(self, schema_file):
40         """Initialize the schema from a schema file.  Currently the
41         schema file is a list of 'attribute_name = index_type' pairs,
42         one per line.  index_type is one of N or C, where N means a
43         normal string index, and C means a CIDR index.
44
45         It should be noted that this database implementation
46         implements a global namespace for attributes, which isn't
47         really correct according to RFC 2167.  RFC 2167 dictates that
48         different authority area are actually autonomous and thus have
49         separate schemas."""
50
51         # initialize base schema
52
53         self.attrs['id']         = "N"
54         self.attrs['auth-area']  = None
55         self.attrs['class-name'] = None
56         self.attrs['updated']    = None
57         self.attrs['referred-auth-area'] = "R"
58
59         sf = open(schema_file, "r")
60
61         for line in sf.xreadlines():
62             line = line.strip()
63             if not line or line.startswith("#"): continue
64
65             attr, it = line.split("=")
66             self.attrs[attr.strip().lower()] = it.strip()[0].upper()
67
68         for attr, index_type in self.attrs.items():
69             if index_type == "N":
70                 # normal index
71                 self.indexes[attr] = MemIndex.MemIndex()
72                 self.normal_indexes.append(attr)
73             elif index_type == "A":
74                 # "all" index -- both a normal and a cidr index
75                 self.indexes[attr] = MemIndex.ComboMemIndex()
76                 self.normal_indexes.append(attr)
77                 self.cidr_indexes.append(attr)
78             elif index_type == "R":
79                 # referral index, an all index that must be searched
80                 # explictly by attribute
81                 self.indexes[attr] = MemIndex.ComboMemIndex()
82             elif index_type == "C":
83                 # a cidr index
84                 self.indexes[attr] = MemIndex.CidrMemIndex()
85                 self.cidr_indexes.append(attr)
86         return
87
88     def add_object(self, obj):
89         """Add an rwhoisobject to the raw indexes, including the
90         master index."""
91
92         # add the object to the main index
93         id = obj.getid()
94         if not id: return
95         id = id.lower()
96
97         self.main_index[id] = obj
98
99         for a,v in obj.items():
100             # note the attribute.
101             index_type = self.attrs.setdefault(a, None)
102             v = v.lower()
103             # make sure that we note the auth-area and class
104             if a == 'auth-area':
105                 self.authareas.setdefault(v, None)
106             elif a == 'class-name':
107                 self.classes.setdefault(v, None)
108
109             if index_type:
110                 index = self.indexes[a]
111                 index.add(v, id)
112
113     def load_data(self, data_file):
114         """Load data from rwhoisd-style TXT files (i.e., attr:value,
115         records separated with a "---" bare line)."""
116
117         df = open(data_file, "r")
118         obj = rwhoisobject()
119
120         for line in df.xreadlines():
121             line = line.strip()
122             if line.startswith("#"): continue
123             if not line or line.startswith("---"):
124                 # we've reached the end of an object, so index it.
125                 self.add_object(obj)
126                 # reset obj
127                 obj = rwhoisobject()
128                 continue
129
130             a, v = line.split(":", 1)
131             obj.add_attr(a, v.lstrip())
132
133         self.add_object(obj)
134         return
135
136     def index_data(self):
137         """Prepare the indexes for searching.  Currently, this isn't
138         strictly necessary (the indexes will prepare themselves when
139         necessary), but it should elminate a penalty on initial
140         searches"""
141
142         for i in self.indexes.values():
143             i.prepare()
144         return
145
146     def is_attribute(self, attr):
147         return self.attrs.has_key(attr.lower())
148
149     def is_indexed_attr(self, attr):
150         if self.is_attribute(attr):
151             return self.attrs[attr.lower()]
152         return False
153
154     def is_objectclass(self, objectclass):
155         return self.classes.has_key(objectclass.lower())
156
157     def is_autharea(self, aa):
158         return self.authareas.has_key(aa.lower())
159
160     def get_authareas(self):
161         return self.authareas.keys()
162     
163     def fetch_objects(self, id_list):
164         return [ self.main_index[x] for x in id_list
165                  if self.main_index.has_key(x) ]
166
167     def search_attr(self, attr, value, max = 0):
168
169         """Search for a value in a particular attribute's index.  If
170         the attribute is cidr indexed, an attempt to convert value
171         into a Cidr object will be made.  Returns a list of object ids
172         (or an empty list if nothing was found)"""
173
174         attr = attr.lower()
175         index_type = self.attrs.get(attr)
176         index = self.indexes.get(attr)
177         if not index: return []
178
179         super_prefix_match = False
180         if value.endswith("**"):
181             super_prefix_match = True
182
183         prefix_match = False
184         if value.endswith("*"):
185             value = value.rstrip("*")
186             prefix_match = True
187
188         if index_type == 'C' and not isinstance(value, Cidr.Cidr):
189             value = Cidr.valid_cidr(value)
190         else:
191             value = value.strip().lower()
192
193         if index_type == 'C' and super_prefix_match:
194             return index.find_subnets(value, max)
195
196         res = index.find(value, prefix_match, max)
197         return IndexResult(res)
198
199     def search_normal(self, value, max = 0):
200         """Search for a value in the 'normal' (string keyed) indexes.
201         Returns a list of object ids, or an empty list if nothing was
202         found."""
203
204         res = IndexResult()
205
206         for attr in self.normal_indexes:
207             res.extend(self.search_attr(attr, value, max))
208             if max:
209                 if len(res) >= max:
210                     res.truncate(max)
211                     return res
212         return res
213
214     def search_cidr(self, value, max = 0):
215         """Search for a value in the cidr indexes.  Returns a list of
216         object ids, or an empty list if nothing was found."""
217
218         res = IndexResult()
219         for attr in self.cidr_indexes:
220             res.extend(self.search_attr(attr, value, max))
221             if max:
222                 if len(res) >= max:
223                     res.truncate(max)
224                     return res
225         return res
226
227     def search_referral(self, value, max = 0):
228         """Given a heirarchal value, search for referrals.  Returns a
229         list of object ids or an empty list."""
230
231         return self.search_attr("referred-auth-area", value, max)
232
233     def object_iterator(self):
234         return self.main_index.itervalues()
235
236 class IndexResult:
237     def __init__(self, list=None):
238         if not list: list = []
239         self.data = list
240         self._dict = dict(zip(self.data, self.data))
241
242     def extend(self, list):
243         if isinstance(list, type(self)):
244             list = list.list()
245         new_els = [ x for x in list if not self._dict.has_key(x) ]
246         self.data.extend(new_els)
247         self._dict.update(dict(zip(new_els, new_els)))
248
249     def list(self):
250         return self.data
251
252     def truncate(self, n=0):
253         to_del = self.data[n:]
254         for i in to_del: del self._dict[i]
255         self.data = self.data[:n]
256
257
258 # test driver
259 if __name__ == "__main__":
260     import sys
261     db = MemDB()
262
263     print "loading schema:", sys.argv[1]
264     db.init_schema(sys.argv[1])
265     for data_file in sys.argv[2:]:
266         print "loading data file:", data_file
267         db.load_data(data_file)
268     db.index_data()
269
270     print "Schema: authority areas"
271     for a in db.authareas.keys():
272         print "   %s" % a
273     print "Schema: classes"
274     for c in db.classes.keys():
275         print "   %s" % c
276     print "Schema: attributes"
277     for a in db.attrs.keys():
278         print "   %s" % a
279
280     print "Is 'Network' a class?", db.is_objectclass("Network")
281         
282 #    for k, v in db.main_index.items():
283 #        print "main_index[", k, "]:", v
284
285     print "searching for a.com"
286     res = db.search_attr("domain-name", "a.com")
287     print res.list()
288     print [ str(x) for x in db.fetch_objects(res.list()) ]
289
290     print "searching for doe"
291     res = db.search_normal("doe")
292     print res.list()
293     print [ str(x) for x in db.fetch_objects(res.list()) ]
294
295     print "searching for 10.0.0.2"
296     res = db.search_cidr("10.0.0.2")
297     print res.list()
298     print [ str(x) for x in db.fetch_objects(res.list()) ]
299
300     print "searching for fddi.a.com"
301     res = db.search_normal("fddi.a.com")
302     print res.list()
303
304     print "searching referral index for fddi.a.com"
305     res = db.search_attr("referred-auth-area", "fddi.a.com")
306     print res.list()
307     print [ str(x) for x in db.fetch_objects(res.list()) ]
308
309