/******************************************************************************* * Copyright (c) 2010 Haifeng Li * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package smile.association; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; import smile.math.Math; /** * * @author Haifeng Li */ @SuppressWarnings("unused") public class ARMTest { int[][] itemsets = { {1, 3}, {2}, {4}, {2, 3, 4}, {2, 3}, {2, 3}, {1, 2, 3, 4}, {1, 3}, {1, 2, 3}, {1, 2, 3} }; public ARMTest() { } @BeforeClass public static void setUpClass() throws Exception { } @AfterClass public static void tearDownClass() throws Exception { } @Before public void setUp() { } @After public void tearDown() { } /** * Test of learn method, of class ARM. */ @Test public void testLearn() { System.out.println("learn"); ARM instance = new ARM(itemsets, 3); instance.learn(0.5, System.out); List<AssociationRule> rules = instance.learn(0.5); assertEquals(9, rules.size()); assertEquals(0.6, rules.get(0).support, 1E-2); assertEquals(0.75, rules.get(0).confidence, 1E-2); assertEquals(1, rules.get(0).antecedent.length); assertEquals(3, rules.get(0).antecedent[0]); assertEquals(1, rules.get(0).consequent.length); assertEquals(2, rules.get(0).consequent[0]); assertEquals(0.3, rules.get(4).support, 1E-2); assertEquals(0.6, rules.get(4).confidence, 1E-2); assertEquals(1, rules.get(4).antecedent.length); assertEquals(1, rules.get(4).antecedent[0]); assertEquals(1, rules.get(4).consequent.length); assertEquals(2, rules.get(4).consequent[0]); assertEquals(0.3, rules.get(8).support, 1E-2); assertEquals(0.6, rules.get(8).confidence, 1E-2); assertEquals(1, rules.get(8).antecedent.length); assertEquals(1, rules.get(8).antecedent[0]); assertEquals(2, rules.get(8).consequent.length); assertEquals(3, rules.get(8).consequent[0]); assertEquals(2, rules.get(8).consequent[1]); } /** * Test of learn method, of class ARM. */ @Test public void testLearnPima() { System.out.println("pima"); List<int[]> dataList = new ArrayList<>(1000); try { BufferedReader input = smile.data.parser.IOUtils.getTestDataReader("transaction/pima.D38.N768.C2"); String line; for (int nrow = 0; (line = input.readLine()) != null; nrow++) { if (line.trim().isEmpty()) { continue; } String[] s = line.split(" "); int[] point = new int[s.length]; for (int i = 0; i < s.length; i++) { point[i] = Integer.parseInt(s[i]); } dataList.add(point); } } catch (IOException ex) { System.err.println(ex); } int[][] data = dataList.toArray(new int[dataList.size()][]); int n = Math.max(data); System.out.format("%d transactions, %d items%n", data.length, n); ARM instance = new ARM(data, 20); long numRules = instance.learn(0.9, System.out); System.out.format("%d association rules discovered%n", numRules); assertEquals(6803, numRules); assertEquals(6803, instance.learn(0.9).size()); } /** * Test of learn method, of class ARM. */ @Test public void testLearnKosarak() { System.out.println("kosarak"); List<int[]> dataList = new ArrayList<>(1000); try { BufferedReader input = smile.data.parser.IOUtils.getTestDataReader("transaction/kosarak.dat"); String line; for (int nrow = 0; (line = input.readLine()) != null; nrow++) { if (line.trim().isEmpty()) { continue; } String[] s = line.split(" "); Set<Integer> items = new HashSet<>(); for (int i = 0; i < s.length; i++) { items.add(Integer.parseInt(s[i])); } int j = 0; int[] point = new int[items.size()]; for (int i : items) { point[j++] = i; } dataList.add(point); } } catch (IOException ex) { System.err.println(ex); } int[][] data = dataList.toArray(new int[dataList.size()][]); int n = Math.max(data); System.out.format("%d transactions, %d items%n", data.length, n); ARM instance = new ARM(data, 0.003); long numRules = instance.learn(0.5, System.out); System.out.format("%d association rules discovered%n", numRules); assertEquals(17932, numRules); } }