somewhat gratutious reformatting by eclipse
[captive-validator.git] / src / com / verisign / tat / dnssec / SMessage.java
1 /***************************** -*- Java -*- ********************************\
2  *                                                                         *
3  *   Copyright (c) 2009 VeriSign, Inc. All rights reserved.                *
4  *                                                                         *
5  * This software is provided solely in connection with the terms of the    *
6  * license agreement.  Any other use without the prior express written     *
7  * permission of VeriSign is completely prohibited.  The software and      *
8  * documentation are "Commercial Items", as that term is defined in 48     *
9  * C.F.R.  section 2.101, consisting of "Commercial Computer Software" and *
10  * "Commercial Computer Software Documentation" as such terms are defined  *
11  * in 48 C.F.R. section 252.227-7014(a)(5) and 48 C.F.R. section           *
12  * 252.227-7014(a)(1), and used in 48 C.F.R. section 12.212 and 48 C.F.R.  *
13  * section 227.7202, as applicable.  Pursuant to the above and other       *
14  * relevant sections of the Code of Federal Regulations, as applicable,    *
15  * VeriSign's publications, commercial computer software, and commercial   *
16  * computer software documentation are distributed and licensed to United  *
17  * States Government end users with only those rights as granted to all    *
18  * other end users, according to the terms and conditions contained in the *
19  * license agreement(s) that accompany the products and software           *
20  * documentation.                                                          *
21  *                                                                         *
22 \***************************************************************************/
23
24 package com.verisign.tat.dnssec;
25
26 import org.xbill.DNS.*;
27
28 import java.util.*;
29
30 /**
31  * This class represents a DNS message with resolver/validator state.
32  */
33 public class SMessage {
34     private static SRRset[] empty_srrset_array = new SRRset[0];
35     private Header mHeader;
36     private Record mQuestion;
37     private OPTRecord mOPTRecord;
38     private List<SRRset>[] mSection;
39     private SecurityStatus mSecurityStatus;
40
41     @SuppressWarnings("unchecked")
42     public SMessage(Header h) {
43         mSection = (List<SRRset>[]) new List[3];
44         mHeader = h;
45         mSecurityStatus = new SecurityStatus();
46     }
47
48     public SMessage(int id) {
49         this(new Header(id));
50     }
51
52     public SMessage() {
53         this(new Header(0));
54     }
55
56     public SMessage(Message m) {
57         this(m.getHeader());
58         mQuestion = m.getQuestion();
59         mOPTRecord = m.getOPT();
60
61         for (int i = Section.ANSWER; i <= Section.ADDITIONAL; i++) {
62             RRset[] rrsets = m.getSectionRRsets(i);
63
64             for (int j = 0; j < rrsets.length; j++) {
65                 addRRset(rrsets[j], i);
66             }
67         }
68     }
69
70     public Header getHeader() {
71         return mHeader;
72     }
73
74     public void setHeader(Header h) {
75         mHeader = h;
76     }
77
78     public void setQuestion(Record r) {
79         mQuestion = r;
80     }
81
82     public Record getQuestion() {
83         return mQuestion;
84     }
85
86     public Name getQName() {
87         return getQuestion().getName();
88     }
89
90     public int getQType() {
91         return getQuestion().getType();
92     }
93
94     public int getQClass() {
95         return getQuestion().getDClass();
96     }
97
98     public void setOPT(OPTRecord r) {
99         mOPTRecord = r;
100     }
101
102     public OPTRecord getOPT() {
103         return mOPTRecord;
104     }
105
106     public List<SRRset> getSectionList(int section) {
107         if ((section <= Section.QUESTION) || (section > Section.ADDITIONAL)) {
108             throw new IllegalArgumentException("Invalid section.");
109         }
110
111         if (mSection[section - 1] == null) {
112             mSection[section - 1] = new LinkedList<SRRset>();
113         }
114
115         return (List<SRRset>) mSection[section - 1];
116     }
117
118     public void addRRset(SRRset srrset, int section) {
119         if ((section <= Section.QUESTION) || (section > Section.ADDITIONAL)) {
120             throw new IllegalArgumentException("Invalid section");
121         }
122
123         if (srrset.getType() == Type.OPT) {
124             mOPTRecord = (OPTRecord) srrset.first();
125
126             return;
127         }
128
129         List<SRRset> sectionList = getSectionList(section);
130         sectionList.add(srrset);
131     }
132
133     public void addRRset(RRset rrset, int section) {
134         if (rrset instanceof SRRset) {
135             addRRset((SRRset) rrset, section);
136
137             return;
138         }
139
140         SRRset srrset = new SRRset(rrset);
141         addRRset(srrset, section);
142     }
143
144     public void prependRRsets(List<SRRset> rrsets, int section) {
145         if ((section <= Section.QUESTION) || (section > Section.ADDITIONAL)) {
146             throw new IllegalArgumentException("Invalid section");
147         }
148
149         List<SRRset> sectionList = getSectionList(section);
150         sectionList.addAll(0, rrsets);
151     }
152
153     public SRRset[] getSectionRRsets(int section) {
154         List<SRRset> slist = getSectionList(section);
155
156         return (SRRset[]) slist.toArray(empty_srrset_array);
157     }
158
159     public SRRset[] getSectionRRsets(int section, int qtype) {
160         List<SRRset> slist = getSectionList(section);
161
162         if (slist.size() == 0) {
163             return new SRRset[0];
164         }
165
166         ArrayList<SRRset> result = new ArrayList<SRRset>(slist.size());
167
168         for (SRRset rrset : slist) {
169             if (rrset.getType() == qtype) {
170                 result.add(rrset);
171             }
172         }
173
174         return (SRRset[]) result.toArray(empty_srrset_array);
175     }
176
177     public void deleteRRset(SRRset rrset, int section) {
178         List<SRRset> slist = getSectionList(section);
179
180         if (slist.size() == 0) {
181             return;
182         }
183
184         slist.remove(rrset);
185     }
186
187     public void clear(int section) {
188         if ((section < Section.QUESTION) || (section > Section.ADDITIONAL)) {
189             throw new IllegalArgumentException("Invalid section.");
190         }
191
192         if (section == Section.QUESTION) {
193             mQuestion = null;
194
195             return;
196         }
197
198         if (section == Section.ADDITIONAL) {
199             mOPTRecord = null;
200         }
201
202         mSection[section - 1] = null;
203     }
204
205     public void clear() {
206         for (int s = Section.QUESTION; s <= Section.ADDITIONAL; s++) {
207             clear(s);
208         }
209     }
210
211     public int getRcode() {
212         // FIXME: might want to do what Message does and handle extended rcodes.
213         return mHeader.getRcode();
214     }
215
216     public int getStatus() {
217         return mSecurityStatus.getStatus();
218     }
219
220     public void setStatus(byte status) {
221         mSecurityStatus.setStatus(status);
222     }
223
224     public SecurityStatus getSecurityStatus() {
225         return mSecurityStatus;
226     }
227
228     public void setSecurityStatus(SecurityStatus s) {
229         if (s == null) {
230             return;
231         }
232
233         mSecurityStatus = s;
234     }
235
236     public Message getMessage() {
237         // Generate our new message.
238         Message m = new Message(mHeader.getID());
239
240         // Convert the header
241         // We do this for two reasons: 1) setCount() is package scope, so we
242         // can't do that, and 2) setting the header on a message after creating
243         // the message frequently gets stuff out of sync, leading to malformed
244         // wire format messages.
245         Header h = m.getHeader();
246         h.setOpcode(mHeader.getOpcode());
247         h.setRcode(mHeader.getRcode());
248
249         for (int i = 0; i < 16; i++) {
250             if (Flags.isFlag(i)) {
251                 if (mHeader.getFlag(i)) {
252                     h.setFlag(i);
253                 } else {
254                     h.unsetFlag(i);
255                 }
256             }
257         }
258
259         // Add all the records. -- this will set the counts correctly in the
260         // message header.
261         if (mQuestion != null) {
262             m.addRecord(mQuestion, Section.QUESTION);
263         }
264
265         for (int sec = Section.ANSWER; sec <= Section.ADDITIONAL; sec++) {
266             List<SRRset> slist = getSectionList(sec);
267
268             for (SRRset rrset : slist) {
269                 for (Iterator<Record> j = rrset.rrs(); j.hasNext();) {
270                     m.addRecord(j.next(), sec);
271                 }
272
273                 for (Iterator<RRSIGRecord> j = rrset.sigs(); j.hasNext();) {
274                     m.addRecord(j.next(), sec);
275                 }
276             }
277         }
278
279         if (mOPTRecord != null) {
280             m.addRecord(mOPTRecord, Section.ADDITIONAL);
281         }
282
283         return m;
284     }
285
286     public int getCount(int section) {
287         if (section == Section.QUESTION) {
288             return (mQuestion == null) ? 0 : 1;
289         }
290
291         List<SRRset> sectionList = getSectionList(section);
292
293         if (sectionList == null) {
294             return 0;
295         }
296
297         if (sectionList.size() == 0) {
298             return 0;
299         }
300
301         int count = 0;
302
303         for (SRRset sr : sectionList) {
304             count += sr.totalSize();
305         }
306
307         return count;
308     }
309
310     public String toString() {
311         return getMessage().toString();
312     }
313
314     /**
315      * Find a specific (S)RRset in a given section.
316      * 
317      * @param name
318      *            the name of the RRset.
319      * @param type
320      *            the type of the RRset.
321      * @param dclass
322      *            the class of the RRset.
323      * @param section
324      *            the section to look in (ANSWER -> ADDITIONAL)
325      * 
326      * @return The SRRset if found, null otherwise.
327      */
328     public SRRset findRRset(Name name, int type, int dclass, int section) {
329         if ((section <= Section.QUESTION) || (section > Section.ADDITIONAL)) {
330             throw new IllegalArgumentException("Invalid section.");
331         }
332
333         SRRset[] rrsets = getSectionRRsets(section);
334
335         for (int i = 0; i < rrsets.length; i++) {
336             if (rrsets[i].getName().equals(name)
337                     && (rrsets[i].getType() == type)
338                     && (rrsets[i].getDClass() == dclass)) {
339                 return rrsets[i];
340             }
341         }
342
343         return null;
344     }
345
346     /**
347      * Find an "answer" RRset. This will look for RRsets in the ANSWER section
348      * that match the <qname,qtype,qclass>, taking into consideration CNAMEs.
349      * 
350      * @param qname
351      *            The starting search name.
352      * @param qtype
353      *            The search type.
354      * @param qclass
355      *            The search class.
356      * 
357      * @return a SRRset matching the query. This SRRset may have a different
358      *         name from qname, due to following a CNAME chain.
359      */
360     public SRRset findAnswerRRset(Name qname, int qtype, int qclass) {
361         SRRset[] srrsets = getSectionRRsets(Section.ANSWER);
362
363         for (int i = 0; i < srrsets.length; i++) {
364             if (srrsets[i].getName().equals(qname)
365                     && (srrsets[i].getType() == Type.CNAME)) {
366                 CNAMERecord cname = (CNAMERecord) srrsets[i].first();
367                 qname = cname.getTarget();
368
369                 continue;
370             }
371
372             if (srrsets[i].getName().equals(qname)
373                     && (srrsets[i].getType() == qtype)
374                     && (srrsets[i].getDClass() == qclass)) {
375                 return srrsets[i];
376             }
377         }
378
379         return null;
380     }
381 }