/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.util;
import org.chombo.util.BasicUtils;
import org.chombo.util.TabularData;
/**
* Markov state transition probability matrix
* @author pranab
*
*/
public class StateTransitionProbability extends TabularData {
private int scale = 100;
private double[][] dTable;
private int floatPrecision = 3;
/**
*
*/
public StateTransitionProbability() {
super();
}
/**
* @param numRow
* @param numCol
*/
public StateTransitionProbability(int numRow, int numCol) {
super(numRow, numCol);
}
/**
* @param rowLabels
* @param colLabels
*/
public StateTransitionProbability(String[] rowLabels, String[] colLabels) {
super(rowLabels, colLabels);
}
/**
* @param scale
*/
public void setScale(int scale) {
this.scale = scale;
}
/**
* @param scale
* @return
*/
public StateTransitionProbability withScale(int scale) {
this.scale = scale;
return this;
}
/**
* @param floatPrecision
* @return
*/
public StateTransitionProbability withFloatPrecision(int floatPrecision) {
this.floatPrecision = floatPrecision;
return this;
}
/**
*
*/
public void normalizeRows() {
//laplace correction
for (int r = 0; r < numRow; ++r) {
boolean gotZeroCount = false;
for (int c = 0; c < numCol && !gotZeroCount; ++c) {
gotZeroCount = table[r][c] == 0;
}
if (gotZeroCount) {
for (int c = 0; c < numCol; ++c) {
table[r][c] += 1;
}
}
}
//normalize
int rowSum = 0;
if (scale == 1) {
dTable = new double[numRow][numCol];
}
for (int r = 0; r < numRow; ++r) {
rowSum = getRowSum(r);
for (int c = 0; c < numCol; ++c) {
if (scale > 1) {
table[r][c] = (table[r][c] * scale) / rowSum;
} else {
dTable[r][c] = ((double)table[r][c]) / rowSum;
}
}
}
}
/* (non-Javadoc)
* @see org.chombo.util.TabularData#toString()
*/
public String toString() {
StringBuilder stBld = new StringBuilder();
for (int i = 0; i < numRow; ++i) {
stBld.append(serializeRow(i)).append(DELIMETER);
}
return stBld.substring(0, stBld.length()-1);
}
/* (non-Javadoc)
* @see org.chombo.util.TabularData#serializeRow(int)
*/
public String serializeRow(int row) {
StringBuilder stBld = new StringBuilder();
for (int c = 0; c < numCol; ++c) {
if (scale > 1) {
stBld.append(table[row][c]).append(DELIMETER);
} else {
stBld.append(BasicUtils.formatDouble(dTable[row][c], floatPrecision)).append(DELIMETER);
}
}
return stBld.substring(0, stBld.length()-1);
}
/* (non-Javadoc)
* @see org.chombo.util.TabularData#deseralizeRow(java.lang.String, int)
*/
public void deseralizeRow(String data, int row) {
String[] items = data.split(DELIMETER);
int k = 0;
for (int c = 0; c < numCol; ++c) {
if (scale > 1) {
table[row][c] = Integer.parseInt(items[k++]);
} else {
dTable[row][c] = Double.parseDouble(items[k++]);
}
}
}
/**
* @param items
* @param start
* @param row
*/
public void deseralizeRow(String[] items, int start, int row) {
int k = start;
for (int c = 0; c < numCol; ++c) {
if (scale > 1) {
table[row][c] = Integer.parseInt(items[k++]);
} else {
dTable[row][c] = Double.parseDouble(items[k++]);
}
}
}
/**
* @param rowLabel
* @param colLabel
* @return
*/
public double get(String rowLabel, String colLabel) {
double value = 0;
int[] rowCol = getRowCol(rowLabel, colLabel);
if (scale > 1) {
value = table[rowCol[0]][rowCol[1]];
} else {
value = dTable[rowCol[0]][rowCol[1]];
}
return value;
}
}