package com.linkedin.thirdeye.rootcause.impl;
import com.linkedin.thirdeye.dataframe.DataFrame;
import com.linkedin.thirdeye.dataframe.DoubleSeries;
import com.linkedin.thirdeye.dataframe.StringSeries;
import com.linkedin.thirdeye.rootcause.Entity;
import com.linkedin.thirdeye.rootcause.Pipeline;
import com.linkedin.thirdeye.rootcause.PipelineContext;
import com.linkedin.thirdeye.rootcause.PipelineResult;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Implementation of an aggregator that handles the same entity being returned from multiple
* pipelines by summing the entity's weights. It optionally allows to truncate the input for
* each input pipeline <i>separately</i> to its top {@code k} elements before aggregation.
*/
public class LinearAggregationPipeline extends Pipeline {
private static Logger LOG = LoggerFactory.getLogger(LinearAggregationPipeline.class);
private final static String PROP_K = "k";
private final static String PROP_K_DEFAULT = "-1";
private static final String URN = "urn";
private static final String SCORE = "score";
private final int k;
/**
* Constructor for dependency injection
*
* @param outputName pipeline output name
* @param inputNames input pipeline names
* @param k top k truncation before aggregation ({@code -1} for unbounded)
*/
public LinearAggregationPipeline(String outputName, Set<String> inputNames, int k) {
super(outputName, inputNames);
this.k = k;
}
/**
* Alternate constructor for use by PipelineLoader
*
* @param outputName pipeline output name
* @param inputNames input pipeline names
* @param properties configuration properties ({@code PROP_K})
*/
public LinearAggregationPipeline(String outputName, Set<String> inputNames, Map<String, String> properties) {
super(outputName, inputNames);
String kProp = PROP_K_DEFAULT;
if(properties.containsKey(PROP_K))
kProp = properties.get(PROP_K);
this.k = Integer.parseInt(kProp);
}
@Override
public PipelineResult run(PipelineContext context) {
StringSeries.Builder urnBuilder = StringSeries.builder();
DoubleSeries.Builder scoreBuilder = DoubleSeries.builder();
for(Map.Entry<String, Set<Entity>> entry : context.getInputs().entrySet()) {
DataFrame df = toDataFrame(entry.getValue());
if(this.k >= 0) {
LOG.info("Truncating '{}' to {} entities (from {})", entry.getKey(), this.k, df.size());
df = df.sortedBy(SCORE).tail(this.k);
}
LOG.info("{}:\n{}", entry.getKey(), df.toString(50, URN, SCORE));
urnBuilder.addSeries(df.get(URN));
scoreBuilder.addSeries(df.getDoubles(SCORE));
}
StringSeries urns = urnBuilder.build();
DoubleSeries scores = scoreBuilder.build();
DataFrame df = new DataFrame();
df.addSeries(URN, urns);
df.addSeries(SCORE, scores);
DataFrame grp = df.groupBy(URN).aggregate(SCORE, DoubleSeries.SUM);
grp = grp.sortedBy(SCORE).reverse();
return new PipelineResult(context, toEntities(grp, URN, SCORE));
}
private static DataFrame toDataFrame(Collection<Entity> entities) {
String[] urns = new String[entities.size()];
double[] scores = new double[entities.size()];
int i = 0;
for(Entity e : entities) {
urns[i] = e.getUrn();
scores[i] = e.getScore();
i++;
}
return new DataFrame().addSeries(URN, urns).addSeries(SCORE, scores);
}
private static Set<Entity> toEntities(DataFrame df, String colUrn, String colScore) {
Set<Entity> entities = new HashSet<>();
for(int i=0; i<df.size(); i++) {
entities.add(new Entity(df.getString(colUrn, i), df.getDouble(colScore, i)));
}
return entities;
}
}