/**
* Copyright 2014 Marco Cornolti
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package it.acubelab.smaph.learn;
import it.unipi.di.acube.batframework.metrics.MetricsResultSet;
import it.unipi.di.acube.batframework.systemPlugins.WATAnnotator;
import it.unipi.di.acube.batframework.utils.FreebaseApi;
import it.unipi.di.acube.batframework.utils.WikipediaApiInterface;
import it.acubelab.smaph.SmaphAnnotator;
import it.acubelab.smaph.SmaphAnnotatorDebugger;
import it.acubelab.smaph.SmaphConfig;
import it.cnr.isti.hpc.erd.WikipediaToFreebase;
import java.util.*;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_problem;
import org.apache.commons.lang3.tuple.Triple;
public class GenerateModel {
public static void main(String[] args) throws Exception {
Locale.setDefault(Locale.US);
String freebKey = "";
SmaphConfig.setConfigFile("smaph-config.xml");
String bingKey = SmaphConfig.getDefaultBingKey();
WikipediaApiInterface wikiApi = new WikipediaApiInterface(
"wid.cache", "redirect.cache");
FreebaseApi freebApi = new FreebaseApi(freebKey, "freeb.cache");
double[][] paramsToTest = new double[][] {
/*
* {0.035, 0.5 }, {0.035, 1 }, {0.035, 4 }, {0.035, 8 }, {0.035, 10 },
* {0.035, 16 }, {0.714, .5 }, {0.714, 1 }, {0.714, 4 }, {0.714, 8 },
* {0.714, 10 }, {0.714, 16 }, {0.9, .5 }, {0.9, 1 }, {0.9, 4 }, {0.9, 8
* }, {0.9, 10 }, {0.9, 16 },
*
* { 1.0/15.0, 1 }, { 1.0/27.0, 1 },
*/
/*
* {0.01, 1}, {0.01, 5}, {0.01, 10}, {0.03, 1}, {0.03, 5}, {0.03, 10},
* {0.044, 1}, {0.044, 5}, {0.044, 10}, {0.06, 1}, {0.06, 5}, {0.06,
* 10},
*/
{ 0.03, 5 }, };
double[][] weightsToTest = new double[][] {
/*
* { 3, 4 }
*/
{ 3.8, 3 },
{ 3.8, 4 },
{ 3.8, 5 },
{ 3.8, 6 },
{ 3.8, 7 },
{ 3.8, 8 },
{ 3.8, 9 },
{ 3.8, 10 },
};
Integer[][] featuresSetsToTest = new Integer[][] {
//{ 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 },
{ 1, 2, 3, 6, 7, 9, 10, 11,12,13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 },
/*
* { 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
*/
}; // < -------------------------------------- MIND THIS
int wikiSearckTopK = 10; // <---------------------------
String filePrefix = "_ANW";// <---------------------------
WikipediaToFreebase wikiToFreebase = new WikipediaToFreebase("mapdb");
List<ModelConfigurationResult> mcrs = new Vector<>();
for (double editDistanceThr = 0.7; editDistanceThr <= 0.7; editDistanceThr += 0.7) {
SmaphAnnotator bingAnnotator = GenerateTrainingAndTest
.getDefaultBingAnnotator(wikiApi, wikiToFreebase,
editDistanceThr, wikiSearckTopK, bingKey);
WATAnnotator.setCache("wikisense.cache");
SmaphAnnotator.setCache(SmaphConfig.getDefaultBingCache());
BinaryExampleGatherer trainEntityFilterGatherer = new BinaryExampleGatherer();
BinaryExampleGatherer testEntityFilterGatherer = new BinaryExampleGatherer();
GenerateTrainingAndTest
.gatherExamplesTrainingAndDevel(bingAnnotator,
trainEntityFilterGatherer,
testEntityFilterGatherer, wikiApi, wikiToFreebase,
freebApi);
SmaphAnnotator.unSetCache();
BinaryExampleGatherer trainGatherer = trainEntityFilterGatherer; // //////////////
// <----------------------
BinaryExampleGatherer testGatherer = testEntityFilterGatherer; // //////////////
// <----------------------
int count = 0;
for (Integer[] ftrToTestArray : featuresSetsToTest) {
// double gamma = 1.0 / ftrToTestArray.length; //
// <--------------------- MIND THIS
// double C = 1;// < -------------------------------------- MIND
// THIS
for (double[] paramsToTestArray : paramsToTest) {
double gamma = paramsToTestArray[0];
double C = paramsToTestArray[1];
for (double[] weightsPosNeg : weightsToTest) {
double wPos = weightsPosNeg[0], wNeg = weightsPosNeg[1];
Vector<Integer> features = new Vector<>(
Arrays.asList(ftrToTestArray));
Triple<svm_problem, double[], double[]> ftrsMinsMaxs = TuneModel
.getScaledTrainProblem(features, trainGatherer);
svm_problem trainProblem = ftrsMinsMaxs.getLeft();
String fileBase = getModelFileNameBaseEF(
features.toArray(new Integer[0]), wPos, wNeg,
editDistanceThr, gamma, C) + filePrefix;
/*
* String fileBase = getModelFileNameBaseEQF(
* features.toArray(new Integer[0]), wPos, wNeg);
*/// < -------------------------
LibSvmUtils.dumpRanges(ftrsMinsMaxs.getMiddle(),
ftrsMinsMaxs.getRight(), fileBase + ".range");
svm_model model = TuneModel.trainModel(wPos, wNeg,
features, trainProblem, gamma, C);
svm.svm_save_model(fileBase + ".model", model);
MetricsResultSet metrics = TuneModel.ParameterTester
.computeMetrics(model, TuneModel
.getScaledTestProblems(features,
testGatherer,
ftrsMinsMaxs.getMiddle(),
ftrsMinsMaxs.getRight()));
int tp = metrics.getGlobalTp();
int fp = metrics.getGlobalFp();
int fn = metrics.getGlobalFn();
float microF1 = metrics.getMicroF1();
float macroF1 = metrics.getMacroF1();
float macroRec = metrics.getMacroRecall();
float macroPrec = metrics.getMacroPrecision();
int totVects = testGatherer.getExamplesCount();
mcrs.add(new ModelConfigurationResult(features, wPos,
wNeg, editDistanceThr, tp, fp, fn, totVects
- tp - fp - fn, microF1, macroF1,
macroRec, macroPrec));
System.err.printf("Trained %d/%d models.%n", ++count,
weightsToTest.length
* featuresSetsToTest.length
* paramsToTest.length);
}
}
}
}
for (ModelConfigurationResult mcr : mcrs)
System.out.printf("%.5f%%\t%.5f%%\t%.5f%%%n",
mcr.getMacroPrecision() * 100, mcr.getMacroRecall() * 100,
mcr.getMacroF1() * 100);
for (double[] weightPosNeg : weightsToTest)
System.out.printf("%.5f\t%.5f%n", weightPosNeg[0], weightPosNeg[1]);
for (ModelConfigurationResult mcr : mcrs)
System.out.println(mcr.getReadable());
for (double[] paramGammaC : paramsToTest)
System.out.printf("%.5f\t%.5f%n", paramGammaC[0], paramGammaC[1]);
WATAnnotator.flush();
}
public static String getModelFileNameBaseEF(Integer[] ftrs, double wPos,
double wNeg, double editDistance, double gamma, double C) {
Vector<Integer> features = new Vector<Integer>(Arrays.asList(ftrs));
Collections.sort(features);
String filename = "models/model_";
for (int f : features)
filename += f + (f == features.get(features.size() - 1) ? "" : ",");
filename += String.format("_%.5f_%.5f_%.3f_%.8f_%.8f", wPos, wNeg,
editDistance, gamma, C);
return filename;
}
public static String getModelFileNameBaseEQF(Integer[] ftrs, double wPos,
double wNeg) {
Vector<Integer> features = new Vector<Integer>(Arrays.asList(ftrs));
Collections.sort(features);
String filename = "models/EQ_model_";
for (int f : features)
filename += f + (f == features.get(features.size() - 1) ? "" : ",");
filename += String.format("_%.5f_%.5f", wPos, wNeg);
return filename;
}
}