/* * Copyright 2014-2016 CyberVision, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.kaaproject.kaa.server.transports.http.transport; import org.apache.avro.specific.SpecificRecordBase; import org.apache.commons.codec.binary.Base64; import org.kaaproject.kaa.common.avro.AvroByteArrayConverter; import org.kaaproject.kaa.common.endpoint.CommonEpConstans; import org.kaaproject.kaa.common.endpoint.security.MessageEncoderDecoder; import org.kaaproject.kaa.common.hash.EndpointObjectHash; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.EOFException; import java.io.IOException; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.URL; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.PublicKey; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Vector; /** * Abstract HTTP Test Client Class. * * @author Andrey Panasenko <apanasenko@cybervisiontech.com> */ abstract public class HttpTestClient<T extends SpecificRecordBase, R extends SpecificRecordBase> implements Runnable { /** * The Constant logger. */ protected static final Logger logger = LoggerFactory .getLogger(HttpTestClient.class); /** * Random generator */ protected static Random rnd = new Random(); /** * Destination URL connection */ private HttpURLConnection connection; /** * Multipart objects container */ private MultipartObjects objects; /** * Test ID, random generated */ private int testId; /** * byte array for signature */ private byte[] signature; /** * byte array for encrypted SessionKey */ private byte[] key; /** * byte array for POST Data */ private byte[] data; /** * encoder/decoder */ private MessageEncoderDecoder crypt; /** * Client Private Key */ private PrivateKey clientPrivateKey; /** * Client Public Key */ private PublicKey clientPublicKey; /** * Client Public Key Hash */ private EndpointObjectHash clientPublicKeyHash; /** * AVRO request converter */ private AvroByteArrayConverter<T> requestConverter; /** * AVRO response converter */ private AvroByteArrayConverter<R> responseConverter; /** * Activity interface */ private HttpActivity<R> activity; /** * generated test SyncRequest */ private T request; /** * Constructor. * * @param serverPublicKey - server public key * @param commandName - command name, used as end of URL * @param activity - Activity interface implementation. * @throws MalformedURLException - throws if URL is incorrect * @throws Exception - throws if request creation failed */ public HttpTestClient(PublicKey serverPublicKey, String commandName, HttpActivity<R> activity) throws MalformedURLException, Exception { testId = rnd.nextInt(); this.activity = activity; //TODO: replace int bindPort = 7888; String url = "http://localhost:" + bindPort + "/domain/" + commandName; connection = (HttpURLConnection) new URL(url).openConnection(); objects = new MultipartObjects(); requestConverter = new AvroByteArrayConverter<>(getRequestConverterClass()); responseConverter = new AvroByteArrayConverter<>(getResponseConverterClass()); init(serverPublicKey); } /** * Generate String with random ascii symbols from 48 till 122 with length size. * * @param size of String * @return String with random ascii symbols */ public static String getRandomString(int size) { return MultipartObjects.getRandomString(size); } /** * generate random bytes array with size * * @param size of bytes * @return byte[] array of random bytes */ public static byte[] getRandomBytes(int size) { byte[] rndbytes = new byte[size]; rnd.nextBytes(rndbytes); return rndbytes; } /** * Initialization of request keys and encoder/decoder * * @param serverPublicKey - server public key * @throws Exception - if key generation failed. */ private void init(PublicKey serverPublicKey) throws Exception { KeyPairGenerator clientKeyGen; try { clientKeyGen = KeyPairGenerator.getInstance("RSA"); clientKeyGen.initialize(2048); KeyPair clientKeyPair = clientKeyGen.genKeyPair(); clientPrivateKey = clientKeyPair.getPrivate(); clientPublicKey = clientKeyPair.getPublic(); } catch (NoSuchAlgorithmException e) { throw new Exception(e.toString()); } crypt = new MessageEncoderDecoder(clientPrivateKey, clientPublicKey, serverPublicKey); try { key = crypt.getEncodedSessionKey(); } catch (GeneralSecurityException e) { throw new Exception(e.toString()); } ByteBuffer publicKeyBuffer = ByteBuffer.wrap(EndpointObjectHash.fromSha1(clientPublicKey.getEncoded()).getData()); clientPublicKeyHash = EndpointObjectHash.fromBytes(publicKeyBuffer.array()); } /** * Post initialization, encrypt and sign request * * @param request - request to encrypt and sign * @throws Exception - in case of encrypt error */ protected void postInit(T request) throws Exception { try { byte[] requestBodyRaw = requestConverter.toByteArray(request); data = crypt.encodeData(requestBodyRaw); signature = crypt.sign(data); if (signature.length > 256) { throw new Exception("Error signature length must not be more than 256, but " + signature.length); } } catch (IOException | GeneralSecurityException e) { throw new Exception(e.toString()); } objects.addObject(CommonEpConstans.REQUEST_SIGNATURE_ATTR_NAME, signature); objects.addObject(CommonEpConstans.REQUEST_KEY_ATTR_NAME, key); objects.addObject(CommonEpConstans.REQUEST_DATA_ATTR_NAME, data); } /* (non-Javadoc) * @see java.lang.Runnable#run() */ @Override public void run() { logger.trace("Test: " + testId + " started..."); IOException error = null; try { //connection.setChunkedStreamingMode(2048); connection.setRequestMethod("POST"); connection.setDoOutput(true); connection.setRequestProperty("Content-Type", objects.getContentType()); DataOutputStream out = new DataOutputStream(connection.getOutputStream()); objects.dumbObjects(out); out.flush(); out.close(); } catch (IOException e) { e.printStackTrace(); error = e; } List<Byte> bodyArray = new Vector<>(); try { DataInputStream r = new DataInputStream(connection.getInputStream()); while (true) { bodyArray.add(new Byte(r.readByte())); } } catch (EOFException eof) { } catch (IOException e) { e.printStackTrace(); error = e; } byte[] body = new byte[bodyArray.size()]; for (int i = 0; i < body.length; i++) { body[i] = bodyArray.get(i); } processComplete(error, connection.getHeaderFields(), body); } /** * push Response to client invocation code * * @param e - set if error received during HTTP request processing * @param header - header list * @param body - body byte array */ private void processComplete(IOException e, Map<String, List<String>> header, byte[] body) { if (e != null) { e.printStackTrace(); activity.httpRequestComplete(e, this.testId, null); return; } try { R response = decodeHttpResponse(header, body); activity.httpRequestComplete(null, this.testId, response); } catch (Exception e1) { e1.printStackTrace(); activity.httpRequestComplete(e1, this.testId, null); } } /** * Decode http response to Response * * @return type R Response */ protected R decodeHttpResponse(Map<String, List<String>> header, byte[] body) throws Exception { if (header.containsKey(CommonEpConstans.SIGNATURE_HEADER_NAME) && header.get(CommonEpConstans.SIGNATURE_HEADER_NAME) != null && header.get(CommonEpConstans.SIGNATURE_HEADER_NAME).size() > 0) { String sigHeader = header.get(CommonEpConstans.SIGNATURE_HEADER_NAME).get(0); byte[] respSignature = Base64.decodeBase64(sigHeader); byte[] respData = body; crypt.verify(respData, respSignature); logger.trace("Test " + getId() + " response verified, body size " + body.length); byte[] respDecoded = crypt.decodeData(respData); return responseConverter.fromByteArray(respDecoded); } else { throw new Exception("HTTP response incorrect, no signature fields " + CommonEpConstans.SIGNATURE_HEADER_NAME); } } /** * Test ID getter. * * @return int Test ID */ public int getId() { return testId; } /** * Client Public Key getter. * * @return the clientPublicKey */ public PublicKey getClientPublicKey() { return clientPublicKey; } /** * Client Public Key Hash getter. * * @return the clientPublicKeyHash */ public EndpointObjectHash getClientPublicKeyHash() { return clientPublicKeyHash; } /** * @return the request */ public T getRequest() { return request; } /** * * @param request */ public void setRequest(T request) { this.request = request; } /** * Gets the request converter class. * * @return the request converter class */ protected abstract Class<T> getRequestConverterClass(); /** * Gets the response converter class. * * @return the response converter class */ protected abstract Class<R> getResponseConverterClass(); }