/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.sequence;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.ParseException;
import java.util.ArrayList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
import smile.data.Attribute;
import smile.data.NominalAttribute;
/**
*
* @author Haifeng Li
*/
@SuppressWarnings("unused")
public class CRFTest {
class Dataset {
Attribute[] attributes;
double[][][] x;
int[][] y;
int p;
int k;
}
class IntDataset {
int[][][] x;
int[][] y;
int p;
int k;
}
IntDataset load(String resource) {
int p = 0;
int k = 0;
IntDataset dataset = new IntDataset();
ArrayList<int[][]> x = new ArrayList<>();
ArrayList<int[]> y = new ArrayList<>();
ArrayList<int[]> seq = new ArrayList<>();
ArrayList<Integer> label = new ArrayList<>();
int id = 1;
try(BufferedReader input = smile.data.parser.IOUtils.getTestDataReader(resource)) {
String[] words = input.readLine().split(" ");
int nseq = Integer.parseInt(words[0]);
k = Integer.parseInt(words[1]);
p = Integer.parseInt(words[2]);
String line = null;
while ((line = input.readLine()) != null) {
words = line.split(" ");
int seqid = Integer.parseInt(words[0]);
int pos = Integer.parseInt(words[1]);
int len = Integer.parseInt(words[2]);
int[] feature = new int[len];
for (int i = 0; i < len; i++) {
try {
feature[i] = Integer.parseInt(words[i+3]);
} catch (Exception ex) {
System.err.println(ex);
}
}
if (seqid == id) {
seq.add(feature);
label.add(Integer.valueOf(words[len + 3]));
} else {
id = seqid;
int[][] xx = new int[seq.size()][];
int[] yy = new int[seq.size()];
for (int i = 0; i < seq.size(); i++) {
xx[i] = seq.get(i);
yy[i] = label.get(i);
}
x.add(xx);
y.add(yy);
seq = new ArrayList<>();
label = new ArrayList<>();
seq.add(feature);
label.add(Integer.valueOf(words[len + 3]));
}
}
int[][] xx = new int[seq.size()][];
int[] yy = new int[seq.size()];
for (int i = 0; i < seq.size(); i++) {
xx[i] = seq.get(i);
yy[i] = label.get(i);
}
x.add(xx);
y.add(yy);
} catch (IOException ex) {
System.err.println(ex);
}
dataset.p = p;
dataset.k = k;
dataset.x = new int[x.size()][][];
dataset.y = new int[y.size()][];
for (int i = 0; i < dataset.x.length; i++) {
dataset.x[i] = x.get(i);
dataset.y[i] = y.get(i);
}
return dataset;
}
Dataset load(String resource, Attribute[] attributes) {
int p = 0;
int k = 0;
Dataset dataset = new Dataset();
dataset.attributes = attributes;
ArrayList<double[][]> x = new ArrayList<>();
ArrayList<int[]> y = new ArrayList<>();
ArrayList<double[]> seq = new ArrayList<>();
ArrayList<Integer> label = new ArrayList<>();
int id = 1;
try(BufferedReader input = smile.data.parser.IOUtils.getTestDataReader(resource)) {
String[] words = input.readLine().split(" ");
int nseq = Integer.parseInt(words[0]);
k = Integer.parseInt(words[1]);
p = Integer.parseInt(words[2]);
String line = null;
while ((line = input.readLine()) != null) {
words = line.split(" ");
int seqid = Integer.parseInt(words[0]);
int pos = Integer.parseInt(words[1]);
int len = Integer.parseInt(words[2]);
if (dataset.attributes == null) {
dataset.attributes = new Attribute[len];
for (int i = 0; i < len; i++) {
dataset.attributes[i] = new NominalAttribute("Attr" + (i+1));
}
}
double[] feature = new double[len];
for (int i = 0; i < len; i++) {
try {
feature[i] = dataset.attributes[i].valueOf(words[i+3]);
} catch (ParseException ex) {
System.err.println(ex);
}
}
if (seqid == id) {
seq.add(feature);
label.add(Integer.valueOf(words[len + 3]));
} else {
id = seqid;
double[][] xx = new double[seq.size()][];
int[] yy = new int[seq.size()];
for (int i = 0; i < seq.size(); i++) {
xx[i] = seq.get(i);
yy[i] = label.get(i);
}
x.add(xx);
y.add(yy);
seq = new ArrayList<>();
label = new ArrayList<>();
seq.add(feature);
label.add(Integer.valueOf(words[len + 3]));
}
}
double[][] xx = new double[seq.size()][];
int[] yy = new int[seq.size()];
for (int i = 0; i < seq.size(); i++) {
xx[i] = seq.get(i);
yy[i] = label.get(i);
}
x.add(xx);
y.add(yy);
} catch (IOException ex) {
System.err.println(ex);
}
dataset.p = p;
dataset.k = k;
dataset.x = new double[x.size()][][];
dataset.y = new int[y.size()][];
for (int i = 0; i < dataset.x.length; i++) {
dataset.x[i] = x.get(i);
dataset.y[i] = y.get(i);
}
return dataset;
}
public CRFTest() {
}
@BeforeClass
public static void setUpClass() throws Exception {
}
@AfterClass
public static void tearDownClass() throws Exception {
}
@Before
public void setUp() {
smile.math.Math.setSeed(54217137L);
}
@After
public void tearDown() {
}
/**
* Test of learn method, of class CRF.
*/
@Test
public void testLearnProteinSparse() {
System.out.println("learn protein sparse");
IntDataset train = load("sequence/sparse.protein.11.train");
IntDataset test = load("sequence/sparse.protein.11.test");
CRF.Trainer trainer = new CRF.Trainer(train.p, train.k);
trainer.setLearningRate(0.3);
trainer.setMaxNodes(100);
trainer.setNumTrees(100);
CRF crf = trainer.train(train.x, train.y);
int error = 0;
int n = 0;
for (int i = 0; i < test.x.length; i++) {
n += test.x[i].length;
int[] label = crf.predict(test.x[i]);
for (int j = 0; j < test.x[i].length; j++) {
if (test.y[i][j] != label[j]) {
error++;
}
}
}
int viterbiError = 0;
crf.setViterbi(true);
for (int i = 0; i < test.x.length; i++) {
n += test.x[i].length;
int[] label = crf.predict(test.x[i]);
for (int j = 0; j < test.x[i].length; j++) {
if (test.y[i][j] != label[j]) {
viterbiError++;
}
}
}
System.out.format("Protein error (forward-backward) is %d of %d%n", error, n);
System.out.format("Protein error (forward-backward) rate = %.2f%%%n", 100.0 * error / n);
System.out.format("Protein error (Viterbi) is %d of %d%n", viterbiError, n);
System.out.format("Protein error (Viterbi) rate = %.2f%%%n", 100.0 * viterbiError / n);
assertEquals(1234, error);
assertEquals(1318, viterbiError);
}
/**
* Test of learn method, of class CRF.
*/
@Test
public void testLearnHyphenSparse() {
System.out.println("learn hyphen sparse");
IntDataset train = load("sequence/sparse.hyphen.6.train");
IntDataset test = load("sequence/sparse.hyphen.6.test");
CRF.Trainer trainer = new CRF.Trainer(train.p, train.k);
trainer.setLearningRate(1.0);
trainer.setMaxNodes(100);
trainer.setNumTrees(100);
CRF crf = trainer.train(train.x, train.y);
int error = 0;
int n = 0;
for (int i = 0; i < test.x.length; i++) {
n += test.x[i].length;
int[] label = crf.predict(test.x[i]);
for (int j = 0; j < test.x[i].length; j++) {
if (test.y[i][j] != label[j]) {
error++;
}
}
}
int viterbiError = 0;
crf.setViterbi(true);
for (int i = 0; i < test.x.length; i++) {
n += test.x[i].length;
int[] label = crf.predict(test.x[i]);
for (int j = 0; j < test.x[i].length; j++) {
if (test.y[i][j] != label[j]) {
viterbiError++;
}
}
}
System.out.format("Hypen error (forward-backward) is %d of %d%n", error, n);
System.out.format("Hypen error (forward-backward) rate = %.2f%%%n", 100.0 * error / n);
System.out.format("Hypen error (Viterbi) is %d of %d%n", viterbiError, n);
System.out.format("Hypen error (Viterbi) rate = %.2f%%%n", 100.0 * viterbiError / n);
assertEquals(470, error);
assertEquals(478, viterbiError);
}
/**
* Test of learn method, of class CRF.
*/
@Test
public void testLearnProtein() {
System.out.println("learn protein");
Dataset train = load("sequence/sparse.protein.11.train", null);
Dataset test = load("sequence/sparse.protein.11.test", train.attributes);
CRF.Trainer trainer = new CRF.Trainer(train.attributes, train.k);
trainer.setLearningRate(0.3);
trainer.setMaxNodes(100);
trainer.setNumTrees(100);
CRF crf = trainer.train(train.x, train.y);
int error = 0;
int n = 0;
for (int i = 0; i < test.x.length; i++) {
n += test.x[i].length;
int[] label = crf.predict(test.x[i]);
for (int j = 0; j < test.x[i].length; j++) {
if (test.y[i][j] != label[j]) {
error++;
}
}
}
int viterbiError = 0;
crf.setViterbi(true);
for (int i = 0; i < test.x.length; i++) {
n += test.x[i].length;
int[] label = crf.predict(test.x[i]);
for (int j = 0; j < test.x[i].length; j++) {
if (test.y[i][j] != label[j]) {
viterbiError++;
}
}
}
System.out.format("Protein error (forward-backward) is %d of %d%n", error, n);
System.out.format("Protein error (forward-backward) rate = %.2f%%%n", 100.0 * error / n);
System.out.format("Protein error (Viterbi) is %d of %d%n", viterbiError, n);
System.out.format("Protein error (Viterbi) rate = %.2f%%%n", 100.0 * viterbiError / n);
assertEquals(1270, error);
assertEquals(1420, viterbiError);
}
/**
* Test of learn method, of class CRF.
*/
@Test
public void testLearnHyphen() {
System.out.println("learn hyphen");
Dataset train = load("sequence/sparse.hyphen.6.train", null);
Dataset test = load("sequence/sparse.hyphen.6.test", train.attributes);
CRF.Trainer trainer = new CRF.Trainer(train.attributes, train.k);
trainer.setLearningRate(1.0);
trainer.setMaxNodes(100);
trainer.setNumTrees(100);
CRF crf = trainer.train(train.x, train.y);
int error = 0;
int n = 0;
for (int i = 0; i < test.x.length; i++) {
n += test.x[i].length;
int[] label = crf.predict(test.x[i]);
for (int j = 0; j < test.x[i].length; j++) {
if (test.y[i][j] != label[j]) {
error++;
}
}
}
int viterbiError = 0;
crf.setViterbi(true);
for (int i = 0; i < test.x.length; i++) {
n += test.x[i].length;
int[] label = crf.predict(test.x[i]);
for (int j = 0; j < test.x[i].length; j++) {
if (test.y[i][j] != label[j]) {
viterbiError++;
}
}
}
System.out.format("Hypen error (forward-backward) is %d of %d%n", error, n);
System.out.format("Hypen error (forward-backward) rate = %.2f%%%n", 100.0 * error / n);
System.out.format("Hypen error (Viterbi) is %d of %d%n", viterbiError, n);
System.out.format("Hypen error (Viterbi) rate = %.2f%%%n", 100.0 * viterbiError / n);
assertEquals(473, error);
assertEquals(478, viterbiError);
}
}