/******************************************************************************* * Copyright (c) 2010 Haifeng Li * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package smile.classification; import java.io.BufferedReader; import java.io.IOException; import java.util.ArrayList; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; /** * * @author Haifeng Li */ public class MaxentTest { class Dataset { int[][] x; int[] y; int p; } @SuppressWarnings("unused") Dataset load(String resource) { int p = 0; ArrayList<int[]> x = new ArrayList<>(); ArrayList<Integer> y = new ArrayList<>(); try (BufferedReader input = smile.data.parser.IOUtils.getTestDataReader(resource)) { String[] words = input.readLine().split(" "); int nseq = Integer.parseInt(words[0]); int k = Integer.parseInt(words[1]); p = Integer.parseInt(words[2]); String line = null; while ((line = input.readLine()) != null) { words = line.split(" "); int seqid = Integer.parseInt(words[0]); int pos = Integer.parseInt(words[1]); int len = Integer.parseInt(words[2]); int[] feature = new int[len]; for (int i = 0; i < len; i++) { feature[i] = Integer.parseInt(words[i+3]); } x.add(feature); y.add(Integer.valueOf(words[len+3])); } } catch (IOException ex) { System.err.println(ex); } Dataset dataset = new Dataset(); dataset.p = p; dataset.x = new int[x.size()][]; dataset.y = new int[y.size()]; for (int i = 0; i < dataset.x.length; i++) { dataset.x[i] = x.get(i); dataset.y[i] = y.get(i); } return dataset; } public MaxentTest() { } @BeforeClass public static void setUpClass() throws Exception { } @AfterClass public static void tearDownClass() throws Exception { } @Before public void setUp() { } @After public void tearDown() { } /** * Test of learn method, of class Maxent. */ @Test public void testLearnProtein() { System.out.println("learn protein"); Dataset train = load("sequence/sparse.protein.11.train"); Dataset test = load("sequence/sparse.protein.11.test"); Maxent maxent = new Maxent(train.p, train.x, train.y, 0.1, 1E-5, 500); int error = 0; for (int i = 0; i < test.x.length; i++) { if (test.y[i] != maxent.predict(test.x[i])) { error++; } } System.out.format("Protein error is %d of %d%n", error, test.x.length); System.out.format("Protein error rate = %.2f%%%n", 100.0 * error / test.x.length); assertEquals(1338, error); } /** * Test of learn method, of class Maxent. */ @Test public void testLearnHyphen() { System.out.println("learn hyphen"); Dataset train = load("sequence/sparse.hyphen.6.train"); Dataset test = load("sequence/sparse.hyphen.6.test"); Maxent maxent = new Maxent(train.p, train.x, train.y, 0.1, 1E-5, 500); int error = 0; for (int i = 0; i < test.x.length; i++) { if (test.y[i] != maxent.predict(test.x[i])) { error++; } } System.out.format("Protein error is %d of %d%n", error, test.x.length); System.out.format("Hyphen error rate = %.2f%%%n", 100.0 * error / test.x.length); assertEquals(765, error); } }