package com.datascience.serialization.json; import com.datascience.core.base.CategoryPair; import com.datascience.core.nominal.CategoryValue; import com.datascience.core.stats.MatrixValue; import com.datascience.core.stats.MultinomialConfusionMatrix; import com.google.gson.*; import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Map; /** * @Author: konrad */ public class MultinominalConfusionMatrixJSON { public static class ConfusionMatrixDeserializer implements JsonDeserializer<MultinomialConfusionMatrix> { @Override public MultinomialConfusionMatrix deserialize(JsonElement json, Type type, JsonDeserializationContext context) throws JsonParseException { JsonObject jobject = (JsonObject) json; Collection<String> categories = context.deserialize(jobject.get("categories"), JSONUtils.stringSetType); Collection<MatrixValue<String>> matrixValues = context.deserialize(jobject.get("matrix"), JSONUtils.matrixValuesCollectionType); Map<CategoryPair, Double> matrix = new HashMap<CategoryPair, Double>(); for (MatrixValue<String> mv : matrixValues){ matrix.put(new CategoryPair(mv.from, mv.to), mv.value); } Collection<CategoryValue> rowDenominatorValues = context.deserialize(jobject.get("rowDenominator"), JSONUtils.categoryValuesCollectionType); Map<String, Double> rowDenominator = new HashMap<String, Double>(); for (CategoryValue cv : rowDenominatorValues) { rowDenominator.put(cv.categoryName, cv.value); } return new MultinomialConfusionMatrix(categories, matrix, rowDenominator); } } public static class ConfusionMatrixSerializer implements JsonSerializer<MultinomialConfusionMatrix> { @Override public JsonElement serialize(MultinomialConfusionMatrix arg0, Type arg1, JsonSerializationContext arg2) { JsonObject ret = new JsonObject(); Collection<CategoryValue> cp = new ArrayList<CategoryValue>(arg0.rowDenominator.size()); for (Map.Entry<String, Double> e : arg0.rowDenominator.entrySet()){ cp.add(new CategoryValue(e.getKey(), e.getValue())); } ret.add("rowDenominator", arg2.serialize(cp)); Collection<MatrixValue> mv = new ArrayList<MatrixValue>(arg0.getMatrix().size()); for (Map.Entry<CategoryPair, Double> e : arg0.getMatrix().entrySet()){ mv.add(new MatrixValue(e.getKey().from, e.getKey().to, e.getValue())); } ret.add("matrix", arg2.serialize(mv)); ret.add("categories", arg2.serialize(arg0.getCategories())); return ret; } } }