Update TODO based on work for 0.4.1
[python-rwhoisd.git] / rwhoisd / Cidr.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, 2008 David E. Blacka
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 2 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
18 # USA
19
20 import socket, types, copy, bisect, struct
21 import v6addr
22
23 def new(address, netlen = -1):
24     """Construct either a CidrV4 or CidrV6 object."""
25
26     # ints are probably v4 addresses.
27     if isinstance(address, int):
28         return CidrV4(address, netlen)
29     # longs could be v4 addresses, but we will only assume so if the
30     # value is small.
31     if isinstance(address, long):
32         if address <= pow(2, 32):
33             return CidrV4(address, netlen)
34         return CidrV6(address, netlen)
35     # otherwise, a colon in the address is a dead giveaway.
36     if ":" in address:
37         return CidrV6(address, netlen)
38     return CidrV4(address, netlen)
39
40 class Cidr:
41     """A class representing a generic CIDRized network value."""
42
43     def _initialize(self, address, netlen):
44         """This a common constructor that is used by the subclasses."""
45
46         if isinstance(address, int) or \
47                 isinstance(address, long) and netlen >= 0:
48             self.numaddr, self.netlen = address, netlen
49             self.addr = self._convert_ipaddr(address)
50             self.calc()
51             return
52
53         if not self.is_valid_cidr(address):
54             raise ValueError, \
55                 repr(address) + " is not a valid CIDR representation"
56
57         if netlen < 0:
58             if type(address) == types.StringType:
59                 if "/" in address:
60                     self.addr, self.netlen = address.split("/", 1)
61                 else:
62                     self.addr, self.netlen = address, self._max_netlen()
63             elif type(address) == types.TupleType:
64                 self.addr, self.netlen = address
65             else:
66                 raise TypeError, "address must be a string or a tuple"
67         else:
68             self.addr, self.netlen = address, netlen
69
70
71         # convert string network lengths to integer
72         if type(self.netlen) == types.StringType:
73             self.netlen = int(self.netlen)
74
75         self.calc()
76
77     def __str__(self):
78         return self.addr + "/" + str(self.netlen)
79
80     def __repr__(self):
81         return "<" + str(self) + ">"
82
83     def __cmp__(self, other):
84         """One CIDR network block is less than another if the start
85         address is numerically less or if the block is larger.  That
86         is, supernets will sort before subnets.  This ordering allows
87         for an efficient search for subnets of a given network."""
88
89         res = self._base_mask(self.numaddr) - other._base_mask(other.numaddr)
90         if res == 0: res = self.netlen - other.netlen
91         if res < 0: return -1
92         if res > 0: return 1
93         return 0
94
95     def calc(self):
96         """This method should be called after any change to the main
97         internal state: netlen or numaddr."""
98
99         # make sure the network length is valid
100         if not self.is_valid_netlen(self.netlen):
101             raise TypeError, "network length must be between 0 and %d" % \
102                 (self._max_netlen())
103
104         # convert the string ipv4 address to a 32bit number
105         self.numaddr = self._convert_ipstr(self.addr)
106         # calculate our netmask
107         self.mask = self._mask(self.netlen)
108         # force the cidr address into correct masked notation
109         self.numaddr &= self.mask
110
111         # convert the number back to a string to normalize the string
112         self.addr = self._convert_ipaddr(self.numaddr)
113
114     def is_supernet(self, other):
115         """returns True if the other Cidr object is a supernet (an
116         enclosing network block) of this one.  A Cidr object is a
117         supernet of itself."""
118         return other.numaddr & self.mask == self.numaddr
119
120     def is_subnet(self, other):
121         """returns True if the other Cidr object is a subnet (an
122         enclosednetwork block) of this one.  A Cidr object is a
123         subnet of itself."""
124         return self.numaddr & other.mask == other.numaddr
125
126     def netmask(self):
127         """return the netmask of this Cidr network"""
128         return self._convert_ipaddr(self.mask)
129
130     def length(self):
131         """return the length (in number of addresses) of this network block"""
132         return 1 << (self._max_netlen() - self.netlen);
133
134     def end(self):
135         """return the last IP address in this network block"""
136         return self._convert_ipaddr(self.numaddr + self.length() - 1)
137
138     def to_netblock(self):
139         return (self.addr, self.end())
140
141     def clone(self):
142         # we can get away with a shallow copy (so far)
143         return copy.copy(self)
144
145     def is_ipv6(self):
146         if isinstance(self, CidrV6): return True
147         return False
148
149     def is_valid_cidr(self, address):
150         if "/" in address:
151             addr, netlen = address.split("/", 1)
152             netlen = int(netlen)
153         else:
154             addr, netlen = address, 0
155         return self._is_valid_address(addr) and self.is_valid_netlen(netlen)
156
157     def is_valid_netlen(self, netlen):
158         if netlen < 0: return False
159         if netlen > self._max_netlen(): return False
160         return True
161
162
163 class CidrV4(Cidr):
164     """A class representing a CIDRized IPv4 network value.
165
166     Specifically, it is representing a contiguous IPv4 network block
167     that can be expressed as a ip-address/network-length pair."""
168
169     base_mask = 0xFFFFFFFF
170     msb_mask  = 0x80000000
171
172     def __init__(self, address, netlen = -1):
173         """This takes either a formatted string in CIDR notation:
174         (e.g., "127.0.0.1/32"), A tuple consisting of an formatting
175         string IPv4 address and a numeric network length, or the same
176         as two arguments."""
177
178         self._initialize(address, netlen)
179
180     def _is_valid_address(self, address):
181         """Returns True if the address is a legal IPv4 address."""
182         try:
183             self._convert_ipstr(address)
184             return True
185         except socket.error:
186             return False
187
188     def _base_mask(self, numaddr):
189         return numaddr & CidrV4.base_mask
190
191     def _max_netlen(self):
192         return 32
193
194     def _convert_ipstr(self, addr):
195         packed_numaddr = socket.inet_aton(addr)
196         return struct.unpack("!I", packed_numaddr)[0]
197
198     def _convert_ipaddr(self, numaddr):
199         packed_numaddr = struct.pack("!I", numaddr)
200         return socket.inet_ntoa(packed_numaddr)
201
202     def _mask(self, len):
203         return self._base_mask(CidrV4.base_mask << (32 - len))
204
205 class CidrV6(Cidr):
206     """A class representing a CIDRized IPv6 network value.
207
208     Specifically, it is representing a contiguous IPv6 network block
209     that can be expressed as a ipv6-address/network-length pair."""
210
211     base_mask  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL # 128-bits of all ones.
212     msb_mask   = 0x80000000000000000000000000000000L
213     lower_mask = 0x0000000000000000FFFFFFFFFFFFFFFFL
214     upper_mask = 0xFFFFFFFFFFFFFFFF0000000000000000L
215
216     def __init__(self, address, netlen = -1):
217
218         self._initialize(address, netlen)
219
220     def _is_valid_address(self, address):
221         try:
222             self._convert_ipstr(address)
223             return True
224         except socket.error, e:
225             print "Failed to convert address string '%s': " + str(e) % (address)
226             return False
227
228     def _base_mask(self, numaddr):
229         return numaddr & CidrV6.base_mask
230
231     def _max_netlen(self):
232         return 128
233
234     def _convert_ipstr(self, addr):
235         packed_numaddr = socket.inet_pton(socket.AF_INET6, addr)
236         upper, lower = struct.unpack("!QQ", packed_numaddr);
237         return (upper << 64) | lower
238
239     def _convert_ipaddr(self, numaddr):
240         upper = (numaddr & CidrV6.upper_mask) >> 64;
241         lower = numaddr & CidrV6.lower_mask;
242         packed_numaddr = struct.pack("!QQ", upper, lower)
243         return socket.inet_ntop(socket.AF_INET6, packed_numaddr)
244
245     def _mask(self, len):
246         return self._base_mask(CidrV6.base_mask << (128 - len))
247
248
249 def valid_cidr(address):
250     """Returns the converted Cidr object if 'address' is valid CIDR
251     notation, False if not.  For the purposes of this module, valid
252     CIDR notation consists of a IPv4 or IPv6 address with an optional
253     trailing "/netlen"."""
254
255     if isinstance(address, Cidr): return address
256     try:
257         c = new(address)
258         return c
259     except (ValueError, socket.error):
260         return False
261
262 def netblock_to_cidr(start, end):
263     """Convert an arbitrary network block expressed as a start and end
264     address (inclusive) into a series of valid CIDR blocks."""
265
266     def largest_prefix(length, max_netlen, msb_mask):
267         # calculates the largest network length (smallest mask length)
268         # that can fit within the block length.
269         i = 1; v = length
270         while i <= max_netlen:
271             if v & msb_mask: break
272             i += 1; v <<= 1
273         return i
274     def netlen_to_mask(n, max_netlen, base_mask):
275         # convert the network length into its netmask
276         return ~((1 << (max_netlen - n)) - 1) & base_mask
277     def netlen_to_length(n, max_netlen, base_mask):
278         return 1 << (max_netlen - n) & base_mask
279
280     # convert the start and ending addresses of the netblock to Cidr
281     # object, mostly so we can get the numeric versions of their
282     # addresses.
283     cs = valid_cidr(start)
284     ce = valid_cidr(end)
285
286     # if either the start or ending addresses aren't valid addresses,
287     # quit now.
288     if not cs or not ce:
289         return None
290     # if the start and ending addresses aren't in the same family, quit now
291     if cs.is_ipv6() != ce.is_ipv6():
292         return None
293
294     max_netlen = cs._max_netlen()
295     msb_mask = cs.msb_mask
296     base_mask = cs.base_mask
297
298     # calculate the number of IP address in the netblock
299     block_len = ce.numaddr - cs.numaddr
300     # calcuate the largest CIDR block size that fits
301     netlen = largest_prefix(block_len + 1, max_netlen, msb_mask)
302
303     res = []; s = cs.numaddr
304     while block_len > 0:
305         mask = netlen_to_mask(netlen, max_netlen, base_mask)
306         # check to see if our current network length is valid
307         if (s & mask) != s:
308             # if not, shrink the network block size
309             netlen += 1
310             continue
311         # otherwise, we have a valid CIDR block, so add it to the list
312         res.append(new(s, netlen))
313         # and setup for the next round:
314         cur_len = netlen_to_length(netlen, max_netlen, base_mask)
315         s         += cur_len
316         block_len -= cur_len
317         netlen = largest_prefix(block_len + 1, max_netlen, msb_mask)
318     return res
319
320 # test driver
321 if __name__ == "__main__":
322     import sys
323     a = new("127.00.000.1/24")
324     b = new("127.0.0.1", 32)
325     c = new("24.232.119.192", 26)
326     d = new("24.232.119.0", 24)
327     e = new("24.224.0.0", 11)
328     f = new("216.168.111.0/27");
329     g = new("127.0.0.2/31");
330     h = new("127.0.0.16/32")
331     i = new("3ffe:4:201e:beef::0/64");
332     j = new("2001:3c01::/32")
333
334     print f.addr
335     print j.addr
336
337     try:
338         bad = new("24.261.119.0", 32)
339     except ValueError, x:
340         print "error:", x
341
342     print "cidr:", a, "num addresses:", a.length(), "ending address", \
343         a.end(), "netmask", a.netmask()
344
345     print "cidr:", j, "num addresses:", j.length(), "ending address", \
346         j.end(), "netmask", j.netmask()
347
348     clist = [a, b, c, d, e, f, g, h, i , j]
349     print "unsorted list of cidr objects:\n  ", clist
350
351
352     clist.sort()
353     print "sorted list of cidr object:\n  ", clist
354
355     k = new("2001:3c01::1:0", 120)
356     print "supernet: ", str(j), " supernet of ", str(k), "? ", \
357         str(j.is_supernet(k))
358     print "supernet: ", str(k), " supernet of ", str(j), "? ", \
359         str(k.is_supernet(j))
360     print "subnet: ", str(j), " subnet of ", str(k), "? ", \
361         str(j.is_subnet(k))
362     print "subnet: ", str(k), " subnet of ", str(j), "? ", \
363         str(k.is_subnet(j))
364
365     netblocks = [ ("192.168.10.0", "192.168.10.255"),
366                   ("192.168.10.0", "192.168.10.63"),
367                   ("172.16.0.0", "172.16.127.255"),
368                   ("24.33.41.22", "24.33.41.37"),
369                   ("196.11.1.0", "196.11.30.255"),
370                   ("192.247.1.0", "192.247.10.255"),
371                   ("10.131.43.3", "10.131.44.7"),
372                   ("3ffe:4:5::", "3ffe:4:5::ffff"),
373                   ("3ffe:4:5::", "3ffe:4:6::1")]
374
375     for start, end in netblocks:
376         print "netblock %s - %s:" % (start, end)
377         blocks = netblock_to_cidr(start, end)
378         print blocks