package org.numenta.nupic.encoders; import java.util.*; import gnu.trove.list.TIntList; import gnu.trove.list.array.TIntArrayList; import org.joda.time.DateTime; import org.junit.Test; import org.numenta.nupic.util.ArrayUtils; import org.numenta.nupic.util.MinMax; import org.numenta.nupic.util.Tuple; import static org.junit.Assert.*; @SuppressWarnings("unchecked") public class DateEncoderTest { private DateEncoder de; private DateEncoder.Builder builder; private DateTime dt; private int[] expected; private int[] bits; private void setUp() { // 3 bits for season, 1 bit for day of week, 3 for weekend, 5 for time of day // use of forced is not recommended, used here for readability. builder = DateEncoder.builder(); de = builder.season(3) .dayOfWeek(1) .weekend(3) .timeOfDay(5).build(); //in the middle of fall, Thursday, not a weekend, afternoon - 4th Nov, 2010, 14:55 dt = new DateTime(2010, 11, 4, 14, 55); DateTime comparison = new DateTime(2010, 11, 4, 13, 55); bits = de.encode(dt); int[] comparisonBits = de.encode(comparison); System.out.println(Arrays.toString(bits)); System.out.println(Arrays.toString(comparisonBits)); // //dt.getMillis(); // season is aaabbbcccddd (1 bit/month) # TODO should be <<3? // should be 000000000111 (centered on month 11 - Nov) int[] seasonExpected = {0,0,0,0,0,0,0,0,0,1,1,1}; // week is MTWTFSS // contrary to local time documentation, Monday = 0 (for python // datetime.datetime.timetuple() int[] dayOfWeekExpected = {0,0,0,1,0,0,0}; // not a weekend, so it should be "False" int[] weekendExpected = {1,1,1,0,0,0}; // time of day has radius of 4 hours and w of 5 so each bit = 240/5 min = 48min // 14:55 is minute 14*60 + 55 = 895; 895/48 = bit 18.6 // should be 30 bits total (30 * 48 minutes = 24 hours) int[] timeOfDayExpected = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0}; expected = ArrayUtils.concatAll(seasonExpected, dayOfWeekExpected, weekendExpected, timeOfDayExpected); } private void initDE() { de = builder.build(); } /** * Creating date encoder instance */ @Test public void testDateEncoder() { setUp(); initDE(); List<Tuple> descs = de.getDescription(); assertNotNull(descs); // should be [("season", 0), ("day of week", 12), ("weekend", 19), ("time of day", 25)] List<Tuple> expectedDescs = new ArrayList<>(Arrays.asList( new Tuple("season", 0), new Tuple("day of week", 12), new Tuple("weekend", 19), new Tuple("time of day", 25) )); assertEquals(expectedDescs.size(), descs.size()); for (int i = 0; i < expectedDescs.size(); ++i) { Tuple desc = descs.get(i); assertNotNull(desc); assertEquals(expectedDescs.get(i), desc); } assertArrayEquals(expected, bits); System.out.println(); de.pprintHeader(""); de.pprint(bits, ""); System.out.println(); } // TODO Current implementation of DateEncoder throws at invalid Date, // but testMissingValues in Python expects it to encode it as all-zero bits: // def testMissingValues(self): // '''missing values''' // mvOutput = self._e.encode(SENTINEL_VALUE_FOR_MISSING_DATA) // self.assertEqual(sum(mvOutput), 0) /** * Decoding date */ @Test public void testDecoding() { setUp(); initDE(); //TODO Why null is needed? Tuple decoded = de.decode(bits, null); Map<String, RangeList> fieldsMap = (Map<String, RangeList>)decoded.get(0); List<String> fieldsOrder = (List<String>)decoded.get(1); assertNotNull(fieldsMap); assertNotNull(fieldsOrder); assertEquals(4, fieldsMap.size()); Map<String, Double> expectedMap = new HashMap<>(); expectedMap.put("season", 305.0); expectedMap.put("time of day", 14.4); expectedMap.put("day of week", 3.0); expectedMap.put("weekend", 0.0); for (String key : expectedMap.keySet()) { double expected = expectedMap.get(key); RangeList actual = fieldsMap.get(key); assertEquals(1, actual.size()); MinMax minmax = actual.getRange(0); assertEquals(expected, minmax.min(), de.getResolution()); assertEquals(expected, minmax.max(), de.getResolution()); } System.out.println(decoded.toString()); System.out.println(String.format("decodedToStr=>%s", de.decodedToStr(decoded))); } /** * Check topDownCompute */ @Test public void testTopDownCompute() { setUp(); initDE(); List<Encoding> topDown = de.topDownCompute(bits); List<Double> expectedList = Arrays.asList(320.25, 3.5, .167, 14.8); for (int i = 0; i < topDown.size(); i++) { Encoding r = topDown.get(i); double actual = (double)r.getValue(); double expected = expectedList.get(i); assertEquals(expected, actual, 4.0); } } /** * Check bucket index support */ @Test public void testBucketIndexSupport() { setUp(); initDE(); int[] bucketIndices = de.getBucketIndices(dt); System.out.println(String.format("bucket indices: %s", Arrays.toString(bucketIndices))); List<Encoding> bucketInfo = de.getBucketInfo(bucketIndices); List<Double> expectedList = Arrays.asList(320.25, 3.5, .167, 14.8); TIntList encodings = new TIntArrayList(); for (int i = 0; i < bucketInfo.size(); i++) { Encoding r = bucketInfo.get(i); double actual = (double)r.getValue(); double expected = expectedList.get(i); assertEquals(expected, actual, 4.0); encodings.addAll(r.getEncoding()); } assertArrayEquals(expected, encodings.toArray()); } /** * look at holiday more carefully because of the smooth transition */ @Test public void testHoliday() { //use of forced is not recommended, used here for readability, see ScalarEncoder DateEncoder e = DateEncoder.builder().holiday(5).forced(true).build(); int [] holiday = new int[]{0,0,0,0,0,1,1,1,1,1}; int [] notholiday = new int[]{1,1,1,1,1,0,0,0,0,0}; int [] holiday2 = new int[]{0,0,0,1,1,1,1,1,0,0}; DateTime d = new DateTime(2010, 12, 25, 4, 55); //System.out.println(String.format("1:%s", Arrays.toString(e.encode(d)))); assertArrayEquals(holiday, e.encode(d)); d = new DateTime(2008, 12, 27, 4, 55); //System.out.println(String.format("2:%s", Arrays.toString(e.encode(d)))); assertArrayEquals(notholiday, e.encode(d)); d = new DateTime(1999, 12, 26, 8, 0); //System.out.println(String.format("3:%s", Arrays.toString(e.encode(d)))); assertArrayEquals(holiday2, e.encode(d)); d = new DateTime(2011, 12, 24, 16, 0); //System.out.println(String.format("4:%s", Arrays.toString(e.encode(d)))); assertArrayEquals(holiday2, e.encode(d)); } /** * Test weekend encoder */ @Test public void testWeekend() { //use of forced is not recommended, used here for readability, see ScalarEncoder DateEncoder e = DateEncoder.builder().customDays(21, Arrays.asList( "sat", "sun", "fri" )).forced(true).build(); DateEncoder mon = DateEncoder.builder().customDays(21, Arrays.asList( "Monday" )).forced(true).build(); DateEncoder e2 = DateEncoder.builder().weekend(21, 1).forced(true).build(); //DateTime d = new DateTime(1988,5,29,20,0); DateTime d = new DateTime(1988,5,29,20,0); assertArrayEquals(e.encode(d), e2.encode(d)); for (int i = 0; i < 300; i++) { DateTime curDate = d.plusDays(i + 1); assertArrayEquals(e.encode(curDate), e2.encode(curDate)); //Make sure Tuple decoded = mon.decode(mon.encode(curDate), null); Map<String, RangeList> fieldsMap = (Map<String, RangeList>)decoded.get(0); List<String> fieldsOrder = (List<String>)decoded.get(1); assertNotNull(fieldsMap); assertNotNull(fieldsOrder); assertEquals(1, fieldsMap.size()); RangeList range = fieldsMap.get("Monday"); assertEquals(1, range.size()); assertEquals(1, ((List<MinMax>)range.get(0)).size()); MinMax minmax = range.getRange(0); System.out.println("DateEncoderTest.testWeekend(): minmax.min() = " + minmax.min()); if(minmax.min() == 1.0) { assertEquals(1, curDate.getDayOfWeek()); } else { assertNotEquals(1, curDate.getDayOfWeek()); } } } }