aca257b20f825dedd013e9e183ea59d7e72b6b0d
[captive-validator.git] / src / com / versign / tat / dnssec / SMessage.java
1 /*
2  * $Id$
3  * 
4  * Copyright (c) 2005 VeriSign. All rights reserved.
5  * 
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * 
9  * 1. Redistributions of source code must retain the above copyright notice,
10  * this list of conditions and the following disclaimer. 2. Redistributions in
11  * binary form must reproduce the above copyright notice, this list of
12  * conditions and the following disclaimer in the documentation and/or other
13  * materials provided with the distribution. 3. The name of the author may not
14  * be used to endorse or promote products derived from this software without
15  * specific prior written permission.
16  * 
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
18  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
19  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
20  * NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
22  * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  *  
28  */
29
30 package com.versign.tat.dnssec;
31
32 import java.util.*;
33
34 import org.xbill.DNS.*;
35
36 /**
37  * This class represents a DNS message with resolver/validator state.
38  */
39 public class SMessage {
40     private Header          mHeader;
41
42     private Record          mQuestion;
43     private OPTRecord       mOPTRecord;
44     private List<SRRset>[]          mSection;
45     private SecurityStatus  mSecurityStatus;
46
47     private static SRRset[] empty_srrset_array = new SRRset[0];
48
49     @SuppressWarnings("unchecked")
50     public SMessage(Header h) {
51         mSection = (List<SRRset>[]) new List[3];
52         mHeader = h;
53         mSecurityStatus = new SecurityStatus();
54     }
55
56     public SMessage(int id) {
57         this(new Header(id));
58     }
59
60     public SMessage() {
61         this(new Header(0));
62     }
63
64     public SMessage(Message m) {
65         this(m.getHeader());
66         mQuestion = m.getQuestion();
67         mOPTRecord = m.getOPT();
68
69         for (int i = Section.ANSWER; i <= Section.ADDITIONAL; i++) {
70             RRset[] rrsets = m.getSectionRRsets(i);
71
72             for (int j = 0; j < rrsets.length; j++) {
73                 addRRset(rrsets[j], i);
74             }
75         }
76     }
77
78     public Header getHeader() {
79         return mHeader;
80     }
81
82     public void setHeader(Header h) {
83         mHeader = h;
84     }
85
86     public void setQuestion(Record r) {
87         mQuestion = r;
88     }
89
90     public Record getQuestion() {
91         return mQuestion;
92     }
93
94     public Name getQName() {
95         return getQuestion().getName();
96     }
97
98     public int getQType() {
99         return getQuestion().getType();
100     }
101
102     public int getQClass() {
103         return getQuestion().getDClass();
104     }
105
106     public void setOPT(OPTRecord r) {
107         mOPTRecord = r;
108     }
109
110     public OPTRecord getOPT() {
111         return mOPTRecord;
112     }
113
114     public List<SRRset> getSectionList(int section) {
115         if (section <= Section.QUESTION || section > Section.ADDITIONAL)
116             throw new IllegalArgumentException("Invalid section.");
117
118         if (mSection[section - 1] == null) {
119             mSection[section - 1] = new LinkedList<SRRset>();
120         }
121
122         return (List<SRRset>) mSection[section - 1];
123     }
124
125     public void addRRset(SRRset srrset, int section) {
126         if (section <= Section.QUESTION || section > Section.ADDITIONAL)
127             throw new IllegalArgumentException("Invalid section");
128
129         if (srrset.getType() == Type.OPT) {
130             mOPTRecord = (OPTRecord) srrset.first();
131             return;
132         }
133
134         List<SRRset> sectionList = getSectionList(section);
135         sectionList.add(srrset);
136     }
137
138     public void addRRset(RRset rrset, int section) {
139         if (rrset instanceof SRRset) {
140             addRRset((SRRset) rrset, section);
141             return;
142         }
143
144         SRRset srrset = new SRRset(rrset);
145         addRRset(srrset, section);
146     }
147
148     public void prependRRsets(List<SRRset> rrsets, int section) {
149         if (section <= Section.QUESTION || section > Section.ADDITIONAL)
150             throw new IllegalArgumentException("Invalid section");
151
152         List<SRRset> sectionList = getSectionList(section);
153         sectionList.addAll(0, rrsets);
154     }
155
156     public SRRset[] getSectionRRsets(int section) {
157         List<SRRset> slist = getSectionList(section);
158
159         return (SRRset[]) slist.toArray(empty_srrset_array);
160     }
161
162     public SRRset[] getSectionRRsets(int section, int qtype) {
163         List<SRRset> slist = getSectionList(section);
164
165         if (slist.size() == 0) return new SRRset[0];
166
167         ArrayList<SRRset> result = new ArrayList<SRRset>(slist.size());
168         for (SRRset rrset : slist) {
169             if (rrset.getType() == qtype) result.add(rrset);
170         }
171
172         return (SRRset[]) result.toArray(empty_srrset_array);
173     }
174
175     public void deleteRRset(SRRset rrset, int section) {
176         List<SRRset> slist = getSectionList(section);
177
178         if (slist.size() == 0) return;
179
180         slist.remove(rrset);
181     }
182
183     public void clear(int section) {
184         if (section < Section.QUESTION || section > Section.ADDITIONAL)
185             throw new IllegalArgumentException("Invalid section.");
186
187         if (section == Section.QUESTION) {
188             mQuestion = null;
189             return;
190         }
191         if (section == Section.ADDITIONAL) {
192             mOPTRecord = null;
193         }
194
195         mSection[section - 1] = null;
196     }
197
198     public void clear() {
199         for (int s = Section.QUESTION; s <= Section.ADDITIONAL; s++) {
200             clear(s);
201         }
202     }
203
204     public int getRcode() {
205         // FIXME: might want to do what Message does and handle extended rcodes.
206         return mHeader.getRcode();
207     }
208
209     public int getStatus() {
210         return mSecurityStatus.getStatus();
211     }
212
213     public void setStatus(byte status) {
214         mSecurityStatus.setStatus(status);
215     }
216
217     public SecurityStatus getSecurityStatus() {
218         return mSecurityStatus;
219     }
220
221     public void setSecurityStatus(SecurityStatus s) {
222         if (s == null) return;
223         mSecurityStatus = s;
224     }
225
226     public Message getMessage() {
227         // Generate our new message.
228         Message m = new Message(mHeader.getID());
229
230         // Convert the header
231         // We do this for two reasons: 1) setCount() is package scope, so we
232         // can't do that, and 2) setting the header on a message after creating
233         // the message frequently gets stuff out of sync, leading to malformed
234         // wire format messages.
235         Header h = m.getHeader();
236         h.setOpcode(mHeader.getOpcode());
237         h.setRcode(mHeader.getRcode());
238         for (int i = 0; i < 16; i++) {
239             if (Flags.isFlag(i)) {
240                 if (mHeader.getFlag(i)) {
241                     h.setFlag(i);
242                 } else {
243                     h.unsetFlag(i);
244                 }
245             }
246         }
247
248         // Add all the records. -- this will set the counts correctly in the
249         // message header.
250
251         if (mQuestion != null) {
252             m.addRecord(mQuestion, Section.QUESTION);
253         }
254
255         for (int sec = Section.ANSWER; sec <= Section.ADDITIONAL; sec++) {
256             List<SRRset> slist = getSectionList(sec);
257             for (SRRset rrset : slist) {
258                 for (Iterator<Record> j = rrset.rrs(); j.hasNext(); ) {
259                     m.addRecord(j.next(), sec);
260                 }
261                 for (Iterator<RRSIGRecord> j = rrset.sigs(); j.hasNext(); ) {
262                     m.addRecord(j.next(), sec);
263                 }
264             }
265         }
266
267         if (mOPTRecord != null) {
268             m.addRecord(mOPTRecord, Section.ADDITIONAL);
269         }
270
271         return m;
272     }
273
274     public int getCount(int section) {
275         if (section == Section.QUESTION) {
276             return mQuestion == null ? 0 : 1;
277         }
278         List<SRRset> sectionList = getSectionList(section);
279         if (sectionList == null) return 0;
280         if (sectionList.size() == 0) return 0;
281
282         int count = 0;
283         for (SRRset sr : sectionList) {
284             count += sr.totalSize();
285         }
286
287         return count;
288     }
289
290     public String toString() {
291         return getMessage().toString();
292     }
293
294     /**
295      * Find a specific (S)RRset in a given section.
296      * 
297      * @param name
298      *            the name of the RRset.
299      * @param type
300      *            the type of the RRset.
301      * @param dclass
302      *            the class of the RRset.
303      * @param section
304      *            the section to look in (ANSWER -> ADDITIONAL)
305      * 
306      * @return The SRRset if found, null otherwise.
307      */
308     public SRRset findRRset(Name name, int type, int dclass, int section) {
309         if (section <= Section.QUESTION || section > Section.ADDITIONAL)
310             throw new IllegalArgumentException("Invalid section.");
311
312         SRRset[] rrsets = getSectionRRsets(section);
313
314         for (int i = 0; i < rrsets.length; i++) {
315             if (rrsets[i].getName().equals(name) && rrsets[i].getType() == type
316                 && rrsets[i].getDClass() == dclass) {
317                 return rrsets[i];
318             }
319         }
320
321         return null;
322     }
323
324     /**
325      * Find an "answer" RRset. This will look for RRsets in the ANSWER section
326      * that match the <qname,qtype,qclass>, taking into consideration CNAMEs.
327      * 
328      * @param qname
329      *            The starting search name.
330      * @param qtype
331      *            The search type.
332      * @param qclass
333      *            The search class.
334      * 
335      * @return a SRRset matching the query. This SRRset may have a different
336      *         name from qname, due to following a CNAME chain.
337      */
338     public SRRset findAnswerRRset(Name qname, int qtype, int qclass) {
339         SRRset[] srrsets = getSectionRRsets(Section.ANSWER);
340
341         for (int i = 0; i < srrsets.length; i++) {
342             if (srrsets[i].getName().equals(qname)
343                 && srrsets[i].getType() == Type.CNAME) {
344                 CNAMERecord cname = (CNAMERecord) srrsets[i].first();
345                 qname = cname.getTarget();
346                 continue;
347             }
348
349             if (srrsets[i].getName().equals(qname)
350                 && srrsets[i].getType() == qtype
351                 && srrsets[i].getDClass() == qclass) {
352                 return srrsets[i];
353             }
354         }
355
356         return null;
357     }
358
359 }