/*
* dnssecjava - a DNSSEC validating stub resolver for Java
* Copyright (c) 2013-2015 Ingo Bauersachs
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*/
package org.jitsi.dnssec;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.LogManager;
import org.jitsi.dnssec.validator.ValidatingResolver;
import org.joda.time.DateTime;
import org.joda.time.format.ISODateTimeFormat;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.rules.TestRule;
import org.junit.rules.TestWatcher;
import org.junit.runner.Description;
import org.junit.runner.RunWith;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.powermock.reflect.Whitebox;
import org.xbill.DNS.ARecord;
import org.xbill.DNS.DClass;
import org.xbill.DNS.DNSSEC;
import org.xbill.DNS.DNSSEC.DNSSECException;
import org.xbill.DNS.Master;
import org.xbill.DNS.Message;
import org.xbill.DNS.Name;
import org.xbill.DNS.RRset;
import org.xbill.DNS.Record;
import org.xbill.DNS.Section;
import org.xbill.DNS.SimpleResolver;
import org.xbill.DNS.TXTRecord;
import org.xbill.DNS.Type;
import static org.powermock.api.mockito.PowerMockito.whenNew;
@RunWith(PowerMockRunner.class)
@PrepareForTest({DNSSEC.class, TestInvalid.class})
public abstract class TestBase {
private final static boolean offline = !Boolean.getBoolean("org.jitsi.dnssecjava.online");
private final static boolean partialOffline = "partial".equals(System.getProperty("org.jitsi.dnssecjava.offline"));
private final static boolean record = Boolean.getBoolean("org.jitsi.dnssecjava.record");
private boolean unboundTest = false;
private boolean alwaysOffline = false;
private Map<String, Message> queryResponsePairs = new HashMap<String, Message>();
private MessageReader messageReader = new MessageReader();
private FileWriter w;
private BufferedReader r;
protected final static String localhost = "127.0.0.1";
protected ValidatingResolver resolver;
protected String testName;
static {
try {
LogManager.getLogManager().readConfiguration(TestBase.class.getResourceAsStream("logging.properties"));
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
@Rule
public TestRule watcher = new TestWatcher() {
@Override
protected void starting(Description description) {
unboundTest = false;
testName = description.getMethodName();
try {
// do not record or process unbound unit tests offline
alwaysOffline = description.getAnnotation(AlwaysOffline.class) != null;
if (description.getClassName().contains("unbound")) {
unboundTest = true;
return;
}
String filename = "/recordings/" + description.getClassName().replace(".", "_") + "/" + testName;
File f = new File("./src/test/resources" + filename);
if ((record || !f.exists()) && !alwaysOffline) {
f.getParentFile().getParentFile().mkdir();
f.getParentFile().mkdir();
w = new FileWriter(f.getAbsoluteFile());
w.write("#Date: " + new DateTime().toString(ISODateTimeFormat.dateTimeNoMillis()));
w.write("\n");
}
else if (offline || partialOffline || alwaysOffline) {
PrepareMocks pm = description.getAnnotation(PrepareMocks.class);
if (pm != null) {
Whitebox.invokeMethod(TestBase.this, pm.value());
}
InputStream stream = getClass().getResourceAsStream(filename);
if (stream != null) {
r = new BufferedReader(new InputStreamReader(stream));
long millis = DateTime.parse(r.readLine().substring("#Date: ".length()), ISODateTimeFormat.dateTimeNoMillis()).getMillis();
whenNew(Date.class).withNoArguments().thenReturn(new Date(millis));
whenNew(Date.class).withArguments(Mockito.anyLong()).thenAnswer(new Answer<Date>(){
@Override
public Date answer(InvocationOnMock invocationOnMock) throws Throwable {
return new Date((Long)invocationOnMock.getArguments()[0]);
}
});
Message m;
while ((m = messageReader.readMessage(r)) != null) {
queryResponsePairs.put(key(m), m);
}
r.close();
}
}
}
catch (Exception e) {
System.err.println(e);
throw new RuntimeException(e);
}
}
@Override
protected void finished(Description description) {
try {
if (record && w != null) {
w.flush();
w.close();
w = null;
}
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
};
@BeforeClass
public static void setupClass() {
R.setBundle(null);
R.setUseNeutralMessages(true);
}
@Before
public void setup() throws NumberFormatException, IOException, DNSSECException {
resolver = new ValidatingResolver(new SimpleResolver("62.192.5.131") {
@Override
public Message send(Message query) throws IOException {
System.out.println("---" + key(query));
Message response = queryResponsePairs.get(key(query));
if (response != null) {
return response;
}
else if ((offline && !partialOffline) || unboundTest || alwaysOffline) {
Assert.fail("Response for " + key(query) + " not found.");
}
Message networkResult = super.send(query);
if (record) {
w.write(networkResult.toString());
w.write("\n\n###############################################\n\n");
}
return networkResult;
}
});
resolver.loadTrustAnchors(getClass().getResourceAsStream("/trust_anchors"));
System.err.println("--------------");
}
protected void add(Message m) throws IOException {
this.add(key(m), m, true);
}
protected void add(String query, Message response) throws IOException {
this.add(query, response, true);
}
protected void add(String query, Message response, boolean clear) throws IOException {
queryResponsePairs.put(query, messageFromString(response.toString()));
// reset the resolver so any cached stuff is cleared
if (!clear) {
return;
}
try {
setup();
}
catch (NumberFormatException e) {
throw new IOException(e);
}
catch (DNSSECException e) {
throw new IOException(e);
}
}
protected Message get(Name target, int type) {
return queryResponsePairs.get(key(target, type));
}
protected void clear() {
queryResponsePairs.clear();
}
protected Message createMessage(String query) throws IOException {
return Message.newQuery(Record.newRecord(Name.fromString(query.split("/")[0]), Type.value(query.split("/")[1]), DClass.IN));
}
protected Message messageFromString(String message) throws IOException {
return messageReader.readMessage(new StringReader(message));
}
@SuppressWarnings("unchecked")
protected String firstA(Message response) {
RRset[] sectionRRsets = response.getSectionRRsets(Section.ANSWER);
if (sectionRRsets.length > 0) {
Iterator<Record> rrs = sectionRRsets[0].rrs();
while (rrs.hasNext()) {
Record r = rrs.next();
if (r.getType() == Type.A) {
return ((ARecord)r).getAddress().getHostAddress();
}
}
}
return null;
}
protected String getReason(Message m) {
for (RRset set : m.getSectionRRsets(Section.ADDITIONAL)) {
if (set.getName().equals(Name.root) && set.getType() == Type.TXT && set.getDClass() == ValidatingResolver.VALIDATION_REASON_QCLASS) {
StringBuilder sb = new StringBuilder();
@SuppressWarnings("unchecked")
List<String> strings = (List<String>)((TXTRecord)set.first()).getStrings();
for (String part : strings){
sb.append(part);
}
return sb.toString();
}
}
return null;
}
protected boolean isEmptyAnswer(Message response) {
RRset[] sectionRRsets = response.getSectionRRsets(Section.ANSWER);
return sectionRRsets.length == 0;
}
private String key(Name n, int t) {
return n + "/" + Type.string(t);
}
private String key(Record r) {
return key(r.getName(), r.getType());
}
private String key(Message m) {
return key(m.getQuestion());
}
protected Record toRecord(String data){
try {
InputStream in = new ByteArrayInputStream(data.getBytes("UTF-8"));
Master m = new Master(in, Name.root);
return m._nextRecord();
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
}