package org.deeplearning4j.eval;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
/**
* BaseEvaluation implement common evaluation functionality (for time series, etc) for {@link Evaluation},
* {@link RegressionEvaluation}, {@link ROC}, {@link ROCMultiClass} etc.
*
* @author Alex Black
*/
@EqualsAndHashCode
public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvaluation<T> {
private static ObjectMapper objectMapper = new ObjectMapper();
private static ObjectMapper yamlMapper = new ObjectMapper(new YAMLFactory());
@Override
public void evalTimeSeries(INDArray labels, INDArray predicted) {
evalTimeSeries(labels, predicted, null);
}
@Override
public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) {
Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, labelsMask);
INDArray labels2d = pair.getFirst();
INDArray predicted2d = pair.getSecond();
eval(labels2d, predicted2d);
}
@Override
public void eval(INDArray labels, INDArray networkPredictions, List<? extends Serializable> recordMetaData) {
eval(labels, networkPredictions);
}
@Override
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) {
if (maskArray == null) {
eval(labels, networkPredictions);
return;
}
if (labels.rank() == 3 && maskArray.rank() == 2) {
//Per-output masking
evalTimeSeries(labels, networkPredictions, maskArray);
return;
}
throw new UnsupportedOperationException(
this.getClass().getSimpleName() + " does not support per-output masking");
}
/**
* @return
*/
@Override
public String toJson() {
try {
return objectMapper.writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* @return
*/
@Override
public String toYaml() {
try {
return yamlMapper.writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
*
* @param json
* @param clazz
* @param <T>
* @return
*/
public static <T extends BaseEvaluation> T fromYaml(String json,Class<T> clazz) {
try {
return yamlMapper.readValue(json,clazz);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
*
* @param json
* @param clazz
* @param <T>
* @return
*/
public static <T extends BaseEvaluation> T fromJson(String json,Class<T> clazz) {
try {
return objectMapper.readValue(json,clazz);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public String toString() {
return stats();
}
}