/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sshd.client.keyverifier; import java.io.File; import java.io.IOException; import java.net.SocketAddress; import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; import java.security.KeyPair; import java.security.PublicKey; import java.util.Collection; import java.util.Collections; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.atomic.AtomicInteger; import org.apache.sshd.client.ClientFactoryManager; import org.apache.sshd.client.config.hosts.KnownHostEntry; import org.apache.sshd.client.config.hosts.KnownHostHashValue; import org.apache.sshd.client.keyverifier.KnownHostsServerKeyVerifier.HostEntryPair; import org.apache.sshd.client.session.ClientSession; import org.apache.sshd.common.NamedFactory; import org.apache.sshd.common.config.keys.AuthorizedKeyEntry; import org.apache.sshd.common.config.keys.KeyUtils; import org.apache.sshd.common.config.keys.PublicKeyEntry; import org.apache.sshd.common.config.keys.PublicKeyEntryResolver; import org.apache.sshd.common.mac.Mac; import org.apache.sshd.common.random.JceRandomFactory; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.net.SshdSocketAddress; import org.apache.sshd.util.test.BaseTestSupport; import org.apache.sshd.util.test.Utils; import org.junit.BeforeClass; import org.junit.FixMethodOrder; import org.junit.Test; import org.junit.runners.MethodSorters; import org.mockito.Mockito; /** * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> */ @FixMethodOrder(MethodSorters.NAME_ASCENDING) public class KnownHostsServerKeyVerifierTest extends BaseTestSupport { private static final String HASHED_HOST = "192.168.1.61"; private static final Map<String, PublicKey> HOST_KEYS = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); private static Map<String, KnownHostEntry> hostsEntries; private static Path entriesFile; public KnownHostsServerKeyVerifierTest() { super(); } @BeforeClass public static void loadHostsEntries() throws Exception { URL url = KnownHostsServerKeyVerifierTest.class.getResource(KnownHostEntry.STD_HOSTS_FILENAME); assertNotNull("Missing test file resource", url); entriesFile = new File(url.toURI()).toPath(); outputDebugMessage("loadHostsEntries(%s)", entriesFile); hostsEntries = loadEntries(entriesFile); // Cannot use forEach because of the potential IOException/GeneralSecurityException being thrown for (Map.Entry<String, KnownHostEntry> ke : hostsEntries.entrySet()) { String host = ke.getKey(); KnownHostEntry entry = ke.getValue(); AuthorizedKeyEntry authEntry = ValidateUtils.checkNotNull(entry.getKeyEntry(), "No key extracted from %s", entry); PublicKey key = authEntry.resolvePublicKey(PublicKeyEntryResolver.FAILING); assertNull("Multiple keys for host=" + host, HOST_KEYS.put(host, key)); } } @Test public void testNoUpdatesNoNewHostsAuthentication() throws Exception { final AtomicInteger delegateCount = new AtomicInteger(0); ServerKeyVerifier delegate = (clientSession, remoteAddress, serverKey) -> { delegateCount.incrementAndGet(); fail("verifyServerKey(" + clientSession + ")[" + remoteAddress + "] unexpected invocation"); return false; }; final AtomicInteger updateCount = new AtomicInteger(0); ServerKeyVerifier verifier = new KnownHostsServerKeyVerifier(delegate, createKnownHostsCopy()) { @Override protected KnownHostEntry updateKnownHostsFile( ClientSession clientSession, SocketAddress remoteAddress, PublicKey serverKey, Path file, Collection<HostEntryPair> knownHosts) throws Exception { updateCount.incrementAndGet(); fail("updateKnownHostsFile(" + clientSession + ")[" + remoteAddress + "] unexpected invocation: " + file); return super.updateKnownHostsFile(clientSession, remoteAddress, serverKey, file, knownHosts); } }; HOST_KEYS.forEach((host, serverKey) -> { KnownHostEntry entry = hostsEntries.get(host); outputDebugMessage("Verify host=%s", entry); assertTrue("Failed to verify server=" + entry, invokeVerifier(verifier, host, serverKey)); assertEquals("Unexpected delegate invocation for host=" + entry, 0, delegateCount.get()); assertEquals("Unexpected update invocation for host=" + entry, 0, updateCount.get()); }); } @Test public void testFileUpdatedOnEveryNewHost() throws Exception { final AtomicInteger delegateCount = new AtomicInteger(0); ServerKeyVerifier delegate = (clientSession, remoteAddress, serverKey) -> { delegateCount.incrementAndGet(); return true; }; Path path = getKnownHostCopyPath(); Files.deleteIfExists(path); final AtomicInteger updateCount = new AtomicInteger(0); ServerKeyVerifier verifier = new KnownHostsServerKeyVerifier(delegate, path) { @Override protected KnownHostEntry updateKnownHostsFile( ClientSession clientSession, SocketAddress remoteAddress, PublicKey serverKey, Path file, Collection<HostEntryPair> knownHosts) throws Exception { updateCount.incrementAndGet(); return super.updateKnownHostsFile(clientSession, remoteAddress, serverKey, file, knownHosts); } }; int verificationCount = 0; // Cannot use forEach because the verification count variable is not effectively final for (Map.Entry<String, PublicKey> ke : HOST_KEYS.entrySet()) { String host = ke.getKey(); PublicKey serverKey = ke.getValue(); KnownHostEntry entry = hostsEntries.get(host); outputDebugMessage("Verify host=%s", entry); assertTrue("Failed to verify server=" + entry, invokeVerifier(verifier, host, serverKey)); verificationCount++; assertEquals("Mismatched number of delegate counts for server=" + entry, verificationCount, delegateCount.get()); assertEquals("Mismatched number of update counts for server=" + entry, verificationCount, updateCount.get()); } // make sure we have all the original entries and ONLY them Map<String, KnownHostEntry> updatedEntries = loadEntries(path); hostsEntries.forEach((host, expected) -> { KnownHostEntry actual = updatedEntries.remove(host); assertNotNull("No updated entry for host=" + host, actual); String expLine = expected.getConfigLine(); // if original is a list or hashed then replace them with the expected host if ((expLine.indexOf(',') > 0) || (expLine.indexOf(KnownHostHashValue.HASHED_HOST_DELIMITER) >= 0)) { int pos = expLine.indexOf(' '); expLine = host + expLine.substring(pos); } int pos = expLine.indexOf("comment-"); if (pos > 0) { expLine = expLine.substring(0, pos).trim(); } assertEquals("Mismatched entry data for host=" + host, expLine, actual.getConfigLine()); }); assertTrue("Unexpected extra updated hosts: " + updatedEntries, updatedEntries.isEmpty()); } @Test public void testWriteHashedHostValues() throws Exception { Path path = getKnownHostCopyPath(); Files.deleteIfExists(path); KnownHostsServerKeyVerifier verifier = new KnownHostsServerKeyVerifier(AcceptAllServerKeyVerifier.INSTANCE, path) { @Override protected NamedFactory<Mac> getHostValueDigester(ClientSession clientSession, SocketAddress remoteAddress, String hostIdentity) { return KnownHostHashValue.DEFAULT_DIGEST; } }; ClientFactoryManager manager = Mockito.mock(ClientFactoryManager.class); Mockito.when(manager.getRandomFactory()).thenReturn(JceRandomFactory.INSTANCE); ClientSession session = Mockito.mock(ClientSession.class); Mockito.when(session.getFactoryManager()).thenReturn(manager); HOST_KEYS.forEach((host, serverKey) -> { KnownHostEntry entry = hostsEntries.get(host); outputDebugMessage("Write host=%s", entry); SocketAddress address = new SshdSocketAddress(host, 7365); Mockito.when(session.getConnectAddress()).thenReturn(address); assertTrue("Failed to validate server=" + entry, verifier.verifyServerKey(session, address, serverKey)); }); // force re-read to ensure all values are hashed Collection<HostEntryPair> keys = verifier.reloadKnownHosts(path); for (HostEntryPair ke : keys) { KnownHostEntry entry = ke.getHostEntry(); assertNotNull("No hashing for entry=" + entry, entry.getHashedEntry()); } verifier.setLoadedHostsEntries(keys); // make sure can still validate the original hosts HOST_KEYS.forEach((host, serverKey) -> { KnownHostEntry entry = hostsEntries.get(host); outputDebugMessage("Re-validate host=%s", entry); SocketAddress address = new SshdSocketAddress(host, 7365); Mockito.when(session.getConnectAddress()).thenReturn(address); assertTrue("Failed to re-validate server=" + entry, verifier.verifyServerKey(session, address, serverKey)); }); } @Test public void testRejectModifiedServerKey() throws Exception { KeyPair kp = Utils.generateKeyPair(KeyUtils.RSA_ALGORITHM, 1024); final PublicKey modifiedKey = kp.getPublic(); final AtomicInteger acceptCount = new AtomicInteger(0); ServerKeyVerifier verifier = new KnownHostsServerKeyVerifier(AcceptAllServerKeyVerifier.INSTANCE, createKnownHostsCopy()) { @Override public boolean acceptModifiedServerKey( ClientSession clientSession, SocketAddress remoteAddress, KnownHostEntry entry, PublicKey expected, PublicKey actual) throws Exception { acceptCount.incrementAndGet(); assertSame("Mismatched actual key for " + remoteAddress, modifiedKey, actual); return super.acceptModifiedServerKey(clientSession, remoteAddress, entry, expected, actual); } }; int validationCount = 0; // Cannot use forEach because the validation count variable is not effectively final for (Map.Entry<String, KnownHostEntry> ke : hostsEntries.entrySet()) { String host = ke.getKey(); KnownHostEntry entry = ke.getValue(); outputDebugMessage("Verify host=%s", entry); assertFalse("Unexpected to verification success for " + entry, invokeVerifier(verifier, host, modifiedKey)); validationCount++; assertEquals("Mismatched invocation count for host=" + entry, validationCount, acceptCount.get()); } } @Test public void testAcceptModifiedServerKeyUpdatesFile() throws Exception { KeyPair kp = Utils.generateKeyPair(KeyUtils.RSA_ALGORITHM, 1024); final PublicKey modifiedKey = kp.getPublic(); Path path = createKnownHostsCopy(); ServerKeyVerifier verifier = new KnownHostsServerKeyVerifier(AcceptAllServerKeyVerifier.INSTANCE, path) { @Override public boolean acceptModifiedServerKey( ClientSession clientSession, SocketAddress remoteAddress, KnownHostEntry entry, PublicKey expected, PublicKey actual) throws Exception { assertSame("Mismatched actual key for " + remoteAddress, modifiedKey, actual); return true; } }; hostsEntries.forEach((host, entry) -> { outputDebugMessage("Verify host=%s", entry); assertTrue("Failed to verify " + entry, invokeVerifier(verifier, host, modifiedKey)); }); String expected = PublicKeyEntry.toString(modifiedKey); Map<String, KnownHostEntry> updatedKeys = loadEntries(path); hostsEntries.forEach((host, original) -> { KnownHostEntry updated = updatedKeys.remove(host); assertNotNull("No updated entry for " + original, updated); String actual = updated.getConfigLine(); int pos = actual.indexOf(' '); if (actual.charAt(0) == KnownHostEntry.MARKER_INDICATOR) { for (pos++; pos < actual.length(); pos++) { if (actual.charAt(pos) != ' ') { break; } } pos = actual.indexOf(' ', pos); } actual = GenericUtils.trimToEmpty(actual.substring(pos + 1)); assertEquals("Mismatched updated value for host=" + host, expected, actual); }); assertTrue("Unexpected extra updated entries: " + updatedKeys, updatedKeys.isEmpty()); } private Path createKnownHostsCopy() throws IOException { Path file = getKnownHostCopyPath(); Files.copy(entriesFile, file, StandardCopyOption.REPLACE_EXISTING); return file; } private Path getKnownHostCopyPath() throws IOException { Path file = getTempTargetRelativeFile(getClass().getSimpleName(), getCurrentTestName()); assertHierarchyTargetFolderExists(file.getParent()); return file; } private boolean invokeVerifier(ServerKeyVerifier verifier, String host, PublicKey serverKey) { SocketAddress address = new SshdSocketAddress(host, 7365); ClientSession session = Mockito.mock(ClientSession.class); Mockito.when(session.getConnectAddress()).thenReturn(address); Mockito.when(session.toString()).thenReturn(getCurrentTestName() + "[" + host + "]"); return verifier.verifyServerKey(session, address, serverKey); } private static Map<String, KnownHostEntry> loadEntries(Path file) throws IOException { Collection<KnownHostEntry> entries = KnownHostEntry.readKnownHostEntries(file); if (GenericUtils.isEmpty(entries)) { return Collections.emptyMap(); } Map<String, KnownHostEntry> hostsMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); for (KnownHostEntry entry : entries) { String line = entry.getConfigLine(); outputDebugMessage("loadTestLines(%s) processing %s", file, line); // extract hosts int pos = line.indexOf(' '); String patterns = line.substring(0, pos); if (entry.getHashedEntry() != null) { assertNull("Multiple hashed entries in file", hostsMap.put(HASHED_HOST, entry)); } else { String[] addrs = GenericUtils.split(patterns, ','); for (String a : addrs) { assertNull("Multiple entries for address=" + a, hostsMap.put(a, entry)); } } } return hostsMap; } }