package resa.evaluation.topology.vld;
import backtype.storm.task.OutputCollector;
import backtype.storm.task.TopologyContext;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.topology.base.BaseRichBolt;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Tuple;
import backtype.storm.tuple.Values;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import resa.util.ConfigUtil;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import static resa.evaluation.topology.vld.Constant.*;
/**
* Created by ding on 14-7-3.
*/
public class Matcher extends BaseRichBolt {
private static final Logger LOG = LoggerFactory.getLogger(Matcher.class);
private static final int[] EMPTY_MATCH = new int[0];
private Map<byte[], int[]> featDesc2Image;
private OutputCollector collector;
private double distThreshold;
@Override
public void prepare(Map stormConf, TopologyContext context, OutputCollector collector) {
featDesc2Image = new HashMap<>();
loadIndex(context.getThisTaskIndex(), context.getComponentTasks(context.getThisComponentId()).size());
this.collector = collector;
distThreshold = ConfigUtil.getDouble(stormConf, CONF_FEAT_DIST_THRESHOLD, 100);
}
private void loadIndex(int index, int totalPieces) {
int count = 0;
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(this.getClass().getResourceAsStream("/index.txt")))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.isEmpty() || count++ % totalPieces != index) {
continue;
}
StringTokenizer tokenizer = new StringTokenizer(line);
String[] tmp = StringUtils.split(tokenizer.nextToken(), ',');
byte[] feat = new byte[tmp.length];
for (int i = 0; i < feat.length; i++) {
feat[i] = (byte) (((int) Double.parseDouble(tmp[i])) & 0xFF);
}
int[] images = Stream.of(StringUtils.split(tokenizer.nextToken(), ',')).mapToInt(Integer::parseInt)
.toArray();
featDesc2Image.put(feat, images);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
LOG.info("taskIndex=" + index + ", totalIndexPieces=" + totalPieces + ", totalIndexEntry=" + count + ", load="
+ featDesc2Image.size());
}
@Override
public void execute(Tuple input) {
String frameId = input.getStringByField(FIELD_FRAME_ID);
List<byte[]> desc = (List<byte[]>) input.getValueByField(FIELD_FEATURE_DESC);
Map<Integer, Long> image2Freq = desc.stream().flatMap(imgDesc -> findMatches(imgDesc).stream())
.flatMap(imgList -> IntStream.of(imgList).boxed())
.collect(Collectors.groupingBy(i -> i, Collectors.counting()));
int[] matches = image2Freq.isEmpty() ? EMPTY_MATCH : new int[image2Freq.size() * 2];
int i = 0;
for (Map.Entry<Integer, Long> m : image2Freq.entrySet()) {
matches[i++] = m.getKey();
matches[i++] = m.getValue().intValue();
}
collector.emit(STREAM_MATCH_IMAGES, input, new Values(frameId, matches));
collector.ack(input);
}
private List<int[]> findMatches(byte[] desc) {
List<int[]> matches = new ArrayList<>();
for (Map.Entry<byte[], int[]> e : featDesc2Image.entrySet()) {
double d = distance(e.getKey(), desc);
if (d < distThreshold) {
matches.add(e.getValue());
}
}
return matches;
}
private double distance(byte[] v1, byte[] v2) {
double sum = 0;
for (int i = 0; i < v1.length; i++) {
double d = (v1[i] & 0xFF) - (v2[1] & 0xFF);
sum += d * d;
}
return Math.sqrt(sum);
}
@Override
public void declareOutputFields(OutputFieldsDeclarer declarer) {
declarer.declareStream(STREAM_MATCH_IMAGES, new Fields(FIELD_FRAME_ID, FIELD_MATCH_IMAGES));
}
}