package edu.umd.hooka.ttables; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.IntBuffer; import java.nio.FloatBuffer; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.log4j.Logger; import org.apache.log4j.Level; import edu.umd.hooka.alignment.IndexedFloatArray; /** * Data structure that stores translation probabilities p(f|e) or * counts c(f,e). The set of values (*,e) associated with a particular * e are stored adjacent to one another in an array. For a given f * in (f,e) the location of the f is found using a binary search. * * Layout: http://www.umiacs.umd.edu/~redpony/ttable-structure.png * * @author redpony * */ public class TTable_monolithic extends TTable implements Cloneable { int[] _ef; // length <= |E|x|F| int[] _e; // length = |E| float[] _values; IndexedFloatArray _nullValues; // length = |F| Path _datapath; FileSystem _fs; private static final Logger myLogger = Logger.getLogger(TTable_monolithic.class); int eLen; int indexLen; public Object clone() { TTable_monolithic res = new TTable_monolithic(); res._ef = _ef.clone(); res._e = _e.clone(); res._values = _values.clone(); res.eLen = eLen; res.indexLen = indexLen; res._nullValues = (IndexedFloatArray)_nullValues.clone(); return res; } public TTable_monolithic() {} public TTable_monolithic(FileSystem fs, Path p) throws IOException { _fs = fs; _datapath = p; this.readFields(_fs.open(_datapath)); } public TTable_monolithic(int[] e, int[] ef, int maxF) { _ef = ef; _e = e; _nullValues = new IndexedFloatArray(maxF + 1); eLen = _e.length; indexLen = _ef.length; _values = new float[indexLen]; } public TTable_monolithic(int[] e, int[] ef, int maxF, FileSystem fs, Path p) { _ef = ef; _e = e; _nullValues = new IndexedFloatArray(maxF + 1); eLen = _e.length; indexLen = _ef.length; _values = new float[indexLen]; _fs = fs; _datapath = p; } public int getMaxF() { return _nullValues.size() - 1; } public int getMaxE() { return _e.length - 1; } int binSearch(int e, int f) { int min = _e[e]; int max = _e[e+1] - 1; while (min <= max) { int mid = (min + max) / 2; if (_ef[mid] > f) max = mid - 1; else if (_ef[mid] < f) min = mid + 1; else return mid; } throw new RuntimeException("Couldn't find (" + f + "," + e +")"); } public void add(int e, int f, float delta) { if (e == 0) _nullValues.add(f, delta); else _values[binSearch(e,f)] += delta; } public void set(int e, int f, float value) { if (e == 0) _nullValues.set(f, value); else _values[binSearch(e,f)] = value; } public void setDistribution(float[] x) { _values = x.clone(); } public void setNullDistribution(float[] x) { myLogger.setLevel(Level.DEBUG); myLogger.debug("Length of input array is " + x.length); _nullValues = new IndexedFloatArray(x.length); _nullValues.set(0, 0.0f); for(int i=1; i<x.length; i++) _nullValues.set(i, x[i]); } /* public void setNullDistribution(float[] x) { _nullValues = (IndexedFloatArray)(x.clone()); } */ public void set(int e, IndexedFloatArray fs) { if (e == 0) { _nullValues.copyFrom(fs); return; } int from = _e[e]; int len = _e[e+1] - from; if (len != fs.size()) throw new RuntimeException("Mismatch lengths: in ttable there are " + fs.size() + " parameters for e="+e+ ", but in the IA there are " + len); fs.copyTo(_values, from); } /* public void plusEquals(int e, IndexedFloatArray fs) { if (e == 0) _nullValues.plusEquals(fs); int from = _e[e]; int len = _e[e+1] - from; if (len != fs.size()) throw new RuntimeException("Mismatch lengths: in ttable there are " + fs.size() + " parameters for e="+e+ ", but in the IA there are " + len); fs.addTo(_values, from); }*/ public float get(int e, int f) { if (e == 0) return _nullValues.get(f); else { int min = _e[e]; int max = _e[e+1] - 1; while (min <= max) { int mid = (min + max) / 2; if (_ef[mid] > f) max = mid - 1; else if (_ef[mid] < f) min = mid + 1; else return _values[mid]; } return 0.0f; } } public void clear() { java.util.Arrays.fill(_values, 0.0f); } public void prune(float threshold) { throw new RuntimeException("Not implemented"); } public void normalize() { _nullValues.normalize(); for (int e = 1; e<_e.length - 1; e++) { int bf = _e[e]; int ef = _e[e+1]; float total = 0.0f; for (int f = bf; f < ef; f++) { total += _values[f]; } // make uniform if (total == 0.0f) { float u = 1.0f/(float)(ef - bf); for (int f = bf; f < ef; f++) _values[f] = u; } else { // normalize for (int f = bf; f < ef; f++) { _values[f] /= total; } } } } public void readFields(DataInput in) throws IOException { int bbLen = in.readInt(); ByteBuffer bb=ByteBuffer.allocate(bbLen); in.readFully(bb.array()); IntBuffer ib = bb.asIntBuffer(); _e = new int[bbLen/4]; ib.get(_e); eLen = _e.length; if (_nullValues == null) _nullValues = new IndexedFloatArray(); _nullValues.readFields(in); bbLen = in.readInt(); bb=ByteBuffer.allocate(bbLen); in.readFully(bb.array()); ib = bb.asIntBuffer(); _ef = new int[bbLen/4]; ib.get(_ef); bb=ByteBuffer.allocate(bbLen); in.readFully(bb.array()); FloatBuffer fb = bb.asFloatBuffer(); _values = new float[bbLen/4]; fb.get(_values); indexLen = _values.length; } public void write(DataOutput out) throws IOException { int bbLen = eLen * 4; out.writeInt(bbLen); ByteBuffer bb=ByteBuffer.allocate(bbLen); IntBuffer ib = bb.asIntBuffer(); ib.put(_e, 0, eLen); out.write(bb.array()); _nullValues.write(out); bbLen = indexLen * 4; out.writeInt(bbLen); bb=ByteBuffer.allocate(bbLen); ib=bb.asIntBuffer(); ib.put(_ef, 0, indexLen); out.write(bb.array()); bb=ByteBuffer.allocate(bbLen); FloatBuffer fb=bb.asFloatBuffer(); fb.put(_values, 0, indexLen); out.write(bb.array()); } public String toString() { StringBuffer sb = new StringBuffer(); sb.append("NULL: ").append(_nullValues.toString()).append("\n"); for (int e = 1; e<_e.length - 1; e++) { int bfi = _e[e]; int efi = _e[e+1]; for (int fi = bfi; fi < efi; fi++) { sb.append("e=").append(e) .append(" f=").append(_ef[fi]).append(" val=").append(_values[fi]).append("\n"); } } return sb.toString(); } @Override public void write() throws IOException { this.write(_fs.create(_datapath)); } }