package com.datascience.core.nominal; import com.datascience.utils.CostMatrix; import com.google.common.math.DoubleMath; import java.util.*; import static com.google.common.base.Preconditions.checkArgument; /** * User: artur * Date: 4/12/13 */ public class PureNominalData { static protected final int MAX_CATEGORY_LENGTH = 100; protected Collection<String> categories; protected boolean fixedPriors; protected Map<String, Double> categoryPriors; protected CostMatrix<String> costMatrix; public Collection<String> getCategories(){ return categories; } public boolean arePriorsFixed(){ return fixedPriors; } public double getCategoryPrior(String name){ return categoryPriors.get(name); } public Map<String, Double> getCategoryPriors(){ return categoryPriors; } public void setCategoryPriors(Collection<CategoryValue> priors){ double priorSum = 0.; Set<String> categoryNames = new HashSet<String>(); for (CategoryValue cv : priors){ priorSum += cv.value; checkArgument(categoryNames.add(cv.categoryName), "CategoryPriors contains two categories with the same name"); checkArgument(cv.value > 0. && cv.value < 1., "Each category prior should be higher than 0 and less than 1"); } checkArgument(priors.size() == categories.size(), "Different number of categories in categoryPriors and categories parameters"); checkArgument(DoubleMath.fuzzyEquals(1., priorSum, 1e-6), "Priors should sum up to 1. or not to be given (therefore we initialize the priors to be uniform across classes)"); fixedPriors = true; categoryPriors = new HashMap<String, Double>(); for (CategoryValue cv : priors){ checkArgument(categories.contains(cv.categoryName), "Categories list does not contain category named %s", cv.categoryName); categoryPriors.put(cv.categoryName, cv.value); } } public CostMatrix<String> getCostMatrix(){ return costMatrix; } public void initialize(Collection<String> categories, Collection<CategoryValue> priors, CostMatrix<String> costMatrix){ checkArgument(categories != null, "There is no categories collection"); checkArgument(categories.size() >= 2, "There should be at least two categories"); for (String c : categories){ checkArgument(c.length() < MAX_CATEGORY_LENGTH, "Category names should be shorter than 50 chars"); } this.categories = new HashSet<String>(); this.categories.addAll(categories); checkArgument(this.categories.size() == categories.size(), "Category names should be different"); fixedPriors = false; if (priors != null){ setCategoryPriors(priors); } if (costMatrix == null) { this.costMatrix = new CostMatrix<String>(); } else { for (String s : costMatrix.getKnownValues()){ checkArgument(this.categories.contains(s), "Categories list does not contain category named %s", s); } this.costMatrix = costMatrix; } for (String c1 : categories) for (String c2 : categories){ if (!this.costMatrix.hasCost(c1, c2)) this.costMatrix.add(c1, c2, c1.equals(c2) ? 0. : 1.); } } public void checkForCategoryExist(String name){ if (!categories.contains(name)) throw new IllegalArgumentException("There is no category named: " + name); } }