package org.primftpd.services;
import org.apache.sshd.server.PublickeyAuthenticator;
import org.apache.sshd.server.session.ServerSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigInteger;
import java.security.PublicKey;
import java.util.List;
public class PubKeyAuthenticator implements PublickeyAuthenticator {
protected final Logger logger = LoggerFactory.getLogger(getClass());
private final List<PublicKey> pubKeys;
public PubKeyAuthenticator(List<PublicKey> pubKeys) {
this.pubKeys = pubKeys;
}
@Override
public boolean authenticate(String username, PublicKey key, ServerSession session) {
logger.debug("attempting pub key auth, user: {}, client key class: '{}'",
username, key.getClass().getName());
// never mind username
for (PublicKey configuredKey : pubKeys) {
boolean keyEquals = keysEqual(configuredKey, key);
logger.debug("pub key auth, success: {}, server key class '{}'",
keyEquals, configuredKey.getClass().getName());
if (keyEquals) {
return true;
}
}
return false;
}
private boolean keysEqual(PublicKey serverKey, PublicKey clientKey) {
if (serverKey != null && clientKey != null) {
if (serverKey instanceof org.bouncycastle.jce.provider.JCEECPublicKey)
{
try {
org.bouncycastle.jce.provider.JCEECPublicKey ecServerKey =
(org.bouncycastle.jce.provider.JCEECPublicKey)serverKey;
org.bouncycastle.math.ec.ECPoint serverKeyQ = ecServerKey.getQ();
BigInteger serverKeyX = serverKeyQ.getAffineXCoord().toBigInteger();
BigInteger serverKeyY = serverKeyQ.getAffineYCoord().toBigInteger();
BigInteger[] clientKeyPoint = clientKeyPointQ(clientKey);
return serverKeyX.equals(clientKeyPoint[0]) && serverKeyY.equals(clientKeyPoint[1]);
} catch (Exception e) {
logger.error(
"Could not get component of client key to compare with server key. Client key class: "
+ clientKey.getClass().getName(),
e);
}
} else {
return serverKey.equals(clientKey);
}
}
return false;
}
private BigInteger[] clientKeyPointQ(PublicKey clientKey)
throws NoSuchFieldException, IllegalAccessException, NoSuchMethodException, InvocationTargetException
{
Field field = clientKey.getClass().getDeclaredField("q");
field.setAccessible(true);
Object pointQ = field.get(clientKey);
BigInteger x = pointCoord("X", pointQ);
BigInteger y = pointCoord("Y", pointQ);
return new BigInteger[]{x, y};
}
private BigInteger pointCoord(String coord, Object point)
throws NoSuchMethodException, InvocationTargetException, IllegalAccessException
{
Method getCoord = null;
String methodName = "get" + coord;
Method[] methods = point.getClass().getMethods();
for(Method method : methods) {
// Class[] paras = method.getParameterTypes();
// StringBuilder sb = new StringBuilder();
// sb.append("method: ");
// sb.append(method.getName());
// sb.append("(");
// String delimiter = "";
// for (Class para : paras) {
// sb.append(delimiter);
// sb.append(para.getName());
// delimiter = ", ";
// }
// sb.append(")");
// logger.info("{}", sb.toString());
if (methodName.equals(method.getName())) {
getCoord = method;
break;
}
}
if (getCoord == null) {
throw new NoSuchMethodException(methodName);
}
//Method getCoord = point.getClass().getDeclaredMethod("get" + coord);
Object fieldElementX = getCoord.invoke(point);
Method toBigInt = fieldElementX.getClass().getDeclaredMethod("toBigInteger");
return (BigInteger)toBigInt.invoke(fieldElementX);
}
}