package org.shanbo.feluca.distribute.model.horizon;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.msgpack.rpc.Client;
import org.msgpack.rpc.loop.EventLoop;
import org.shanbo.feluca.data2.HashPartitioner;
import org.shanbo.feluca.util.concurrent.ConcurrentExecutor;
/**
* <p>matrix still need testing;
* <p>To use it: follow the 'fetch-then-update' way;
* <p>needs delete method
* @author lgn
*
*/
public class MModelClient {
public static class VectorBuffer{
public float[] weights;
TIntArrayList idBuffer;
TFloatArrayList weightBuffer;
int bufferSize;
public VectorBuffer( int bufferSize, float[] weights){
this.weights = weights;
this.bufferSize = bufferSize;
this.idBuffer = new TIntArrayList(512);
this.weightBuffer = new TFloatArrayList(512);
}
public void clear(){
if (idBuffer.size() > Math.max(512,bufferSize)){
this.idBuffer = new TIntArrayList(512);
this.weightBuffer = new TFloatArrayList(512);
}else{
this.idBuffer.resetQuick();
this.weightBuffer.resetQuick();
}
}
public void fidToBufer(int fid){
this.idBuffer.add(fid);
this.weightBuffer.add(weights[fid]);
}
public int[] getIds(){
return idBuffer.toArray();
}
public float[] getWeights(){
return weightBuffer.toArray();
}
public String toString(){
return idBuffer.toString() + " : " + weightBuffer.toString();
}
}
public static class MatrixBuffer{
public float[][] matrix;
TIntArrayList idBuffer;
ArrayList<float[]> weightBuffer; //ref
int bufferSize;
public MatrixBuffer(int bufferSize, float[][] matrix){
this.matrix = matrix;
this.bufferSize = bufferSize;
this.idBuffer = new TIntArrayList(512);
this.weightBuffer = new ArrayList<float[]>(512);
}
public void clear(){
if (idBuffer.size() > Math.max(512,bufferSize)){
this.idBuffer = new TIntArrayList(512);
this.weightBuffer = new ArrayList<float[]>(512);
}else{
this.idBuffer.resetQuick();
this.weightBuffer.clear();
}
}
public void fidToBufer(int fid){
this.idBuffer.add(fid);
this.weightBuffer.add(matrix[fid]);
}
public int[] getIds(){
return idBuffer.toArray();
}
public float[][] getWeights(){
float[][] result = new float[weightBuffer.size()][];
for(int i = 0; i < weightBuffer.size(); i++){
result[i] = weightBuffer.get(i);
}
return result;
}
}
HashMap<String, VectorBuffer> vectorBuffers ; //vector size is determine by creator
HashMap<String, MatrixBuffer> matrixBuffers;
MModelLocal local;
EventLoop loop;
Client[] clients;
MModelRPC[] matrixModels;
List<String> dataServerAddresses;
int shardId;
HashPartitioner partitioner;
public MModelClient(List<String> dataServerAddresses, int shardId, MModelLocal local){
this.dataServerAddresses = dataServerAddresses;
this.shardId = shardId;
this.local = local;
clients = new Client[dataServerAddresses.size()];
matrixModels = new MModelRPC[dataServerAddresses.size()];
partitioner = new HashPartitioner(dataServerAddresses.size());
this.vectorBuffers = new HashMap<String, MModelClient.VectorBuffer>(3);
this.matrixBuffers = new HashMap<String, MModelClient.MatrixBuffer>(3);
}
public void connect() throws NumberFormatException, UnknownHostException{
loop = EventLoop.defaultEventLoop();
for(int i = 0; i < clients.length; i++){
String[] hostPort = dataServerAddresses.get(i).split(":");
clients[i] = new Client(hostPort[0], Integer.parseInt(hostPort[1]) + MModelRPC.PORT_AWAY, loop);
matrixModels[i] = clients[i].proxy(MModelRPC.class);
}
}
public synchronized void close(){
if (loop != null){
for(Client client : clients){
client.close();
}
loop.shutdown();
System.out.println("modelClients of #" + clients.length + " all closed");
}
}
public void createVector(final String vectorName, int globalVectorSize, final float defaultValue, float vibration){
local.vectorCreate(vectorName, globalVectorSize, defaultValue, vibration); //create to local
vectorBuffers.put(vectorName, new VectorBuffer(512, local.vectors.get(vectorName))); //link to buffer
}
public float[] getVector(String vectorName){
return local.vectors.get(vectorName);
}
public float[][] getMatrix(String matrixName){
return local.matrixes.get(matrixName);
}
public List<Future<Integer>> vectorUpdate(final String vectorName, final int[] fids) throws InterruptedException, ExecutionException{
final VectorBuffer vector = vectorBuffers.get(vectorName);
vector.clear();
for(int i = 0 ; i < fids.length; i++){
int toShardId = partitioner.decideShard(fids[i]);
if (toShardId == shardId){
vector.fidToBufer(fids[i]);
}
}
// push vector to other servers;
ArrayList<Callable<Integer>> pushCallables = new ArrayList<Callable<Integer>>();
for(int i = 0 ; i < clients.length; i++){
final int toShardId = i;
if (i == shardId){ //local ; no need to push
pushCallables.add(new Callable<Integer>() {
public Integer call() throws Exception {
return 1;
}
});
}else{
pushCallables.add(new Callable<Integer>() {
public Integer call() throws Exception {
return matrixModels[toShardId].vectorUpdate(vectorName, vector.getIds(), vector.getWeights());
}
});
}
}
return ConcurrentExecutor.asyncExecute(pushCallables);
}
public void createMatrix( String matrixName, int globalRowSize, int columnSize, float defaultValue, float vibration) {
local.matrixCreate(matrixName, globalRowSize, columnSize, defaultValue, vibration);
matrixBuffers.put(matrixName, new MatrixBuffer(512, local.matrixes.get(matrixName)));
}
public List<Future<Integer>> matrixUpdate(final String matrixName, final int[] fids) throws InterruptedException, ExecutionException{
final MatrixBuffer matrix = matrixBuffers.get(matrixName);
matrix.clear();
for(int i = 0 ; i < fids.length; i++){
int toShardId = partitioner.decideShard(fids[i]);
if (toShardId == shardId){
matrix.fidToBufer(fids[i]);
}
}
ArrayList<Callable<Integer>> pushCallables = new ArrayList<Callable<Integer>>();
for(int i = 0 ; i < clients.length; i++){
final int toShardId = i;
if (i == shardId){ //local ; no push
pushCallables.add(new Callable<Integer>() {
public Integer call() throws Exception {
return 1;
}
});
}else{
pushCallables.add(new Callable<Integer>() {
public Integer call() throws Exception {
return matrixModels[toShardId].matrixUpdate(matrixName, matrix.getIds(), matrix.getWeights());
}
});
}
}
return ConcurrentExecutor.asyncExecute(pushCallables);
// for(int i = 0 ; i < clients.length; i++){
// final int toShardId = i;
// if (i == shardId){
// continue;
// }else{
// matrixModels[toShardId].matrixUpdate(matrixName, matrix.getIds(), matrix.getWeights());
// }
// }
// return null;
}
}