IPv6 Cidr code now working
[python-rwhoisd.git] / rwhoisd / Cidr.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, 2008 David E. Blacka
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 2 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
18 # USA
19
20 import socket, types, copy, bisect, struct
21
22 class Cidr:
23     """A class representing a generic CIDRized network value."""    
24
25     @staticmethod
26     def create(address, netlen = -1):
27         """Construct either a CidrV4 or CidrV6 object."""
28         if isinstance(address, int):
29             return CidrV4(address, netlen)
30         if isinstance(address, long):
31             if address <= pow(2, 32):
32                 return CidrV4(address, netlen)
33             return CidrV6(address, netlen)
34         if ":" in address:
35             return CidrV6(address, netlen)
36         return CidrV4(address, netlen)
37
38
39     def _initialize(self, address, netlen):
40         """This a common constructor that is used by the subclasses."""
41
42         if isinstance(address, int) or \
43                 isinstance(address, long) and netlen >= 0:
44             self.numaddr, self.netlen = address, netlen
45             self.addr = self._convert_ipaddr(address)
46             self.calc()
47             return
48
49         if not self.is_valid_cidr(address):
50             raise ValueError, \
51                 repr(address) + " is not a valid CIDR representation"
52
53         if netlen < 0:
54             if type(address) == types.StringType:
55                 if "/" in address:
56                     self.addr, self.netlen = address.split("/", 1)
57                 else:
58                     self.addr, self.netlen = address, self._max_netlen()
59             elif type(address) == types.TupleType:
60                 self.addr, self.netlen = address
61             else:
62                 raise TypeError, "address must be a string or a tuple"
63         else:
64             self.addr, self.netlen = address, netlen
65
66
67         # convert string network lengths to integer
68         if type(self.netlen) == types.StringType:
69             self.netlen = int(self.netlen)
70
71         self.calc()
72
73     def __str__(self):
74         return self.addr + "/" + str(self.netlen)
75
76     def __repr__(self):
77         return "<" + str(self) + ">"
78
79     def __cmp__(self, other):
80         """One CIDR network block is less than another if the start
81         address is numerically less or if the block is larger.  That
82         is, supernets will sort before subnets.  This ordering allows
83         for an efficient search for subnets of a given network."""
84
85         res = self._base_mask(self.numaddr) - other._base_mask(other.numaddr)
86         if res == 0: res = self.netlen - other.netlen
87         if res < 0: return -1
88         if res > 0: return 1
89         return 0
90
91     def calc(self):
92         """This method should be called after any change to the main
93         internal state: netlen or numaddr."""
94
95         # make sure the network length is valid
96         if not self.is_valid_netlen(self.netlen):
97             raise TypeError, "network length must be between 0 and %d" % \
98                 (self._max_netlen())
99
100         # convert the string ipv4 address to a 32bit number
101         self.numaddr = self._convert_ipstr(self.addr)
102         # calculate our netmask
103         self.mask = self._mask(self.netlen)
104         # force the cidr address into correct masked notation
105         self.numaddr &= self.mask
106
107         # convert the number back to a string to normalize the string
108         self.addr = self._convert_ipaddr(self.numaddr)
109
110     def is_supernet(self, other):
111         """returns True if the other Cidr object is a supernet (an
112         enclosing network block) of this one.  A Cidr object is a
113         supernet of itself."""
114         return other.numaddr & self.mask == self.numaddr
115
116     def is_subnet(self, other):
117         """returns True if the other Cidr object is a subnet (an
118         enclosednetwork block) of this one.  A Cidr object is a
119         subnet of itself."""
120         return self.numaddr & other.mask == other.numaddr
121
122     def netmask(self):
123         """return the netmask of this Cidr network"""
124         return self._convert_ipaddr(self.mask)
125     
126     def length(self):
127         """return the length (in number of addresses) of this network block"""
128         return 1 << (self._max_netlen() - self.netlen);
129
130     def end(self):
131         """return the last IP address in this network block"""
132         return self._convert_ipaddr(self.numaddr + self.length() - 1)
133
134     def to_netblock(self):
135         return (self.addr, self.end())
136
137     def clone(self):
138         # we can get away with a shallow copy (so far)
139         return copy.copy(self)
140     
141     def is_ipv6(self):
142         if isinstance(self, CidrV6): return True
143         return False
144
145     def is_valid_cidr(self, address):
146         if "/" in address:
147             addr, netlen = address.split("/", 1)
148             netlen = int(netlen)
149         else:
150             addr, netlen = address, 0
151         return self._is_valid_address(addr) and self.is_valid_netlen(netlen)
152
153     def is_valid_netlen(self, netlen):
154         if netlen < 0: return False
155         if netlen > self._max_netlen(): return False
156         return True
157
158
159 class CidrV4(Cidr):
160     """A class representing a CIDRized IPv4 network value.
161
162     Specifically, it is representing a contiguous IPv4 network block
163     that can be expressed as a ip-address/network-length pair."""
164
165     base_mask = 0xFFFFFFFF
166     msb_mask  = 0x80000000
167
168     def __init__(self, address, netlen = -1):
169         """This takes either a formatted string in CIDR notation:
170         (e.g., "127.0.0.1/32"), A tuple consisting of an formatting
171         string IPv4 address and a numeric network length, or the same
172         as two arguments."""
173         
174         self._initialize(address, netlen)
175         
176     def _is_valid_address(self, address):
177         """Returns True if the address is a legal IPv4 address."""
178         try:
179             self._convert_ipstr(address)
180             return True
181         except socket.error:
182             return False
183
184     def _base_mask(self, numaddr):
185         return numaddr & CidrV4.base_mask
186
187     def _max_netlen(self):
188         return 32
189
190     def _convert_ipstr(self, addr):
191         packed_numaddr = socket.inet_aton(addr)
192         return struct.unpack("!I", packed_numaddr)[0]
193
194     def _convert_ipaddr(self, numaddr):
195         packed_numaddr = struct.pack("!I", numaddr)
196         return socket.inet_ntoa(packed_numaddr)
197
198     def _mask(self, len):
199         return self._base_mask(CidrV4.base_mask << (32 - len))
200
201 class CidrV6(Cidr):
202     """A class representing a CIDRized IPv6 network value.
203
204     Specifically, it is representing a contiguous IPv6 network block
205     that can be expressed as a ipv6-address/network-length pair."""
206     
207     base_mask  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL # 128-bits of all ones.
208     msb_mask   = 0x80000000000000000000000000000000L
209     lower_mask = 0x0000000000000000FFFFFFFFFFFFFFFFL
210     upper_mask = 0xFFFFFFFFFFFFFFFF0000000000000000L
211
212     def __init__(self, address, netlen = -1):
213         
214         self._initialize(address, netlen)
215
216     def _is_valid_address(self, address):
217         try:
218             self._convert_ipstr(address)
219             return True
220         except socket.error, e:
221             print "Failed to convert address string '%s': " + str(e) % (address)
222             return False
223
224     def _base_mask(self, numaddr):
225         return numaddr & CidrV6.base_mask
226
227     def _max_netlen(self):
228         return 128
229
230     def _convert_ipstr(self, addr):
231         packed_numaddr = socket.inet_pton(socket.AF_INET6, addr)
232         upper, lower = struct.unpack("!QQ", packed_numaddr);
233         return (upper << 64) | lower
234     
235     def _convert_ipaddr(self, numaddr):
236         upper = (numaddr & CidrV6.upper_mask) >> 64;
237         lower = numaddr & CidrV6.lower_mask;
238         packed_numaddr = struct.pack("!QQ", upper, lower)
239         return socket.inet_ntop(socket.AF_INET6, packed_numaddr)
240
241     def _mask(self, len):
242         return self._base_mask(CidrV6.base_mask << (128 - len))
243
244
245 def valid_cidr(address):
246     """Returns the converted Cidr object if 'address' is valid CIDR
247     notation, False if not.  For the purposes of this module, valid
248     CIDR notation consists of a IPv4 or IPv6 address with an optional
249     trailing "/netlen"."""
250
251     if isinstance(address, Cidr): return address
252     try:
253         c = Cidr.create(address)
254         return c
255     except (ValueError, socket.error):
256         return False
257
258 def netblock_to_cidr(start, end):
259     """Convert an arbitrary network block expressed as a start and end
260     address (inclusive) into a series of valid CIDR blocks."""
261
262     def largest_prefix(length, max_netlen, msb_mask):
263         # calculates the largest network length (smallest mask length)
264         # that can fit within the block length.
265         i = 1; v = length
266         while i <= max_netlen:
267             if v & msb_mask: break
268             i += 1; v <<= 1
269         return i
270     def netlen_to_mask(n, max_netlen, base_mask):
271         # convert the network length into its netmask
272         return ~((1 << (max_netlen - n)) - 1) & base_mask
273     def netlen_to_length(n, max_netlen, base_mask):
274         return 1 << (max_netlen - n) & base_mask
275
276     # convert the start and ending addresses of the netblock to Cidr
277     # object, mostly so we can get the numeric versions of their
278     # addresses.
279     cs = valid_cidr(start)
280     ce = valid_cidr(end)
281
282     # if either the start or ending addresses aren't valid addresses,
283     # quit now.
284     if not cs or not ce:
285         return None
286     # if the start and ending addresses aren't in the same family, quit now
287     if cs.is_ipv6() != ce.is_ipv6():
288         return None
289     
290     max_netlen = cs._max_netlen()
291     msb_mask = cs.msb_mask
292     base_mask = cs.base_mask
293
294     # calculate the number of IP address in the netblock
295     block_len = ce.numaddr - cs.numaddr
296     # calcuate the largest CIDR block size that fits
297     netlen = largest_prefix(block_len + 1, max_netlen, msb_mask)
298     
299     res = []; s = cs.numaddr
300     while block_len > 0:
301         mask = netlen_to_mask(netlen, max_netlen, base_mask)
302         # check to see if our current network length is valid
303         if (s & mask) != s:
304             # if not, shrink the network block size
305             netlen += 1
306             continue
307         # otherwise, we have a valid CIDR block, so add it to the list
308         res.append(Cidr.create(s, netlen))
309         # and setup for the next round:
310         cur_len = netlen_to_length(netlen, max_netlen, base_mask)
311         s         += cur_len
312         block_len -= cur_len
313         netlen = largest_prefix(block_len + 1, max_netlen, msb_mask)
314     return res
315
316 # test driver
317 if __name__ == "__main__":
318     import sys
319     a = Cidr.create("127.00.000.1/24")
320     b = Cidr.create("127.0.0.1", 32)
321     c = Cidr.create("24.232.119.192", 26)
322     d = Cidr.create("24.232.119.0", 24)
323     e = Cidr.create("24.224.0.0", 11)
324     f = Cidr.create("216.168.111.0/27");
325     g = Cidr.create("127.0.0.2/31");
326     h = Cidr.create("127.0.0.16/32")
327     i = Cidr.create("3ffe:4:201e:beef::0/64");
328     j = Cidr.create("2001:3c01::/32")
329
330     print f.addr
331     print j.addr
332     
333     try:
334         bad = Cidr.create("24.261.119.0", 32)
335     except ValueError, x:
336         print "error:", x
337     
338     print "cidr:", a, "num addresses:", a.length(), "ending address", \
339         a.end(), "netmask", a.netmask()
340
341     print "cidr:", j, "num addresses:", j.length(), "ending address", \
342         j.end(), "netmask", j.netmask()
343
344     clist = [a, b, c, d, e, f, g, h, i , j]
345     print "unsorted list of cidr objects:\n  ", clist
346
347
348     clist.sort()
349     print "sorted list of cidr object:\n  ", clist
350
351     k = Cidr.create("2001:3c01::1:0", 120)
352     print "supernet: ", str(j), " supernet of ", str(k), "? ", \
353         str(j.is_supernet(k))
354     print "supernet: ", str(k), " supernet of ", str(j), "? ", \
355         str(k.is_supernet(j))
356     print "subnet: ", str(j), " subnet of ", str(k), "? ", \
357         str(j.is_subnet(k))
358     print "subnet: ", str(k), " subnet of ", str(j), "? ", \
359         str(k.is_subnet(j))
360
361     netblocks = [ ("192.168.10.0", "192.168.10.255"),
362                   ("192.168.10.0", "192.168.10.63"),
363                   ("172.16.0.0", "172.16.127.255"),
364                   ("24.33.41.22", "24.33.41.37"),
365                   ("196.11.1.0", "196.11.30.255"),
366                   ("192.247.1.0", "192.247.10.255"),
367                   ("10.131.43.3", "10.131.44.7"),
368                   ("3ffe:4:5::", "3ffe:4:5::ffff"),
369                   ("3ffe:4:5::", "3ffe:4:6::1")]
370                   
371     for start, end in netblocks:
372         print "netblock %s - %s:" % (start, end)
373         blocks = netblock_to_cidr(start, end)
374         print blocks