package rainbownlp.machinelearning.convertor; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.List; import org.hibernate.Session; import rainbownlp.core.FeatureValuePair; import rainbownlp.core.TimeStamp; import rainbownlp.machinelearning.ExpectedClassEnum; import rainbownlp.machinelearning.MLExample; import rainbownlp.machinelearning.MLExampleFeature; import rainbownlp.util.ConfigurationUtil; import rainbownlp.util.FileUtil; import rainbownlp.util.HibernateUtil; public class SVMLightFormatConvertor { static int numClassRatio = ConfigurationUtil.getValueInteger("numClassesRatio"); public final static String OUTPUT_FILE_EXTENSION = "txt"; public static String writeToFile(List<Integer> exampleIdsToWrite,String taskName) throws IOException { String filePath = getFilePath(taskName); // if(new File(filePath).exists()) return filePath; FileWriter file_writer = new FileWriter(filePath); int counter = 0; for(Integer example_id : exampleIdsToWrite) { counter++; FileUtil.logLine(null, "example_id: "+example_id); MLExample example = MLExample.getExampleById(example_id); String SVMLightFormatLine = getArtifactAttributes(example,taskName).trim(); if(example.getExpectedClass() == null){ FileUtil.logLine(FileUtil.DEBUG_FILE, "expected class is null!"); continue; } Double expectedClass = example.getNumericExpectedClass()+1; //convert to 1-base (e.g. 0 -> 1) FileUtil.logLine(null, "expected class: "+expectedClass); SVMLightFormatLine = expectedClass + " " + SVMLightFormatLine; file_writer.write(SVMLightFormatLine + "\n"); file_writer.flush(); FileUtil.logLine(FileUtil.DEBUG_FILE, "SVMLightFormatLine: "+SVMLightFormatLine); FileUtil.logLine(null, "example wrote: "+counter+"/"+exampleIdsToWrite.size()); // HibernateUtil.clearLoaderSession(); } file_writer.flush(); file_writer.close(); return filePath; } private static String getFilePath(String taskName) { String fold = (ConfigurationUtil.crossFoldCurrent>0)?("Fold"+ConfigurationUtil.crossFoldCurrent):""; String usedFeatures = ""; List<FeatureValuePair> features = FeatureValuePair.getAllFeatures(); for(FeatureValuePair fvp : features){ if(excludeAttributeIds.contains(fvp.getFeatureName())) continue; if(onlyIncludeAttributes.size()>0 && !onlyIncludeAttributes.contains(fvp.getFeatureName())) continue; usedFeatures += "-" + fvp.getTempFeatureIndex(); } String fileName = ConfigurationUtil.getValue("TempFolder")+ fold+ "SVMMultiClass-"+(ConfigurationUtil.TrainingMode?"train":"test")+"-"+ taskName+usedFeatures+"."+OUTPUT_FILE_EXTENSION; return fileName; } public static String writeToFileBinary(List<Integer> exampleIdsToWrite, String taskName) throws IOException { String filePath = getFilePath(taskName); // if(new File(filePath).exists()) return filePath; // if(Setting.TrainingMode) // FeatureValuePair.resetIndexes(); FileWriter file_writer = new FileWriter(filePath); int[] positiveNegativeCount = new int[]{0,0}; int counter = 0; for(Integer example_id : exampleIdsToWrite) { counter++; TimeStamp.setStart("Writing example : "+example_id); MLExample example = MLExample.getExampleById(example_id); TimeStamp.setStart("Getting features for "+example_id); String SVMLightFormatLine = getArtifactAttributes(example,taskName).trim(); TimeStamp.setEnd("Getting features for "+example_id); if(example.getExpectedClass() == null){ FileUtil.logLine(FileUtil.DEBUG_FILE, "expected class is null!"); continue; } Double expectedClass = example.getNumericExpectedClass()+1; //convert to 1-base (e.g. 0 -> 1) FileUtil.logLine(FileUtil.DEBUG_FILE, "expected class: "+expectedClass); if(ConfigurationUtil.TrainingMode && expectedClass == ExpectedClassEnum.BOOLEAN_NO.ordinal() && positiveNegativeCount[0] > (positiveNegativeCount[1]+1)*numClassRatio) continue; if(expectedClass==ExpectedClassEnum.BOOLEAN_NO.ordinal()){ expectedClass = -1D; positiveNegativeCount[0]++; }else{ expectedClass = 1D; positiveNegativeCount[1]++; } SVMLightFormatLine = expectedClass + " " + SVMLightFormatLine; file_writer.write(SVMLightFormatLine + "\n"); file_writer.flush(); FileUtil.logLine(FileUtil.DEBUG_FILE, "SVMLightFormatLine: "+SVMLightFormatLine); // HibernateUtil.clearLoaderSession(); TimeStamp.setEnd("Writing example : "+example_id); FileUtil.logLine(null, "example wrote: "+counter+"/"+exampleIdsToWrite.size()); } file_writer.flush(); file_writer.close(); return filePath; } public static List<String> excludeAttributeIds = new ArrayList<String>(); public static List<String> onlyIncludeAttributes = new ArrayList<String>(); public static ArrayList<String> usedFeatureNames = new ArrayList<String>(); private static String getArtifactAttributes(MLExample example, String taskName) { FileUtil.logLine(FileUtil.DEBUG_FILE, "getArtifactAttributes(): creating feature string for example: "+example.getExampleId()); String SVMLightFormatLine = ""; Session old_session = MLExample.hibernateSession; MLExample.hibernateSession = HibernateUtil.sessionFactory.openSession(); List<MLExampleFeature> features = example.getExampleFeatures(); for(int i=0;i<features.size();i++) { MLExampleFeature feature = features.get(i); FeatureValuePair fvp = feature.getFeatureValuePair(); if(excludeAttributeIds.contains(fvp.getFeatureName())) { FileUtil.logLine(FileUtil.DEBUG_FILE, "getArtifactAttributes(): skipping the feature, excludeAttributeIds includes feature: "+fvp.getFeatureName()); continue; } if(onlyIncludeAttributes.size()>0 && !onlyIncludeAttributes.contains(fvp.getFeatureName())) { FileUtil.logLine(FileUtil.DEBUG_FILE, "getArtifactAttributes(): skipping the feature, onlyIncludeAttributes is not empty and it doesn't include feature: "+fvp.getFeatureName()); continue; } int featureIndex = fvp.getTempFeatureIndex(); if(ConfigurationUtil.TrainingMode && featureIndex==Integer.MAX_VALUE) { featureIndex = FeatureValuePair.getMaxIndex() + 1; if(featureIndex==0) featureIndex++; fvp.setTempFeatureIndex(featureIndex); HibernateUtil.save(fvp); } if(featureIndex==Integer.MAX_VALUE || featureIndex==-1){ FileUtil.logLine(FileUtil.DEBUG_FILE, "getArtifactAttributes(): skipping the feature, feature index is not set for feature: "+fvp.getFeatureName()); continue; } Double numericValue = 0.0; if(fvp.getFeatureValueAuxiliary()!=null) numericValue = Double.parseDouble(fvp.getFeatureValueAuxiliary()); else numericValue = Double.parseDouble(fvp.getFeatureValue()); if (numericValue!=null && numericValue != 0 && !Double.isNaN(numericValue) && !Double.isInfinite(numericValue)) { // double maxVal = getAttributeMaxValue(attribute_id); // numericValue = numericValue/maxVal; // FileUtil.logLine(FileUtil.DEBUG_FILE, fvp.getFeatureName()); if (!usedFeatureNames.contains(fvp.getFeatureName())) { FileUtil.logLine("/tmp/featuresUsed", fvp.getFeatureName()); usedFeatureNames.add(fvp.getFeatureName()); } SVMLightFormatLine += featureIndex + ":" + numericValue + " "; } HibernateUtil.clearLoaderSession(); } MLExample.hibernateSession.clear(); MLExample.hibernateSession.close(); MLExample.hibernateSession = old_session; FileUtil.logLine(FileUtil.DEBUG_FILE, "getArtifactAttributes(): feature string for example: "+example.getExampleId()+" -> "+SVMLightFormatLine); return SVMLightFormatLine; } }