package edu.umd.hooka.alignment;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import edu.umd.hooka.CorpusInfo;
import edu.umd.hooka.PhrasePair;
import edu.umd.hooka.Vocab;
/**
* General EM training framework for word alignment models.
*/
public class EFMarginalCounter {
static CorpusInfo corpus =
CorpusInfo.getCorpus(
CorpusInfo.Corpus.ARABIC_5000k);
public static class MarginalMapper extends MapReduceBase
implements Mapper<IntWritable,PhrasePair,IntWritable,IndexedFloatArray> {
static final IntWritable emarginal = new IntWritable(0);
static final IntWritable fmarginal = new IntWritable(1);
int[] makeUnique(int[] x) {
int cur = -1;
int c = 0;
for (int i : x) {
if (i != cur) {
c++;
cur = i;
}
}
int[] res = new int[c];
cur = -1;
c = 0;
for (int i : x) {
if (i != cur) {
res[c] = i;
c++;
cur = i;
}
}
return res;
}
OutputCollector<IntWritable, IndexedFloatArray> output_;
float[] emap = new float[Vocab.MAX_VOCAB_INDEX];
float[] fmap = new float[Vocab.MAX_VOCAB_INDEX];
int maxF = -1;
int maxE = -1;
boolean hasValues = false;
public void map(IntWritable key, PhrasePair value,
OutputCollector<IntWritable,IndexedFloatArray> output,
Reporter reporter) throws IOException {
output_ = output;
int[] es = value.getE().getWords();
int[] fs = value.getF().getWords();
Arrays.sort(es);
Arrays.sort(fs);
es = makeUnique(es);
fs = makeUnique(fs);
if (es[es.length - 1] > maxE) maxE = es[es.length - 1];
if (fs[fs.length - 1] > maxF) maxF = fs[fs.length - 1];
for (int e : es) emap[e] += 1.0f;
for (int f : fs) fmap[f] += 1.0f;
hasValues = true;
}
public IndexedFloatArray makeIFA(float[] map, int max) {
int c = 0;
for (int i = 0; i <= max; i++)
if (map[i] > 0.5f) c++;
int[] ind = new int[c];
float[] vals = new float[c];
c = 0;
for (int i = 0; i <= max; i++)
if (map[i] > 0.5f) {
ind[c] = i;
vals[c] = map[i];
c++;
}
return new IndexedFloatArray(ind, vals);
}
@Override
public void close() {
try {
if (hasValues) {
output_.collect(emarginal, makeIFA(emap, maxE));
output_.collect(fmarginal, makeIFA(fmap, maxF));
}
} catch (IOException e) {
throw new RuntimeException("Caught " + e);
}
}
}
public static class MarginalReducer extends MapReduceBase
implements Reducer<IntWritable,IndexedFloatArray,IntWritable,IndexedFloatArray> {
IntWritable oe = new IntWritable();
public void reduce(IntWritable key, Iterator<IndexedFloatArray> values,
OutputCollector<IntWritable,IndexedFloatArray> output,
Reporter reporter) throws IOException {
IndexedFloatArray sum = new IndexedFloatArray();
while (values.hasNext()) {
sum.plusEqualsMismatchSize(values.next());
}
output.collect(key, sum);
}
}
@SuppressWarnings("deprecation")
public static void computeMarginals(Path bitext, Path outputPath, int mappers) throws IOException {
int reduceTasks = 2;
JobConf conf = new JobConf(EFMarginalCounter.class);
conf.setJobName("EFMarginals");
conf.setInputFormat(SequenceFileInputFormat.class);
conf.setOutputKeyClass(IntWritable.class);
conf.setOutputValueClass(IndexedFloatArray.class);
conf.setMapperClass(MarginalMapper.class);
conf.setReducerClass(MarginalReducer.class);
conf.setNumMapTasks(mappers);
conf.setNumReduceTasks(reduceTasks);
FileInputFormat.setInputPaths(conf, corpus.getBitext());
FileOutputFormat.setOutputPath(conf, outputPath);
conf.setOutputFormat(SequenceFileOutputFormat.class);
JobClient.runJob(conf);
}
public static void main(String[] args) throws IOException {
JobConf conf = new JobConf(EFMarginalCounter.class);
FileSystem fileSys = FileSystem.get(conf);
String sOutputPath="marginals";
Path outputPath = new Path(sOutputPath);
fileSys.delete(outputPath);
computeMarginals(corpus.getBitext(), outputPath, 38);
}
}