Change Cidr.Cidr.create() to just Cidr.new(), which seems a bit more concise
[python-rwhoisd.git] / rwhoisd / Cidr.py
1 # This file is part of python-rwhoisd
2 #
3 # Copyright (C) 2003, 2008 David E. Blacka
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 2 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
18 # USA
19
20 import socket, types, copy, bisect, struct
21
22 def new(address, netlen = -1):
23     """Construct either a CidrV4 or CidrV6 object."""
24
25     # ints are probably v4 addresses.
26     if isinstance(address, int):
27         return CidrV4(address, netlen)
28     # longs could be v4 addresses, but we will only assume so if the
29     # value is small.
30     if isinstance(address, long):
31         if address <= pow(2, 32):
32             return CidrV4(address, netlen)
33         return CidrV6(address, netlen)
34     # otherwise, a colon in the address is a dead giveaway.
35     if ":" in address:
36         return CidrV6(address, netlen)
37     return CidrV4(address, netlen)
38
39 class Cidr:
40     """A class representing a generic CIDRized network value."""
41
42
43
44     def _initialize(self, address, netlen):
45         """This a common constructor that is used by the subclasses."""
46
47         if isinstance(address, int) or \
48                 isinstance(address, long) and netlen >= 0:
49             self.numaddr, self.netlen = address, netlen
50             self.addr = self._convert_ipaddr(address)
51             self.calc()
52             return
53
54         if not self.is_valid_cidr(address):
55             raise ValueError, \
56                 repr(address) + " is not a valid CIDR representation"
57
58         if netlen < 0:
59             if type(address) == types.StringType:
60                 if "/" in address:
61                     self.addr, self.netlen = address.split("/", 1)
62                 else:
63                     self.addr, self.netlen = address, self._max_netlen()
64             elif type(address) == types.TupleType:
65                 self.addr, self.netlen = address
66             else:
67                 raise TypeError, "address must be a string or a tuple"
68         else:
69             self.addr, self.netlen = address, netlen
70
71
72         # convert string network lengths to integer
73         if type(self.netlen) == types.StringType:
74             self.netlen = int(self.netlen)
75
76         self.calc()
77
78     def __str__(self):
79         return self.addr + "/" + str(self.netlen)
80
81     def __repr__(self):
82         return "<" + str(self) + ">"
83
84     def __cmp__(self, other):
85         """One CIDR network block is less than another if the start
86         address is numerically less or if the block is larger.  That
87         is, supernets will sort before subnets.  This ordering allows
88         for an efficient search for subnets of a given network."""
89
90         res = self._base_mask(self.numaddr) - other._base_mask(other.numaddr)
91         if res == 0: res = self.netlen - other.netlen
92         if res < 0: return -1
93         if res > 0: return 1
94         return 0
95
96     def calc(self):
97         """This method should be called after any change to the main
98         internal state: netlen or numaddr."""
99
100         # make sure the network length is valid
101         if not self.is_valid_netlen(self.netlen):
102             raise TypeError, "network length must be between 0 and %d" % \
103                 (self._max_netlen())
104
105         # convert the string ipv4 address to a 32bit number
106         self.numaddr = self._convert_ipstr(self.addr)
107         # calculate our netmask
108         self.mask = self._mask(self.netlen)
109         # force the cidr address into correct masked notation
110         self.numaddr &= self.mask
111
112         # convert the number back to a string to normalize the string
113         self.addr = self._convert_ipaddr(self.numaddr)
114
115     def is_supernet(self, other):
116         """returns True if the other Cidr object is a supernet (an
117         enclosing network block) of this one.  A Cidr object is a
118         supernet of itself."""
119         return other.numaddr & self.mask == self.numaddr
120
121     def is_subnet(self, other):
122         """returns True if the other Cidr object is a subnet (an
123         enclosednetwork block) of this one.  A Cidr object is a
124         subnet of itself."""
125         return self.numaddr & other.mask == other.numaddr
126
127     def netmask(self):
128         """return the netmask of this Cidr network"""
129         return self._convert_ipaddr(self.mask)
130
131     def length(self):
132         """return the length (in number of addresses) of this network block"""
133         return 1 << (self._max_netlen() - self.netlen);
134
135     def end(self):
136         """return the last IP address in this network block"""
137         return self._convert_ipaddr(self.numaddr + self.length() - 1)
138
139     def to_netblock(self):
140         return (self.addr, self.end())
141
142     def clone(self):
143         # we can get away with a shallow copy (so far)
144         return copy.copy(self)
145
146     def is_ipv6(self):
147         if isinstance(self, CidrV6): return True
148         return False
149
150     def is_valid_cidr(self, address):
151         if "/" in address:
152             addr, netlen = address.split("/", 1)
153             netlen = int(netlen)
154         else:
155             addr, netlen = address, 0
156         return self._is_valid_address(addr) and self.is_valid_netlen(netlen)
157
158     def is_valid_netlen(self, netlen):
159         if netlen < 0: return False
160         if netlen > self._max_netlen(): return False
161         return True
162
163
164 class CidrV4(Cidr):
165     """A class representing a CIDRized IPv4 network value.
166
167     Specifically, it is representing a contiguous IPv4 network block
168     that can be expressed as a ip-address/network-length pair."""
169
170     base_mask = 0xFFFFFFFF
171     msb_mask  = 0x80000000
172
173     def __init__(self, address, netlen = -1):
174         """This takes either a formatted string in CIDR notation:
175         (e.g., "127.0.0.1/32"), A tuple consisting of an formatting
176         string IPv4 address and a numeric network length, or the same
177         as two arguments."""
178
179         self._initialize(address, netlen)
180
181     def _is_valid_address(self, address):
182         """Returns True if the address is a legal IPv4 address."""
183         try:
184             self._convert_ipstr(address)
185             return True
186         except socket.error:
187             return False
188
189     def _base_mask(self, numaddr):
190         return numaddr & CidrV4.base_mask
191
192     def _max_netlen(self):
193         return 32
194
195     def _convert_ipstr(self, addr):
196         packed_numaddr = socket.inet_aton(addr)
197         return struct.unpack("!I", packed_numaddr)[0]
198
199     def _convert_ipaddr(self, numaddr):
200         packed_numaddr = struct.pack("!I", numaddr)
201         return socket.inet_ntoa(packed_numaddr)
202
203     def _mask(self, len):
204         return self._base_mask(CidrV4.base_mask << (32 - len))
205
206 class CidrV6(Cidr):
207     """A class representing a CIDRized IPv6 network value.
208
209     Specifically, it is representing a contiguous IPv6 network block
210     that can be expressed as a ipv6-address/network-length pair."""
211
212     base_mask  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL # 128-bits of all ones.
213     msb_mask   = 0x80000000000000000000000000000000L
214     lower_mask = 0x0000000000000000FFFFFFFFFFFFFFFFL
215     upper_mask = 0xFFFFFFFFFFFFFFFF0000000000000000L
216
217     def __init__(self, address, netlen = -1):
218
219         self._initialize(address, netlen)
220
221     def _is_valid_address(self, address):
222         try:
223             self._convert_ipstr(address)
224             return True
225         except socket.error, e:
226             print "Failed to convert address string '%s': " + str(e) % (address)
227             return False
228
229     def _base_mask(self, numaddr):
230         return numaddr & CidrV6.base_mask
231
232     def _max_netlen(self):
233         return 128
234
235     def _convert_ipstr(self, addr):
236         packed_numaddr = socket.inet_pton(socket.AF_INET6, addr)
237         upper, lower = struct.unpack("!QQ", packed_numaddr);
238         return (upper << 64) | lower
239
240     def _convert_ipaddr(self, numaddr):
241         upper = (numaddr & CidrV6.upper_mask) >> 64;
242         lower = numaddr & CidrV6.lower_mask;
243         packed_numaddr = struct.pack("!QQ", upper, lower)
244         return socket.inet_ntop(socket.AF_INET6, packed_numaddr)
245
246     def _mask(self, len):
247         return self._base_mask(CidrV6.base_mask << (128 - len))
248
249
250 def valid_cidr(address):
251     """Returns the converted Cidr object if 'address' is valid CIDR
252     notation, False if not.  For the purposes of this module, valid
253     CIDR notation consists of a IPv4 or IPv6 address with an optional
254     trailing "/netlen"."""
255
256     if isinstance(address, Cidr): return address
257     try:
258         c = new(address)
259         return c
260     except (ValueError, socket.error):
261         return False
262
263 def netblock_to_cidr(start, end):
264     """Convert an arbitrary network block expressed as a start and end
265     address (inclusive) into a series of valid CIDR blocks."""
266
267     def largest_prefix(length, max_netlen, msb_mask):
268         # calculates the largest network length (smallest mask length)
269         # that can fit within the block length.
270         i = 1; v = length
271         while i <= max_netlen:
272             if v & msb_mask: break
273             i += 1; v <<= 1
274         return i
275     def netlen_to_mask(n, max_netlen, base_mask):
276         # convert the network length into its netmask
277         return ~((1 << (max_netlen - n)) - 1) & base_mask
278     def netlen_to_length(n, max_netlen, base_mask):
279         return 1 << (max_netlen - n) & base_mask
280
281     # convert the start and ending addresses of the netblock to Cidr
282     # object, mostly so we can get the numeric versions of their
283     # addresses.
284     cs = valid_cidr(start)
285     ce = valid_cidr(end)
286
287     # if either the start or ending addresses aren't valid addresses,
288     # quit now.
289     if not cs or not ce:
290         return None
291     # if the start and ending addresses aren't in the same family, quit now
292     if cs.is_ipv6() != ce.is_ipv6():
293         return None
294
295     max_netlen = cs._max_netlen()
296     msb_mask = cs.msb_mask
297     base_mask = cs.base_mask
298
299     # calculate the number of IP address in the netblock
300     block_len = ce.numaddr - cs.numaddr
301     # calcuate the largest CIDR block size that fits
302     netlen = largest_prefix(block_len + 1, max_netlen, msb_mask)
303
304     res = []; s = cs.numaddr
305     while block_len > 0:
306         mask = netlen_to_mask(netlen, max_netlen, base_mask)
307         # check to see if our current network length is valid
308         if (s & mask) != s:
309             # if not, shrink the network block size
310             netlen += 1
311             continue
312         # otherwise, we have a valid CIDR block, so add it to the list
313         res.append(new(s, netlen))
314         # and setup for the next round:
315         cur_len = netlen_to_length(netlen, max_netlen, base_mask)
316         s         += cur_len
317         block_len -= cur_len
318         netlen = largest_prefix(block_len + 1, max_netlen, msb_mask)
319     return res
320
321 # test driver
322 if __name__ == "__main__":
323     import sys
324     a = new("127.00.000.1/24")
325     b = new("127.0.0.1", 32)
326     c = new("24.232.119.192", 26)
327     d = new("24.232.119.0", 24)
328     e = new("24.224.0.0", 11)
329     f = new("216.168.111.0/27");
330     g = new("127.0.0.2/31");
331     h = new("127.0.0.16/32")
332     i = new("3ffe:4:201e:beef::0/64");
333     j = new("2001:3c01::/32")
334
335     print f.addr
336     print j.addr
337
338     try:
339         bad = new("24.261.119.0", 32)
340     except ValueError, x:
341         print "error:", x
342
343     print "cidr:", a, "num addresses:", a.length(), "ending address", \
344         a.end(), "netmask", a.netmask()
345
346     print "cidr:", j, "num addresses:", j.length(), "ending address", \
347         j.end(), "netmask", j.netmask()
348
349     clist = [a, b, c, d, e, f, g, h, i , j]
350     print "unsorted list of cidr objects:\n  ", clist
351
352
353     clist.sort()
354     print "sorted list of cidr object:\n  ", clist
355
356     k = new("2001:3c01::1:0", 120)
357     print "supernet: ", str(j), " supernet of ", str(k), "? ", \
358         str(j.is_supernet(k))
359     print "supernet: ", str(k), " supernet of ", str(j), "? ", \
360         str(k.is_supernet(j))
361     print "subnet: ", str(j), " subnet of ", str(k), "? ", \
362         str(j.is_subnet(k))
363     print "subnet: ", str(k), " subnet of ", str(j), "? ", \
364         str(k.is_subnet(j))
365
366     netblocks = [ ("192.168.10.0", "192.168.10.255"),
367                   ("192.168.10.0", "192.168.10.63"),
368                   ("172.16.0.0", "172.16.127.255"),
369                   ("24.33.41.22", "24.33.41.37"),
370                   ("196.11.1.0", "196.11.30.255"),
371                   ("192.247.1.0", "192.247.10.255"),
372                   ("10.131.43.3", "10.131.44.7"),
373                   ("3ffe:4:5::", "3ffe:4:5::ffff"),
374                   ("3ffe:4:5::", "3ffe:4:6::1")]
375
376     for start, end in netblocks:
377         print "netblock %s - %s:" % (start, end)
378         blocks = netblock_to_cidr(start, end)
379         print blocks