package edu.umd.hooka.alignment.hmm;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import org.apache.hadoop.io.Writable;
public class ATable implements Writable, Cloneable {
float[][] data;
float _nullTrans;
int maxDist;
boolean modelNull;
int extraIPrevFactors;
int extraIFactors;
boolean homogeneous = true;
public Object clone() {
ATable r = new ATable(homogeneous, data.length, maxDist);
r._nullTrans = _nullTrans;
for (int i=0; i < data.length; ++i) {
System.arraycopy(data[i], 0, r.data[i], 0, data[i].length);
}
return r;
}
/* (non-Javadoc)
* @see edu.umd.hooka.alignment.hmm.IATable#clear()
*/
public void clear() {
for (int i = 0; i < data.length; i++) {
float[] row = data[i];
for (int j = 0; j < row.length; ++j)
row[j] = 0.0f;
}
_nullTrans = 0.0f;
}
public ATable() {}
public ATable(boolean homo, int conditioning_values, int dist) {
maxDist = dist;
homogeneous = homo;
modelNull = false;
extraIPrevFactors = 0;
extraIFactors = 0;
if (homogeneous)
assert(conditioning_values == 1);
data = new float[conditioning_values][];
for (int i = 0; i < conditioning_values; ++i)
if (homogeneous)
data[i] = new float[maxDist * 2 + 1];
else
data[i] = new float[i * 2 + 1];
}
public final int getMaxDist() {
return maxDist;
}
/* (non-Javadoc)
* @see edu.umd.hooka.alignment.hmm.IATable#get(int, char)
*/
public float get(int jump, char condition) {
if (jump == -1000) return _nullTrans;
try {
if (homogeneous)
return data[0][jump + maxDist];
else
return data[condition][jump + condition];
} catch (java.lang.ArrayIndexOutOfBoundsException e) {
throw new RuntimeException("Tried access: " + jump + "+" +maxDist + " but dl=" + data.length + " Caught " + e);
}
}
/* (non-Javadoc)
* @see edu.umd.hooka.alignment.hmm.IATable#get(int, char, int)
*/
public float get(int jump, char condition, int dummy) {
if (jump == -1000) return _nullTrans;
try {
if (homogeneous)
return data[0][jump];
else
return data[condition][jump];
} catch (java.lang.ArrayIndexOutOfBoundsException e) {
return 0;
// throw new RuntimeException("Tried access: " + jump + "+"
// + maxDist + " but dl=" + data.length + " Caught " + e);
}
}
/* (non-Javadoc)
* @see edu.umd.hooka.alignment.hmm.IATable#add(int, char, int, float)
*/
public void add(int jump, char condition, int dummy, float v) {
if (v == 0.0f) return;
if (jump == -1000)
_nullTrans += v;
else {
if (homogeneous)
data[0][jump + maxDist] += v;
else
data[condition][jump + condition] += v;
}
}
/* (non-Javadoc)
* @see edu.umd.hooka.alignment.hmm.IATable#add(int, char, float)
*/
public void add(int coord, char condition, float v)
{
if (v == 0.0f) return;
if (coord == -1000)
_nullTrans += v;
else {
if (homogeneous)
data[0][coord] += v;
else
data[condition][coord] += v;
}
}
/* (non-Javadoc)
* @see edu.umd.hooka.alignment.hmm.IATable#getCoord(int, char)
*/
public int getCoord(int jump, char condition) {
if (homogeneous) {
if (jump == -1000) return jump;
return jump + maxDist;
} else {
// TODO fix
if (jump == -1000) return jump;
return jump + condition;
}
}
/* (non-Javadoc)
* @see edu.umd.hooka.alignment.hmm.IATable#plusEquals(edu.umd.hooka.alignment.hmm.ATable)
*/
public void plusEquals(ATable rhs) {
if (data.length != rhs.data.length)
throw new RuntimeException("mismatch lengths!");
for (int i = 0; i < data.length; i++) {
float[] row = data[i];
float[] orow = rhs.data[i];
assert(row.length == orow.length);
for (int j = 0; j < row.length; ++j)
row[j] += orow[j];
}
_nullTrans += rhs._nullTrans;
}
// TODO: take alpha as a parameter, add support for VB
/* (non-Javadoc)
* @see edu.umd.hooka.alignment.hmm.IATable#normalize()
*/
public void normalize() {
boolean smooth = true;
float alpha = 0.00001f;
boolean renorm = false;
for (float[] row : data) {
float sum = 0;
if (modelNull)
sum = _nullTrans;
for (float v : row) sum += v;
if (sum > 0.0f) {
if (smooth) {
sum += alpha * (float)(row.length+1);
if (modelNull)
_nullTrans = (_nullTrans + alpha) / sum;
else
_nullTrans = 0;
for (int i = 0; i < row.length; i++)
row[i] = (row[i] + alpha) / sum;
} else {
_nullTrans /= sum;
for (int i = 0; i < data.length; i++)
row[i] /= sum;
}
continue;
}
boolean initializeUniform = false;
renorm = true;
if (initializeUniform) {
float up = 1.0f / (float)(row.length + 1);
for (int i = 0; i < row.length; i++) {
row[i] = up;
}
if (modelNull)
_nullTrans = up;
else
_nullTrans = 0;
} else {
for (int i = 0; i < row.length; i++) {
int len = (row.length - 1) / 2;
int ad = 0;
if (homogeneous)
ad = (i-maxDist)-1;
else
ad = (i - len) - 1;
if (ad > 0) ad *= -1;
if (homogeneous) {
if (i - maxDist == 0) ad -= 3;
} else {
if (i - len == 0 && i > 3) ad -= 3;
}
row[i] = (float)Math.exp((double)ad * 0.15);
}
if (modelNull & !homogeneous)
throw new RuntimeException("Not implemented properly");
if (modelNull) _nullTrans = (float)row[row.length/2]; else
_nullTrans = 0;
}
}
if (renorm) normalize();
}
public String toString() {
StringBuffer sb = new StringBuffer();
sb.append("ATable: maxDist=").append(maxDist).append('\n');
int i = -1;
for (float[] row : data) {
i++;
sb.append("cond=").append(i);
sb.append(" NULL-trans=").append(_nullTrans).append('\n');
int md = i;
if (homogeneous)
md = maxDist;
for (int j=0; j<row.length; j++)
sb.append(" P(J=").append(j-md).append(") = ").append(row[j]).append('\t');
sb.append('\n');
}
return sb.toString();
}
public void readFields(DataInput in) throws IOException {
homogeneous = in.readBoolean();
maxDist = in.readInt();
data = new float[in.readInt()][];
for (int i = 0; i < data.length; ++i) {
int bbLen = in.readInt();
ByteBuffer bb=ByteBuffer.allocate(bbLen);
in.readFully(bb.array());
FloatBuffer fb = bb.asFloatBuffer();
data[i] = new float[bbLen/4];
fb.get(data[i]);
}
_nullTrans = in.readFloat();
}
public void write(DataOutput out) throws IOException {
out.writeBoolean(homogeneous);
out.writeInt(maxDist);
out.writeInt(data.length);
for (int i = 0; i < data.length; ++i) {
int bbLen = data[i].length * 4;
out.writeInt(bbLen);
ByteBuffer bb=ByteBuffer.allocate(bbLen);
FloatBuffer fb = bb.asFloatBuffer();
fb.put(data[i], 0, data[i].length);
out.write(bb.array());
}
out.writeFloat(_nullTrans);
}
}