package esl.cuenet.generative;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import esl.cuenet.generative.structs.ContextNetwork;
import esl.cuenet.generative.structs.NetworkBuildingHelper;
import esl.cuenet.generative.structs.Propagate;
import esl.cuenet.generative.structs.SpaceTimeValueGenerators;
import esl.system.SysLoggerUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.log4j.Logger;
import org.junit.Test;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class PropagationUnitTest {
static{
SysLoggerUtils.initLogger();
}
Logger logger = Logger.getLogger(getClass());
@Test
public void testAcrossEventsContainingPhotos() throws IOException {
String distanceFile = "/data/ranker/real/ontology_cuenet.distances.txt";
List<String> locationStrings = Lists.newArrayList();
locationStrings.add("bounds");
logger.info("Building Network...");
ContextNetwork network = NetworkBuildingHelper.loadForUnitPropagationTest(locationStrings);
logger.info(network.count() + " " + locationStrings.size());
network.printTree(true);
SpaceTimeValueGenerators stGenerator = new SpaceTimeValueGenerators(locationStrings.iterator());
Propagate propagator = new Propagate(network, distanceFile, stGenerator);
propagator.show();
propagator.prepare(Sets.newHashSet("64"));
double l1delta;
// double[] deltas = new double[10];
for (int i=0; i<10; i++) {
l1delta = propagator.propagateOnceTable();
logger.info("delta = " + l1delta);
// propagator.printScores(4, 9);
// deltas[i] = l1delta;
}
logger.info("----- 5 -----"); propagator.printScores(8, 5);
logger.info("----- 15 -----"); propagator.printScores(8, 15);
logger.info("------25 -----"); propagator.printScores(8, 25);
logger.info("------35 -----"); propagator.printScores(8, 35);
}
@Test
public void testOneDataset() throws IOException {
String distanceFile = "/data/ranker/real/ontology_cuenet.distances.txt";
List<String> locationStrings = Lists.newArrayList();
locationStrings.add("bounds");
logger.info("Building Network...");
Multimap<Integer, String> refMap = HashMultimap.create();
ContextNetwork network = NetworkBuildingHelper.loadOneDatasetForPropagationTest(locationStrings, refMap);
logger.info(network.count() + " " + locationStrings.size());
//network.printTree(true);
SpaceTimeValueGenerators stGenerator = new SpaceTimeValueGenerators(locationStrings.iterator());
Propagate propagator = new Propagate(network, distanceFile, stGenerator);
propagator.show();
propagator.prepare(Sets.newHashSet("64"));
double l1delta;
for (int i=0; i<10; i++) {
l1delta = propagator.propagateOnceTable();
logger.info("delta = " + l1delta);
}
String maxes = "";
for (int i=5; i<refMap.keySet().size(); i+=10) {
logger.info("----- " + i + " -----");
// propagator.printScores(8, i);
logger.info("Refs for " + i + " " + refMap.get(i));
int[] positions = propagator.findObjectPositions(8, i, Lists.newArrayList(refMap.get(i)));
maxes += " " + Collections.max(Arrays.asList(ArrayUtils.toObject(positions)))+1;
logger.info(Arrays.toString(positions));
}
logger.info("MAXES" + maxes);
}
}