package org.numenta.nupic.algorithms; import static org.junit.Assert.*; import org.junit.Test; public class ClassificationTest { @Test public void testCopy() { String mon = "Monday"; String tue = "Tuesday"; String wed = "Wednesday"; double monVal = 0.01d; double tueVal = 0.80d; double wedVal = 0.30d; Classification<String> result = new Classification<>(); result.setActualValues(new String[] { mon, tue, wed }); result.setStats(1, new double[] { monVal, tueVal, wedVal }); assertTrue(result.getMostProbableValue(1).equals(tue)); assertNull(result.getMostProbableValue(2)); Classification<String> result2 = result.copy(); assertEquals(result, result2); result2.setStats(1, new double[] { monVal, tueVal, 0.5d }); assertNotEquals(result, result2); } @Test public void testGetMostProbableValue() { String mon = "Monday"; String tue = "Tuesday"; String wed = "Wednesday"; double monVal = 0.01d; double tueVal = 0.80d; double wedVal = 0.30d; Classification<String> result = new Classification<>(); result.setActualValues(new String[] { mon, tue, wed }); result.setStats(1, new double[] { monVal, tueVal, wedVal }); assertTrue(result.getMostProbableValue(1).equals(tue)); assertNull(result.getMostProbableValue(2)); double monVal2 = 0.30d; double tueVal2 = 0.01d; double wedVal2 = 0.29d; result.setStats(3, new double[] { monVal2, tueVal2, wedVal2 }); assertTrue(result.getMostProbableValue(3).equals(mon)); assertNull(result.getMostProbableValue(2)); } @Test public void testGetMostProbableBucketIndex() { String mon = "Monday"; String tue = "Tuesday"; String wed = "Wednesday"; double monVal = 0.01d; double tueVal = 0.80d; double wedVal = 0.30d; Classification<String> result = new Classification<>(); result.setActualValues(new String[] { mon, tue, wed }); result.setStats(1, new double[] { monVal, tueVal, wedVal }); assertTrue(result.getMostProbableBucketIndex(1) == 1); assertTrue(result.getMostProbableBucketIndex(2) == -1); double monVal2 = 0.30d; double tueVal2 = 0.01d; double wedVal2 = 0.29d; result.setStats(3, new double[] { monVal2, tueVal2, wedVal2 }); assertTrue(result.getMostProbableBucketIndex(3) == 0); assertTrue(result.getMostProbableBucketIndex(2) == -1); } @Test public void testGetCorrectStepsCount() { String mon = "Monday"; String tue = "Tuesday"; String wed = "Wednesday"; double monVal = 0.01d; double tueVal = 0.80d; double wedVal = 0.30d; Classification<String> result = new Classification<>(); result.setActualValues(new String[] { mon, tue, wed }); result.setStats(1, new double[] { monVal, tueVal, wedVal }); assertTrue(result.getMostProbableBucketIndex(1) == 1); assertTrue(result.getMostProbableBucketIndex(2) == -1); double monVal2 = 0.30d; double tueVal2 = 0.01d; double wedVal2 = 0.29d; result.setStats(3, new double[] { monVal2, tueVal2, wedVal2 }); assertTrue(result.getMostProbableBucketIndex(3) == 0); assertTrue(result.getMostProbableBucketIndex(2) == -1); assertTrue(result.getStepCount() == 2); } }