package edu.isi.karma.modeling.semantictypes; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import edu.isi.karma.modeling.semantictypes.crfmodelhandler.CRFModelHandler; import edu.isi.karma.modeling.semantictypes.crfmodelhandler.CRFModelHandler.ColumnFeature; import edu.isi.karma.rep.HNodePath; import edu.isi.karma.rep.Worksheet; import edu.isi.karma.rep.alignment.SemanticType; public class SemanticTypeTrainingThread implements Runnable { private final CRFModelHandler crfModelHandler; private final Worksheet worksheet; private final SemanticType newType; private final Logger logger = LoggerFactory.getLogger(SemanticTypeTrainingThread.class); public SemanticTypeTrainingThread(CRFModelHandler crfModelHandler, Worksheet worksheet, SemanticType newType) { this.crfModelHandler = crfModelHandler; this.worksheet = worksheet; this.newType = newType; } public void run() { long start = System.currentTimeMillis(); // Find the corresponding hNodePath. Used to find examples for training the CRF Model. HNodePath currentColumnPath = null; List<HNodePath> paths = worksheet.getHeaders().getAllPaths(); for (HNodePath path : paths) { if (path.getLeaf().getId().equals(newType.getHNodeId())) { currentColumnPath = path; break; } } Map<ColumnFeature, Collection<String>> columnFeatures = new HashMap<ColumnFeature, Collection<String>>(); // Prepare the column name for training // String columnName = currentColumnPath.getLeaf().getColumnName(); // Collection<String> columnNameList = new ArrayList<String>(); // columnNameList.add(columnName); // columnFeatures.put(ColumnFeature.ColumnHeaderName, columnNameList); // Train the model with the new type ArrayList<String> trainingExamples = SemanticTypeUtil.getTrainingExamples(worksheet, currentColumnPath); boolean trainingResult = false; trainingResult = crfModelHandler.addOrUpdateLabel(newType.getCrfModelLabelString(),trainingExamples, columnFeatures); if (!trainingResult) { logger.error("Error occured while training CRF Model."); } // logger.debug("Using type:" + newType.getDomain().getUri() + "|" + newType.getType().getUri()); // Add the new CRF column model for this column ArrayList<String> labels = new ArrayList<String>(); ArrayList<Double> scores = new ArrayList<Double>(); trainingResult = crfModelHandler.predictLabelForExamples(trainingExamples, 4, labels, scores, null, columnFeatures); if (!trainingResult) { logger.error("Error occured while predicting labels"); } CRFColumnModel newModel = new CRFColumnModel(labels, scores); worksheet.getCrfModel().addColumnModel(newType.getHNodeId(), newModel); long elapsedTimeMillis = System.currentTimeMillis() - start; float elapsedTimeSec = elapsedTimeMillis / 1000F; logger.info("Time required for training the semantic type: " + elapsedTimeSec); // long t2 = System.currentTimeMillis(); // Identify the outliers for the column // SemanticTypeUtil.identifyOutliers(worksheet, newTypeString,currentColumnPath, vWorkspace.getWorkspace().getTagsContainer() // .getTag(TagName.Outlier), columnFeatures, crfModelHandler); // long t3 = System.currentTimeMillis(); // logger.info("Identify outliers: "+ (t3-t2)); } }