update dnsjava library
[captive-validator.git] / src / com / versign / tat / dnssec / SMessage.java
1 /***************************** -*- Java -*- ********************************\
2  *                                                                         *
3  *   Copyright (c) 2009 VeriSign, Inc. All rights reserved.                *
4  *                                                                         *
5  * This software is provided solely in connection with the terms of the    *
6  * license agreement.  Any other use without the prior express written     *
7  * permission of VeriSign is completely prohibited.  The software and      *
8  * documentation are "Commercial Items", as that term is defined in 48     *
9  * C.F.R.  section 2.101, consisting of "Commercial Computer Software" and *
10  * "Commercial Computer Software Documentation" as such terms are defined  *
11  * in 48 C.F.R. section 252.227-7014(a)(5) and 48 C.F.R. section           *
12  * 252.227-7014(a)(1), and used in 48 C.F.R. section 12.212 and 48 C.F.R.  *
13  * section 227.7202, as applicable.  Pursuant to the above and other       *
14  * relevant sections of the Code of Federal Regulations, as applicable,    *
15  * VeriSign's publications, commercial computer software, and commercial   *
16  * computer software documentation are distributed and licensed to United  *
17  * States Government end users with only those rights as granted to all    *
18  * other end users, according to the terms and conditions contained in the *
19  * license agreement(s) that accompany the products and software           *
20  * documentation.                                                          *
21  *                                                                         *
22 \***************************************************************************/
23
24 package com.verisign.tat.dnssec;
25
26 import org.xbill.DNS.*;
27
28 import java.util.*;
29
30
31 /**
32  * This class represents a DNS message with resolver/validator state.
33  */
34 public class SMessage {
35     private static SRRset [] empty_srrset_array = new SRRset[0];
36     private Header           mHeader;
37     private Record           mQuestion;
38     private OPTRecord        mOPTRecord;
39     private List<SRRset> []  mSection;
40     private SecurityStatus   mSecurityStatus;
41
42     @SuppressWarnings("unchecked")
43     public SMessage(Header h) {
44         mSection                                = (List<SRRset> []) new List[3];
45         mHeader                                 = h;
46         mSecurityStatus                         = new SecurityStatus();
47     }
48
49     public SMessage(int id) {
50         this(new Header(id));
51     }
52
53     public SMessage() {
54         this(new Header(0));
55     }
56
57     public SMessage(Message m) {
58         this(m.getHeader());
59         mQuestion      = m.getQuestion();
60         mOPTRecord     = m.getOPT();
61
62         for (int i = Section.ANSWER; i <= Section.ADDITIONAL; i++) {
63             RRset [] rrsets = m.getSectionRRsets(i);
64
65             for (int j = 0; j < rrsets.length; j++) {
66                 addRRset(rrsets[j], i);
67             }
68         }
69     }
70
71     public Header getHeader() {
72         return mHeader;
73     }
74
75     public void setHeader(Header h) {
76         mHeader = h;
77     }
78
79     public void setQuestion(Record r) {
80         mQuestion = r;
81     }
82
83     public Record getQuestion() {
84         return mQuestion;
85     }
86
87     public Name getQName() {
88         return getQuestion().getName();
89     }
90
91     public int getQType() {
92         return getQuestion().getType();
93     }
94
95     public int getQClass() {
96         return getQuestion().getDClass();
97     }
98
99     public void setOPT(OPTRecord r) {
100         mOPTRecord = r;
101     }
102
103     public OPTRecord getOPT() {
104         return mOPTRecord;
105     }
106
107     public List<SRRset> getSectionList(int section) {
108         if ((section <= Section.QUESTION) || (section > Section.ADDITIONAL)) {
109             throw new IllegalArgumentException("Invalid section.");
110         }
111
112         if (mSection[section - 1] == null) {
113             mSection[section - 1] = new LinkedList<SRRset>();
114         }
115
116         return (List<SRRset>) mSection[section - 1];
117     }
118
119     public void addRRset(SRRset srrset, int section) {
120         if ((section <= Section.QUESTION) || (section > Section.ADDITIONAL)) {
121             throw new IllegalArgumentException("Invalid section");
122         }
123
124         if (srrset.getType() == Type.OPT) {
125             mOPTRecord = (OPTRecord) srrset.first();
126
127             return;
128         }
129
130         List<SRRset> sectionList = getSectionList(section);
131         sectionList.add(srrset);
132     }
133
134     public void addRRset(RRset rrset, int section) {
135         if (rrset instanceof SRRset) {
136             addRRset((SRRset) rrset, section);
137
138             return;
139         }
140
141         SRRset srrset = new SRRset(rrset);
142         addRRset(srrset, section);
143     }
144
145     public void prependRRsets(List<SRRset> rrsets, int section) {
146         if ((section <= Section.QUESTION) || (section > Section.ADDITIONAL)) {
147             throw new IllegalArgumentException("Invalid section");
148         }
149
150         List<SRRset> sectionList = getSectionList(section);
151         sectionList.addAll(0, rrsets);
152     }
153
154     public SRRset [] getSectionRRsets(int section) {
155         List<SRRset> slist = getSectionList(section);
156
157         return (SRRset []) slist.toArray(empty_srrset_array);
158     }
159
160     public SRRset [] getSectionRRsets(int section, int qtype) {
161         List<SRRset> slist = getSectionList(section);
162
163         if (slist.size() == 0) {
164             return new SRRset[0];
165         }
166
167         ArrayList<SRRset> result = new ArrayList<SRRset>(slist.size());
168
169         for (SRRset rrset : slist) {
170             if (rrset.getType() == qtype) {
171                 result.add(rrset);
172             }
173         }
174
175         return (SRRset []) result.toArray(empty_srrset_array);
176     }
177
178     public void deleteRRset(SRRset rrset, int section) {
179         List<SRRset> slist = getSectionList(section);
180
181         if (slist.size() == 0) {
182             return;
183         }
184
185         slist.remove(rrset);
186     }
187
188     public void clear(int section) {
189         if ((section < Section.QUESTION) || (section > Section.ADDITIONAL)) {
190             throw new IllegalArgumentException("Invalid section.");
191         }
192
193         if (section == Section.QUESTION) {
194             mQuestion = null;
195
196             return;
197         }
198
199         if (section == Section.ADDITIONAL) {
200             mOPTRecord = null;
201         }
202
203         mSection[section - 1] = null;
204     }
205
206     public void clear() {
207         for (int s = Section.QUESTION; s <= Section.ADDITIONAL; s++) {
208             clear(s);
209         }
210     }
211
212     public int getRcode() {
213         // FIXME: might want to do what Message does and handle extended rcodes.
214         return mHeader.getRcode();
215     }
216
217     public int getStatus() {
218         return mSecurityStatus.getStatus();
219     }
220
221     public void setStatus(byte status) {
222         mSecurityStatus.setStatus(status);
223     }
224
225     public SecurityStatus getSecurityStatus() {
226         return mSecurityStatus;
227     }
228
229     public void setSecurityStatus(SecurityStatus s) {
230         if (s == null) {
231             return;
232         }
233
234         mSecurityStatus = s;
235     }
236
237     public Message getMessage() {
238         // Generate our new message.
239         Message m = new Message(mHeader.getID());
240
241         // Convert the header
242         // We do this for two reasons: 1) setCount() is package scope, so we
243         // can't do that, and 2) setting the header on a message after creating
244         // the message frequently gets stuff out of sync, leading to malformed
245         // wire format messages.
246         Header h = m.getHeader();
247         h.setOpcode(mHeader.getOpcode());
248         h.setRcode(mHeader.getRcode());
249
250         for (int i = 0; i < 16; i++) {
251             if (Flags.isFlag(i)) {
252                 if (mHeader.getFlag(i)) {
253                     h.setFlag(i);
254                 } else {
255                     h.unsetFlag(i);
256                 }
257             }
258         }
259
260         // Add all the records. -- this will set the counts correctly in the
261         // message header.
262         if (mQuestion != null) {
263             m.addRecord(mQuestion, Section.QUESTION);
264         }
265
266         for (int sec = Section.ANSWER; sec <= Section.ADDITIONAL; sec++) {
267             List<SRRset> slist = getSectionList(sec);
268
269             for (SRRset rrset : slist) {
270                 for (Iterator<Record> j = rrset.rrs(); j.hasNext();) {
271                     m.addRecord(j.next(), sec);
272                 }
273
274                 for (Iterator<RRSIGRecord> j = rrset.sigs(); j.hasNext();) {
275                     m.addRecord(j.next(), sec);
276                 }
277             }
278         }
279
280         if (mOPTRecord != null) {
281             m.addRecord(mOPTRecord, Section.ADDITIONAL);
282         }
283
284         return m;
285     }
286
287     public int getCount(int section) {
288         if (section == Section.QUESTION) {
289             return (mQuestion == null) ? 0 : 1;
290         }
291
292         List<SRRset> sectionList = getSectionList(section);
293
294         if (sectionList == null) {
295             return 0;
296         }
297
298         if (sectionList.size() == 0) {
299             return 0;
300         }
301
302         int count = 0;
303
304         for (SRRset sr : sectionList) {
305             count += sr.totalSize();
306         }
307
308         return count;
309     }
310
311     public String toString() {
312         return getMessage().toString();
313     }
314
315     /**
316      * Find a specific (S)RRset in a given section.
317      *
318      * @param name
319      *            the name of the RRset.
320      * @param type
321      *            the type of the RRset.
322      * @param dclass
323      *            the class of the RRset.
324      * @param section
325      *            the section to look in (ANSWER -> ADDITIONAL)
326      *
327      * @return The SRRset if found, null otherwise.
328      */
329     public SRRset findRRset(Name name, int type, int dclass, int section) {
330         if ((section <= Section.QUESTION) || (section > Section.ADDITIONAL)) {
331             throw new IllegalArgumentException("Invalid section.");
332         }
333
334         SRRset [] rrsets = getSectionRRsets(section);
335
336         for (int i = 0; i < rrsets.length; i++) {
337             if (rrsets[i].getName().equals(name) &&
338                     (rrsets[i].getType() == type) &&
339                     (rrsets[i].getDClass() == dclass)) {
340                 return rrsets[i];
341             }
342         }
343
344         return null;
345     }
346
347     /**
348      * Find an "answer" RRset. This will look for RRsets in the ANSWER section
349      * that match the <qname,qtype,qclass>, taking into consideration CNAMEs.
350      *
351      * @param qname
352      *            The starting search name.
353      * @param qtype
354      *            The search type.
355      * @param qclass
356      *            The search class.
357      *
358      * @return a SRRset matching the query. This SRRset may have a different
359      *         name from qname, due to following a CNAME chain.
360      */
361     public SRRset findAnswerRRset(Name qname, int qtype, int qclass) {
362         SRRset [] srrsets = getSectionRRsets(Section.ANSWER);
363
364         for (int i = 0; i < srrsets.length; i++) {
365             if (srrsets[i].getName().equals(qname) &&
366                     (srrsets[i].getType() == Type.CNAME)) {
367                 CNAMERecord cname = (CNAMERecord) srrsets[i].first();
368                 qname = cname.getTarget();
369
370                 continue;
371             }
372
373             if (srrsets[i].getName().equals(qname) &&
374                     (srrsets[i].getType() == qtype) &&
375                     (srrsets[i].getDClass() == qclass)) {
376                 return srrsets[i];
377             }
378         }
379
380         return null;
381     }
382 }