/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.common.util;
import java.math.BigInteger;
import java.security.*;
import java.security.interfaces.*;
import java.security.spec.*;
import org.apache.mina.core.buffer.IoBuffer;
import org.apache.sshd.common.KeyPairProvider;
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.SshException;
/**
* TODO Add javadoc
*
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
*/
public final class Buffer {
public static final int DEFAULT_SIZE = 256;
private byte[] data;
private int rpos;
private int wpos;
public Buffer() {
this(DEFAULT_SIZE);
}
public Buffer(int size) {
this(new byte[getNextPowerOf2(size)], false);
}
public Buffer(byte[] data) {
this(data, 0, data.length, true);
}
public Buffer(byte[] data, boolean read) {
this(data, 0, data.length, read);
}
public Buffer(byte[] data, int off, int len) {
this(data, off, len, true);
}
public Buffer(byte[] data, int off, int len, boolean read) {
this.data = data;
this.rpos = off;
this.wpos = read ? len : 0;
}
@Override
public String toString() {
return "Buffer [rpos=" + rpos + ", wpos=" + wpos + ", size=" + data.length + "]";
}
/*======================
Global methods
======================*/
public int rpos() {
return rpos;
}
public void rpos(int rpos) {
this.rpos = rpos;
}
public int wpos() {
return wpos;
}
public void wpos(int wpos) {
ensureCapacity(wpos - this.wpos);
this.wpos = wpos;
}
public int available() {
return wpos - rpos;
}
public byte[] array() {
return data;
}
public void compact() {
if (available() > 0) {
System.arraycopy(data, rpos, data, 0, wpos - rpos);
}
wpos -= rpos;
rpos = 0;
}
public byte[] getCompactData() {
int l = available();
if (l > 0) {
byte[] b = new byte[l];
System.arraycopy(data, rpos, b, 0, l);
return b;
} else {
return new byte[0];
}
}
public void clear() {
rpos = 0;
wpos = 0;
}
public String printHex() {
return BufferUtils.printHex(array(), rpos(), available());
}
/*======================
Read methods
======================*/
public byte getByte() {
ensureAvailable(1);
return data[rpos++];
}
public int getInt() {
return (int) getUInt();
}
public long getUInt()
{
ensureAvailable(4);
long l = ((data[rpos++] << 24) & 0xff000000L)|
((data[rpos++] << 16) & 0x00ff0000L)|
((data[rpos++] << 8) & 0x0000ff00L)|
((data[rpos++] ) & 0x000000ffL);
return l;
}
public long getLong()
{
ensureAvailable(8);
long l = ((data[rpos++] << 56) & 0xff00000000000000L)|
((data[rpos++] << 48) & 0x00ff000000000000L)|
((data[rpos++] << 40) & 0x0000ff0000000000L)|
((data[rpos++] << 32) & 0x000000ff00000000L)|
((data[rpos++] << 24) & 0x00000000ff000000L)|
((data[rpos++] << 16) & 0x0000000000ff0000L)|
((data[rpos++] << 8) & 0x000000000000ff00L)|
((data[rpos++] ) & 0x00000000000000ffL);
return l;
}
public boolean getBoolean() {
return getByte() != 0;
}
public String getString() {
int len = getInt();
if (len < 0 || len > 32768) {
throw new IllegalStateException("Bad item length: " + len);
}
ensureAvailable(len);
String s = new String(data, rpos, len);
rpos += len;
return s;
}
public byte[] getStringAsBytes() {
return getBytes();
}
public BigInteger getMPInt() {
return new BigInteger(getMPIntAsBytes());
}
public byte[] getMPIntAsBytes() {
return getBytes();
}
public byte[] getBytes() {
int len = getInt();
if (len < 0 || len > 32768) {
throw new IllegalStateException("Bad item length: " + len);
}
byte[] b = new byte[len];
getRawBytes(b);
return b;
}
public void getRawBytes(byte[] buf) {
getRawBytes(buf, 0, buf.length);
}
public void getRawBytes(byte[] buf, int off, int len) {
ensureAvailable(len);
System.arraycopy(data, rpos, buf, off, len);
rpos += len;
}
public PublicKey getPublicKey() throws SshException {
int ow = wpos;
int len = getInt();
wpos = rpos + len;
try {
return getRawPublicKey();
} finally {
wpos = ow;
}
}
public PublicKey getRawPublicKey() throws SshException {
try {
PublicKey key;
String keyAlg = getString();
if (KeyPairProvider.SSH_RSA.equals(keyAlg)) {
BigInteger e = getMPInt();
BigInteger n = getMPInt();
KeyFactory keyFactory = SecurityUtils.getKeyFactory("RSA");
key = keyFactory.generatePublic(new RSAPublicKeySpec(n, e));
} else if (KeyPairProvider.SSH_DSS.equals(keyAlg)) {
BigInteger p = getMPInt();
BigInteger q = getMPInt();
BigInteger g = getMPInt();
BigInteger y = getMPInt();
KeyFactory keyFactory = SecurityUtils.getKeyFactory("DSA");
key = keyFactory.generatePublic(new DSAPublicKeySpec(y, p, q, g));
} else {
throw new IllegalStateException("Unsupported algorithm: " + keyAlg);
}
return key;
} catch (InvalidKeySpecException e) {
throw new SshException(e);
} catch (NoSuchAlgorithmException e) {
throw new SshException(e);
} catch (NoSuchProviderException e) {
throw new SshException(e);
}
}
public KeyPair getKeyPair() throws SshException {
try {
PublicKey pub;
PrivateKey prv;
String keyAlg = getString();
if (KeyPairProvider.SSH_RSA.equals(keyAlg)) {
BigInteger e = getMPInt();
BigInteger n = getMPInt();
BigInteger d = getMPInt();
BigInteger qInv = getMPInt();
BigInteger q = getMPInt();
BigInteger p = getMPInt();
BigInteger dP = d.remainder(p.subtract(BigInteger.valueOf(1)));
BigInteger dQ = d.remainder(q.subtract(BigInteger.valueOf(1)));
KeyFactory keyFactory = SecurityUtils.getKeyFactory("RSA");
pub = keyFactory.generatePublic(new RSAPublicKeySpec(n, e));
prv = keyFactory.generatePrivate(new RSAPrivateCrtKeySpec(n, e, d, p, q, dP, dQ, qInv));
} else if (KeyPairProvider.SSH_DSS.equals(keyAlg)) {
BigInteger p = getMPInt();
BigInteger q = getMPInt();
BigInteger g = getMPInt();
BigInteger y = getMPInt();
BigInteger x = getMPInt();
KeyFactory keyFactory = SecurityUtils.getKeyFactory("DSA");
pub = keyFactory.generatePublic(new DSAPublicKeySpec(y, p, q, g));
prv = keyFactory.generatePrivate(new DSAPrivateKeySpec(x, p, q, g));
} else {
throw new IllegalStateException("Unsupported algorithm: " + keyAlg);
}
return new KeyPair(pub, prv);
} catch (InvalidKeySpecException e) {
throw new SshException(e);
} catch (NoSuchAlgorithmException e) {
throw new SshException(e);
} catch (NoSuchProviderException e) {
throw new SshException(e);
}
}
public SshConstants.Message getCommand() {
byte b = getByte();
SshConstants.Message cmd = SshConstants.Message.fromByte(b);
if (cmd == null) {
throw new IllegalStateException("Unknown command code: " + b);
}
return cmd;
}
private void ensureAvailable(int a) {
if (available() < a) {
throw new BufferException("Underflow");
}
}
/*======================
Write methods
======================*/
public void putByte(byte b) {
ensureCapacity(1);
data[wpos++] = b;
}
public void putBuffer(Buffer buffer) {
int r = buffer.available();
ensureCapacity(r);
System.arraycopy(buffer.data, buffer.rpos, data, wpos, r);
wpos += r;
}
public void putBuffer(IoBuffer buffer) {
int r = buffer.remaining();
ensureCapacity(r);
buffer.get(data, wpos, r);
wpos += r;
}
/**
* Writes 32 bits
* @param i
*/
public void putInt(long i) {
ensureCapacity(4);
data[wpos++] = (byte) (i >> 24);
data[wpos++] = (byte) (i >> 16);
data[wpos++] = (byte) (i >> 8);
data[wpos++] = (byte) (i );
}
/**
* Writes 64 bits
* @param i
*/
public void putLong(long i) {
ensureCapacity(8);
data[wpos++] = (byte) (i >> 56);
data[wpos++] = (byte) (i >> 48);
data[wpos++] = (byte) (i >> 40);
data[wpos++] = (byte) (i >> 32);
data[wpos++] = (byte) (i >> 24);
data[wpos++] = (byte) (i >> 16);
data[wpos++] = (byte) (i >> 8);
data[wpos++] = (byte) (i );
}
public void putBoolean(boolean b) {
putByte(b ? (byte) 1 : (byte) 0);
}
public void putBytes(byte[] b) {
putBytes(b, 0, b.length);
}
public void putBytes(byte[] b, int off, int len) {
putInt(len);
ensureCapacity(len);
System.arraycopy(b, off, data, wpos, len);
wpos += len;
}
public void putString(String string) {
putString(string.getBytes());
}
public void putString(byte[] str) {
putInt(str.length);
putRawBytes(str);
}
public void putMPInt(BigInteger bi) {
putMPInt(bi.toByteArray());
}
public void putMPInt(byte[] foo) {
int i = foo.length;
if ((foo[0] & 0x80) != 0) {
i++;
putInt(i);
putByte((byte)0);
} else {
putInt(i);
}
putRawBytes(foo);
}
public void putRawBytes(byte[] d) {
putRawBytes(d, 0, d.length);
}
public void putRawBytes(byte[] d, int off, int len) {
ensureCapacity(len);
System.arraycopy(d, off, data, wpos, len);
wpos += len;
}
public void putPublicKey(PublicKey key) {
int ow = wpos;
putInt(0);
int ow1 = wpos;
putRawPublicKey(key);
int ow2 = wpos;
wpos = ow;
putInt(ow2 - ow1);
wpos = ow2;
}
public void putRawPublicKey(PublicKey key) {
if (key instanceof RSAPublicKey) {
putString(KeyPairProvider.SSH_RSA);
putMPInt(((RSAPublicKey) key).getPublicExponent());
putMPInt(((RSAPublicKey) key).getModulus());
} else if (key instanceof DSAPublicKey) {
putString(KeyPairProvider.SSH_DSS);
putMPInt(((DSAPublicKey) key).getParams().getP());
putMPInt(((DSAPublicKey) key).getParams().getQ());
putMPInt(((DSAPublicKey) key).getParams().getG());
putMPInt(((DSAPublicKey) key).getY());
} else {
throw new IllegalStateException("Unsupported algorithm: " + key.getAlgorithm());
}
}
public void putKeyPair(KeyPair key) {
if (key.getPrivate() instanceof RSAPrivateCrtKey) {
putString(KeyPairProvider.SSH_RSA);
putMPInt(((RSAPublicKey) key.getPublic()).getPublicExponent());
putMPInt(((RSAPublicKey) key.getPublic()).getModulus());
putMPInt(((RSAPrivateCrtKey) key.getPrivate()).getPrivateExponent());
putMPInt(((RSAPrivateCrtKey) key.getPrivate()).getCrtCoefficient());
putMPInt(((RSAPrivateCrtKey) key.getPrivate()).getPrimeQ());
putMPInt(((RSAPrivateCrtKey) key.getPrivate()).getPrimeP());
} else if (key.getPublic() instanceof DSAPublicKey) {
putString(KeyPairProvider.SSH_DSS);
putMPInt(((DSAPublicKey) key.getPublic()).getParams().getP());
putMPInt(((DSAPublicKey) key.getPublic()).getParams().getQ());
putMPInt(((DSAPublicKey) key.getPublic()).getParams().getG());
putMPInt(((DSAPublicKey) key.getPublic()).getY());
putMPInt(((DSAPrivateKey) key.getPrivate()).getX());
} else {
throw new IllegalStateException("Unsupported algorithm: " + key.getPublic().getAlgorithm());
}
}
public void putCommand(SshConstants.Message cmd) {
putByte(cmd.toByte());
}
private void ensureCapacity(int capacity) {
if (data.length - wpos < capacity) {
int cw = wpos + capacity;
byte[] tmp = new byte[getNextPowerOf2(cw)];
System.arraycopy(data, 0, tmp, 0, data.length);
data = tmp;
}
}
public static class BufferException extends RuntimeException {
public BufferException(String message) {
super(message);
}
}
private static int getNextPowerOf2(int i) {
int j = 1;
while (j < i) {
j <<= 1;
}
return j;
}
}