/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.persist.persistors;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.StringTokenizer;
import org.encog.engine.network.rbf.RadialBasisFunction;
import org.encog.engine.util.EngineArray;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_model;
import org.encog.mathutil.libsvm.svm_node;
import org.encog.mathutil.libsvm.svm_parameter;
import org.encog.neural.networks.layers.RadialBasisFunctionLayer;
import org.encog.neural.networks.svm.SVMNetwork;
import org.encog.parse.tags.read.ReadXML;
import org.encog.parse.tags.write.WriteXML;
import org.encog.persist.EncogPersistedCollection;
import org.encog.persist.EncogPersistedObject;
import org.encog.persist.Persistor;
/**
* Persist a SVM network.
*/
public class SVMNetworkPersistor implements Persistor {
/**
* Constants for the SVM types.
*/
public static final String svm_type_table[] = { "c_svc", "nu_svc",
"one_class", "epsilon_svr", "nu_svr", };
/**
* Constants for the kernel types.
*/
public static final String kernel_type_table[] = { "linear", "polynomial",
"rbf", "sigmoid", "precomputed" };
/**
* The input tag.
*/
public static final String TAG_INPUT = "input";
/**
* The output tag.
*/
public static final String TAG_OUTPUT = "output";
/**
* The models tag.
*/
public static final String TAG_MODELS = "models";
/**
* The data tag.
*/
public static final String TAG_DATA = "Data";
/**
* The row tag.
*/
public static final String TAG_ROW = "Row";
/**
* The model tag.
*/
public static final String TAG_MODEL = "Model";
/**
* The type of SVM this is.
*/
public final static String TAG_TYPE_SVM = "typeSVM";
/**
* The type of kernel to use.
*/
public final static String TAG_TYPE_KERNEL = "typeKernel";
/**
* The degree to use.
*/
public final static String TAG_DEGREE = "degree";
/**
* The gamma to use.
*/
public final static String TAG_GAMMA = "gamma";
/**
* The coefficient.
*/
public final static String TAG_COEF0 = "coef0";
/**
* The number of classes.
*/
public final static String TAG_NUMCLASS = "numClass";
/**
* The total number of cases.
*/
public final static String TAG_TOTALSV = "totalSV";
/**
* The rho to use.
*/
public final static String TAG_RHO = "rho";
/**
* The labels.
*/
public final static String TAG_LABEL = "label";
/**
* The A-probability.
*/
public final static String TAG_PROB_A = "probA";
/**
* The B-probability.
*/
public final static String TAG_PROB_B = "probB";
/**
* The number of support vectors.
*/
public final static String TAG_NSV = "nSV";
/**
* Load the SVM network.
* @param in Where to read it from.
* @return The loaded object.
*/
@Override
public EncogPersistedObject load(ReadXML in) {
SVMNetwork result = null;
int input = -1, output = -1;
final String name = in.getTag().getAttributes().get(
EncogPersistedCollection.ATTRIBUTE_NAME);
final String description = in.getTag().getAttributes().get(
EncogPersistedCollection.ATTRIBUTE_DESCRIPTION);
while (in.readToTag()) {
if (in.is(SVMNetworkPersistor.TAG_INPUT, true)) {
input = Integer.parseInt(in.readTextToTag());
} else if (in.is(SVMNetworkPersistor.TAG_OUTPUT, true)) {
output = Integer.parseInt(in.readTextToTag());
} else if (in.is(SVMNetworkPersistor.TAG_MODELS, true)) {
result = new SVMNetwork(input, output, false);
handleModels(in, result);
} else if (in.is(EncogPersistedCollection.TYPE_SVM, false)) {
break;
}
}
result.setName(name);
result.setDescription(description);
return result;
}
/**
* Load the models.
* @param in Where to read the models from.
* @param network Where the models are read into.
*/
private void handleModels(ReadXML in, SVMNetwork network) {
int index = 0;
while (in.readToTag()) {
if (in.is(SVMNetworkPersistor.TAG_MODEL, true)) {
svm_parameter param = new svm_parameter();
svm_model model = new svm_model();
model.param = param;
network.getModels()[index] = model;
handleModel(in, network.getModels()[index]);
index++;
} else if (in.is(SVMNetworkPersistor.TAG_MODELS, false)) {
break;
}
}
}
/**
* Handle a model.
* @param in Where to read the model from.
* @param model Where to load the model into.
*/
private void handleModel(ReadXML in, svm_model model) {
while (in.readToTag()) {
if (in.is(SVMNetworkPersistor.TAG_TYPE_SVM, true)) {
int i = EngineArray.findStringInArray(
SVMNetworkPersistor.svm_type_table, in.readTextToTag());
model.param.svm_type = i;
} else if (in.is(SVMNetworkPersistor.TAG_DEGREE, true)) {
model.param.degree = Integer.parseInt(in.readTextToTag());
} else if (in.is(SVMNetworkPersistor.TAG_GAMMA, true)) {
model.param.gamma = Double.parseDouble(in.readTextToTag());
} else if (in.is(SVMNetworkPersistor.TAG_COEF0, true)) {
model.param.coef0 = Double.parseDouble(in.readTextToTag());
} else if (in.is(SVMNetworkPersistor.TAG_NUMCLASS, true)) {
model.nr_class = Integer.parseInt(in.readTextToTag());
} else if (in.is(SVMNetworkPersistor.TAG_TOTALSV, true)) {
model.l = Integer.parseInt(in.readTextToTag());
} else if (in.is(SVMNetworkPersistor.TAG_RHO, true)) {
int n = model.nr_class * (model.nr_class - 1) / 2;
model.rho = new double[n];
StringTokenizer st = new StringTokenizer(in.readTextToTag());
for (int i = 0; i < n; i++)
model.rho[i] = Double.parseDouble(st.nextToken());
} else if (in.is(SVMNetworkPersistor.TAG_LABEL, true)) {
int n = model.nr_class;
model.label = new int[n];
StringTokenizer st = new StringTokenizer(in.readTextToTag());
for (int i = 0; i < n; i++)
model.label[i] = Integer.parseInt(st.nextToken());
} else if (in.is(SVMNetworkPersistor.TAG_PROB_A, true)) {
int n = model.nr_class * (model.nr_class - 1) / 2;
model.probA = new double[n];
StringTokenizer st = new StringTokenizer(in.readTextToTag());
for (int i = 0; i < n; i++)
model.probA[i] = Double.parseDouble(st.nextToken());
} else if (in.is(SVMNetworkPersistor.TAG_PROB_B, true)) {
int n = model.nr_class * (model.nr_class - 1) / 2;
model.probB = new double[n];
StringTokenizer st = new StringTokenizer(in.readTextToTag());
for (int i = 0; i < n; i++)
model.probB[i] = Double.parseDouble(st.nextToken());
} else if (in.is(SVMNetworkPersistor.TAG_NSV, true)) {
int n = model.nr_class;
model.nSV = new int[n];
StringTokenizer st = new StringTokenizer(in.readTextToTag());
for (int i = 0; i < n; i++)
model.nSV[i] = Integer.parseInt(st.nextToken());
} else if (in.is(SVMNetworkPersistor.TAG_TYPE_KERNEL, true)) {
int i = EngineArray.findStringInArray(
SVMNetworkPersistor.kernel_type_table, in
.readTextToTag());
model.param.kernel_type = i;
} else if (in.is(SVMNetworkPersistor.TAG_DATA, true)) {
handleData(in, model);
} else if (in.is(SVMNetworkPersistor.TAG_MODEL, false)) {
break;
}
}
}
/**
* Load the data from a model.
* @param in Where to read the data from.
* @param model The model to load data into.
*/
private void handleData(ReadXML in, svm_model model) {
int i = 0;
int m = model.nr_class - 1;
int l = model.l;
model.sv_coef = new double[m][l];
model.SV = new svm_node[l][];
while (in.readToTag()) {
if (in.is(SVMNetworkPersistor.TAG_ROW, true)) {
String line = in.readTextToTag();
StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
for (int k = 0; k < m; k++)
model.sv_coef[k][i] = Double.parseDouble(st.nextToken());
int n = st.countTokens() / 2;
model.SV[i] = new svm_node[n];
for (int j = 0; j < n; j++) {
model.SV[i][j] = new svm_node();
model.SV[i][j].index = Integer.parseInt(st.nextToken());
model.SV[i][j].value = Double.parseDouble(st.nextToken());
}
i++;
} else if (in.is(SVMNetworkPersistor.TAG_DATA, false)) {
break;
}
}
}
/**
* Save a model.
* @param out Where to save a model to.
* @param model The model to save to.
*/
public static void saveModel(WriteXML out, svm_model model) {
if (model != null) {
out.beginTag(SVMNetworkPersistor.TAG_MODEL);
svm_parameter param = model.param;
out.addProperty(SVMNetworkPersistor.TAG_TYPE_SVM,
svm_type_table[param.svm_type]);
out.addProperty(SVMNetworkPersistor.TAG_TYPE_KERNEL,
kernel_type_table[param.kernel_type]);
if (param.kernel_type == svm_parameter.POLY) {
out.addProperty(SVMNetworkPersistor.TAG_DEGREE, param.degree);
}
if (param.kernel_type == svm_parameter.POLY
|| param.kernel_type == svm_parameter.RBF
|| param.kernel_type == svm_parameter.SIGMOID) {
out.addProperty(SVMNetworkPersistor.TAG_GAMMA, param.gamma);
}
if (param.kernel_type == svm_parameter.POLY
|| param.kernel_type == svm_parameter.SIGMOID) {
out.addProperty(SVMNetworkPersistor.TAG_COEF0, param.coef0);
}
int nr_class = model.nr_class;
int l = model.l;
out.addProperty(SVMNetworkPersistor.TAG_NUMCLASS, nr_class);
out.addProperty(SVMNetworkPersistor.TAG_TOTALSV, l);
out.addProperty(SVMNetworkPersistor.TAG_RHO, model.rho, nr_class
* (nr_class - 1) / 2);
out.addProperty(SVMNetworkPersistor.TAG_LABEL, model.label,
nr_class);
out.addProperty(SVMNetworkPersistor.TAG_PROB_A, model.probA,
nr_class * (nr_class - 1) / 2);
out.addProperty(SVMNetworkPersistor.TAG_PROB_B, model.probB,
nr_class * (nr_class - 1) / 2);
out.addProperty(SVMNetworkPersistor.TAG_NSV, model.nSV, nr_class);
out.beginTag(SVMNetworkPersistor.TAG_DATA);
double[][] sv_coef = model.sv_coef;
svm_node[][] SV = model.SV;
StringBuilder line = new StringBuilder();
for (int i = 0; i < l; i++) {
line.setLength(0);
for (int j = 0; j < nr_class - 1; j++)
line.append(sv_coef[j][i] + " ");
svm_node[] p = SV[i];
if (param.kernel_type == svm_parameter.PRECOMPUTED)
line.append("0:" + (int) (p[0].value));
else
for (int j = 0; j < p.length; j++)
line.append(p[j].index + ":" + p[j].value + " ");
out.addProperty(SVMNetworkPersistor.TAG_ROW, line.toString());
}
out.endTag();
out.endTag();
}
}
/**
* Save a SVMNetwork.
* @param obj The object to save.
* @param out Where to save it to.
*/
@Override
public void save(EncogPersistedObject obj, WriteXML out) {
PersistorUtil.beginEncogObject(EncogPersistedCollection.TYPE_SVM, out,
obj, true);
final SVMNetwork net = (SVMNetwork) obj;
out.addProperty(SVMNetworkPersistor.TAG_INPUT, net.getInputCount());
out.addProperty(SVMNetworkPersistor.TAG_OUTPUT, net.getOutputCount());
out.beginTag(SVMNetworkPersistor.TAG_MODELS);
for (int i = 0; i < net.getModels().length; i++) {
saveModel(out, net.getModels()[i]);
}
out.endTag();
out.endTag();
}
}