Initial revision
[python-rwhoisd.git] / rwhoisd / MemDB.py
1 import bisect, types
2 import MemIndex, Cidr
3 from Rwhois import rwhoisobject
4
5 class MemDB:
6
7     def __init__(self):
8
9         # a dictonary holding the various attribute indexes.  The keys
10         # are lowercase attribute names, values are MemIndex or
11         # CidrMemIndex objects.
12         self.indexes = {}
13
14         # a dictonary holding the actual rwhoisobjects.  keys are
15         # string IDs, values are rwhoisobject instances.
16         self.main_index = {}
17
18         # dictonary holding all of the seen attributes.  keys are
19         # lowercase attribute names, value is a character indicating
20         # the index type (if indexed), or None if not indexed.  Index
21         # type characters a 'N' for normal string index, 'C' for CIDR
22         # index.
23         self.attrs = {}
24
25         # Lists containing attribute names that have indexes by type.
26         # This exists so unconstrained searches can just iterate over
27         # them.
28         self.normal_indexes = []
29         self.cidr_indexes = []
30
31         # dictonary holding all of the seen class names.  keys are
32         # lowercase classnames, value is always None.
33         self.classes = {}
34
35         # dictionary holding all of the seen auth-areas.  keys are
36         # lowercase authority area names, value is always None.
37         self.authareas = {}
38
39     def init_schema(self, schema_file):
40         """Initialize the schema from a schema file.  Currently the
41         schema file is a list of 'attribute_name = index_type' pairs,
42         one per line.  index_type is one of N or C, where N means a
43         normal string index, and C means a CIDR index.
44
45         It should be noted that this database implementation
46         implements a global namespace for attributes, which isn't
47         really correct according to RFC 2167.  RFC 2167 dictates that
48         different authority area are actually autonomous and thus have
49         separate schemas."""
50
51         # initialize base schema
52
53         self.attrs['id']         = "N"
54         self.attrs['auth-area']  = None
55         self.attrs['class-name'] = None
56         self.attrs['updated']    = None
57         self.attrs['referred-auth-area'] = "R"
58
59         sf = open(schema_file, "r")
60
61         for line in sf.xreadlines():
62             line = line.strip()
63             if not line or line.startswith("#"): continue
64
65             attr, it = line.split("=")
66             self.attrs[attr.strip().lower()] = it.strip()[0].upper()
67
68         for attr, index_type in self.attrs.items():
69             if index_type == "N":
70                 # normal index
71                 self.indexes[attr] = MemIndex.MemIndex()
72                 self.normal_indexes.append(attr)
73             elif index_type == "A":
74                 # "all" index -- both a normal and a cidr index
75                 self.indexes[attr] = MemIndex.ComboMemIndex()
76                 self.normal_indexes.append(attr)
77                 self.cidr_indexes.append(attr)
78             elif index_type == "R":
79                 # referral index, an all index that must be searched
80                 # explictly by attribute
81                 self.indexes[attr] = MemIndex.ComboMemIndex()
82             elif index_type == "C":
83                 # a cidr index
84                 self.indexes[attr] = MemIndex.CidrMemIndex()
85                 self.cidr_indexes.append(attr)
86         return
87
88     def add_object(self, obj):
89         """Add an rwhoisobject to the raw indexes, including the
90         master index."""
91
92         # add the object to the main index
93         id = obj.getid()
94         if not id: return
95         id = id.lower()
96
97         self.main_index[id] = obj
98
99         for a,v in obj.items():
100             # note the attribute.
101             index_type = self.attrs.setdefault(a, None)
102             v = v.lower()
103             # make sure that we note the auth-area and class
104             if a == 'auth-area':
105                 self.authareas.setdefault(v, None)
106             elif a == 'class-name':
107                 self.classes.setdefault(v, None)
108
109             if index_type:
110                 index = self.indexes[a]
111                 index.add(v, id)
112
113     def load_data(self, data_file):
114         """Load data from rwhoisd-style TXT files (i.e., attr:value,
115         records separated with a "---" bare line)."""
116
117         df = open(data_file, "r")
118         obj = rwhoisobject()
119
120         for line in df.xreadlines():
121             line = line.strip()
122             if line.startswith("#"): continue
123             if not line or line.startswith("---"):
124                 # we've reached the end of an object, so index it.
125                 self.add_object(obj)
126                 # reset obj
127                 obj = rwhoisobject()
128                 continue
129
130             a, v = line.split(":", 1)
131             obj.add_attr(a, v.lstrip())
132
133         self.add_object(obj)
134         return
135
136     def index_data(self):
137         """Prepare the indexes for searching.  Currently, this isn't
138         strictly necessary (the indexes will prepare themselves when
139         necessary), but it should elminate a penalty on initial
140         searches"""
141
142         for i in self.indexes.values():
143             i.prepare()
144         return
145
146     def is_attribute(self, attr):
147         return self.attrs.has_key(attr.lower())
148
149     def is_indexed_attr(self, attr):
150         if self.is_attribute(attr):
151             return self.attrs[attr.lower()]
152         return False
153
154     def is_objectclass(self, objectclass):
155         return self.classes.has_key(objectclass.lower())
156
157     def is_autharea(self, aa):
158         return self.authareas.has_key(aa.lower())
159
160     def fetch_objects(self, id_list):
161         return [ self.main_index[x] for x in id_list
162                  if self.main_index.has_key(x) ]
163
164     def search_attr(self, attr, value, max = 0):
165
166         """Search for a value in a particular attribute's index.  If
167         the attribute is cidr indexed, an attempt to convert value
168         into a Cidr object will be made.  Returns a list of object ids
169         (or an empty list if nothing was found)"""
170
171         attr = attr.lower()
172         index_type = self.attrs.get(attr)
173         index = self.indexes.get(attr)
174         if not index: return []
175
176         super_prefix_match = False
177         if value.endswith("**"):
178             super_prefix_match = True
179
180         prefix_match = False
181         if value.endswith("*"):
182             value = value.rstrip("*")
183             prefix_match = True
184
185         if index_type == 'C' and not isinstance(value, Cidr.Cidr):
186             value = Cidr.valid_cidr(value)
187         else:
188             value = value.strip().lower()
189
190         if index_type == 'C' and super_prefix_match:
191             return index.find_subnets(value, max)
192
193         res = index.find(value, prefix_match, max)
194         return IndexResult(res)
195
196     def search_normal(self, value, max = 0):
197         """Search for a value in the 'normal' (string keyed) indexes.
198         Returns a list of object ids, or an empty list if nothing was
199         found."""
200
201         res = IndexResult()
202
203         for attr in self.normal_indexes:
204             res.extend(self.search_attr(attr, value, max))
205             if max:
206                 if len(res) >= max:
207                     res.truncate(max)
208                     return res
209         return res
210
211     def search_cidr(self, value, max = 0):
212         """Search for a value in the cidr indexes.  Returns a list of
213         object ids, or an empty list if nothing was found."""
214
215         res = IndexResult()
216         for attr in self.cidr_indexes:
217             res.extend(self.search_attr(attr, value, max))
218             if max:
219                 if len(res) >= max:
220                     res.truncate(max)
221                     return res
222         return res
223
224     def search_referral(self, value, max = 0):
225         """Given a heirarchal value, search for referrals.  Returns a
226         list of object ids or an empty list."""
227
228         return self.search_attr("referred-auth-area", value, max)
229
230     def object_iterator(self):
231         return self.main_index.itervalues()
232
233 class IndexResult:
234     def __init__(self, list=None):
235         if not list: list = []
236         self.data = list
237         self._dict = dict(zip(self.data, self.data))
238
239     def extend(self, list):
240         if isinstance(list, type(self)):
241             list = list.list()
242         new_els = [ x for x in list if not self._dict.has_key(x) ]
243         self.data.extend(new_els)
244         self._dict.update(dict(zip(new_els, new_els)))
245
246     def list(self):
247         return self.data
248
249     def truncate(self, n=0):
250         to_del = self.data[n:]
251         for i in to_del: del self._dict[i]
252         self.data = self.data[:n]
253
254
255 # test driver
256 if __name__ == "__main__":
257     import sys
258     db = MemDB()
259
260     print "loading schema:", sys.argv[1]
261     db.init_schema(sys.argv[1])
262     for data_file in sys.argv[2:]:
263         print "loading data file:", data_file
264         db.load_data(data_file)
265     db.index_data()
266
267     print "Schema: authority areas"
268     for a in db.authareas.keys():
269         print "   %s" % a
270     print "Schema: classes"
271     for c in db.classes.keys():
272         print "   %s" % c
273     print "Schema: attributes"
274     for a in db.attrs.keys():
275         print "   %s" % a
276
277     print "Is 'Network' a class?", db.is_objectclass("Network")
278         
279 #    for k, v in db.main_index.items():
280 #        print "main_index[", k, "]:", v
281
282     print "searching for a.com"
283     res = db.search_attr("domain-name", "a.com")
284     print res.list()
285     print [ str(x) for x in db.fetch_objects(res.list()) ]
286
287     print "searching for doe"
288     res = db.search_normal("doe")
289     print res.list()
290     print [ str(x) for x in db.fetch_objects(res.list()) ]
291
292     print "searching for 10.0.0.2"
293     res = db.search_cidr("10.0.0.2")
294     print res.list()
295     print [ str(x) for x in db.fetch_objects(res.list()) ]
296
297     print "searching for fddi.a.com"
298     res = db.search_normal("fddi.a.com")
299     print res.list()
300
301     print "searching referral index for fddi.a.com"
302     res = db.search_attr("referred-auth-area", "fddi.a.com")
303     print res.list()
304     print [ str(x) for x in db.fetch_objects(res.list()) ]
305
306