/*
* Copyright 2011 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package net.java.otr4j.session;
import static org.junit.Assert.*;
import java.util.List;
import java.util.Properties;
import javax.crypto.interfaces.DHPublicKey;
import net.java.otr4j.OtrException;
import net.java.otr4j.OtrKeyManagerImpl;
import net.java.otr4j.OtrKeyManagerStore;
import net.java.otr4j.crypto.SM;
import net.java.otr4j.crypto.SM.SMException;
import net.java.otr4j.session.OtrSm;
import net.java.otr4j.session.OtrSm.OtrSmEngineHost;
import net.java.otr4j.session.SessionID;
import net.java.otr4j.session.TLV;
import org.easymock.EasyMock;
import org.easymock.EasyMockSupport;
import org.jivesoftware.smack.util.Base64;
import org.junit.Before;
import org.junit.Test;
public class OtrSmTest extends EasyMockSupport {
class MemoryPropertiesStore implements OtrKeyManagerStore {
private Properties properties = new Properties();
public MemoryPropertiesStore() {
}
public void setProperty(String id, boolean value) {
properties.setProperty(id, "true");
}
public void setProperty(String id, byte[] value) {
properties.setProperty(id, new String(Base64.encodeBytes(value)));
}
public void removeProperty(String id) {
properties.remove(id);
}
public byte[] getPropertyBytes(String id) {
String value = properties.getProperty(id);
if (value != null)
return Base64.decode(value);
return null;
}
public boolean getPropertyBoolean(String id, boolean defaultValue) {
try {
return Boolean.valueOf(properties.get(id).toString());
} catch (Exception e) {
return defaultValue;
}
}
}
OtrSm sm_a;
OtrSm sm_b;
private OtrKeyManagerImpl manager_a;
private OtrKeyManagerImpl manager_b;
private SessionID sessionId_a;
private SessionID sessionId_b;
private OtrSmEngineHost host_a;
private OtrSmEngineHost host_b;
private Session session_a;
private Session session_b;
@Before
public void setUp() throws Exception {
manager_a = new OtrKeyManagerImpl(new MemoryPropertiesStore());
manager_b = new OtrKeyManagerImpl(new MemoryPropertiesStore());
session_a = createMock(Session.class);
session_b = createMock(Session.class);
AuthContextImpl ca = new AuthContextImpl(session_a);
AuthContextImpl cb = new AuthContextImpl(session_b);
ca.setRemoteDHPublicKey((DHPublicKey) cb.getLocalDHKeyPair().getPublic());
cb.setRemoteDHPublicKey((DHPublicKey) ca.getLocalDHKeyPair().getPublic());
EasyMock.expect(session_a.getS()).andStubReturn(ca.getS());
EasyMock.expect(session_b.getS()).andStubReturn(cb.getS());
sessionId_a = new SessionID("a1", "ua", "xmpp");
sessionId_b = new SessionID("a1", "ub", "xmpp");
manager_a.generateLocalKeyPair(sessionId_a);
manager_b.generateLocalKeyPair(sessionId_b);
manager_a.savePublicKey(sessionId_a, manager_b.loadLocalKeyPair(sessionId_b).getPublic());
manager_b.savePublicKey(sessionId_b, manager_a.loadLocalKeyPair(sessionId_a).getPublic());
host_a = createNiceMock(OtrSmEngineHost.class);
host_b = createNiceMock(OtrSmEngineHost.class);
sm_a = new OtrSm(session_a, manager_a, sessionId_a, host_a);
sm_b = new OtrSm(session_b, manager_b, sessionId_b, host_b);
}
@Test
public void testSuccess() throws Exception {
replayAll();
List<TLV> tlvs = sm_a.initRespondSmp(null, "xyz", true);
assertEquals(SM.EXPECT2, sm_a.smstate.nextExpected);
assertEquals(1, tlvs.size());
runMiddleOfProtocol(tlvs);
assertTrue(manager_b.isVerified(sessionId_b));
assertTrue(manager_a.isVerified(sessionId_a));
}
@Test
public void testSuccess_question() throws Exception {
replayAll();
List<TLV> tlvs = sm_a.initRespondSmp("qqq", "xyz", true);
assertEquals(SM.EXPECT2, sm_a.smstate.nextExpected);
assertEquals(1, tlvs.size());
runMiddleOfProtocol(tlvs);
assertTrue(manager_b.isVerified(sessionId_b));
assertTrue(manager_a.isVerified(sessionId_a));
}
@Test
public void testFailure() throws Exception {
replayAll();
List<TLV> tlvs = sm_a.initRespondSmp(null, "abc", true);
assertEquals(SM.EXPECT2, sm_a.smstate.nextExpected);
assertEquals(1, tlvs.size());
runMiddleOfProtocol(tlvs);
assertFalse(manager_b.isVerified(sessionId_b));
assertFalse(manager_a.isVerified(sessionId_a));
}
@Test
public void testFailure_question() throws Exception {
replayAll();
List<TLV> tlvs = sm_a.initRespondSmp("qqq", "abc", true);
assertEquals(SM.EXPECT2, sm_a.smstate.nextExpected);
assertEquals(1, tlvs.size());
runMiddleOfProtocol(tlvs);
assertFalse(manager_b.isVerified(sessionId_b));
assertFalse(manager_a.isVerified(sessionId_a));
}
private void runMiddleOfProtocol(List<TLV> tlvs) throws SMException, OtrException {
sm_b.processTlv(tlvs.get(0));
assertEquals(SM.EXPECT1, sm_b.smstate.nextExpected);
assertNull(sm_b.getPendingTlvs());
tlvs = sm_b.initRespondSmp(null, "xyz", false);
assertEquals(SM.EXPECT3, sm_b.smstate.nextExpected);
assertEquals(1, tlvs.size());
sm_a.processTlv(tlvs.get(0));
assertEquals(SM.EXPECT4, sm_a.smstate.nextExpected);
assertEquals(1, sm_a.getPendingTlvs().size());
assertFalse(manager_a.isVerified(sessionId_a));
assertFalse(manager_b.isVerified(sessionId_b));
sm_b.processTlv(sm_a.getPendingTlvs().get(0));
assertEquals(SM.EXPECT1, sm_b.smstate.nextExpected);
assertEquals(1, sm_b.getPendingTlvs().size());
sm_a.processTlv(sm_b.getPendingTlvs().get(0));
assertEquals(SM.EXPECT1, sm_a.smstate.nextExpected);
assertNull(sm_a.getPendingTlvs());
}
}