package com.datascience.gal.dataGenerator; import com.datascience.core.base.AssignedLabel; import com.datascience.core.base.LObject; import com.datascience.core.base.Worker; import com.datascience.utils.CostMatrix; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import com.google.gson.reflect.TypeToken; import org.apache.log4j.Logger; import java.io.*; import java.lang.reflect.Type; import java.util.*; /** * Test data manager is used for managing text files containing test data. * * @author piotr.gnys@10clouds.com */ public class DataManager { /** * Saves collection of test objects to file. * * @param filename * Name of file to with collection will be saved * @param objects * Collection of test objects * @throws IOException * Thrown if program was unable to save objects to file */ private static final String defaultCharsetEncoding = "UTF-8"; public void saveTestObjectsToFile(String filename, TroiaObjectCollection objects) throws IOException { logger.info("Saving test objects to file"); FileOutputStream stream = new FileOutputStream(filename); Writer out = new OutputStreamWriter(stream, defaultCharsetEncoding); for (String object : objects) { out.write(object + '\t' + objects.getCategory(object) + '\n'); } out.close(); } /** * Loads collection of test object from file. File must be correctly * formatted with means that each line should have pairs object-category * separated by tabulator, for example : Object-4 Category-0 Object-5 * Category-1 Object-2 Category-0 * * @param filename * Name of file that contains test objects. * @return Collection of test object generated with data from file. * @throws FileNotFoundException * If there is no file with given name. */ public TroiaObjectCollection loadTestObjectsFromFile(String filename) throws FileNotFoundException { logger.info("Loadin test objects from file"); FileInputStream stream = new FileInputStream(filename); Scanner scanner = new Scanner(stream, defaultCharsetEncoding); String line, objectName, objectCategory; TroiaObjectCollection testObjects = new TroiaObjectCollection(); while (scanner.hasNextLine()) { line = scanner.nextLine(); objectName = line.substring(0, line.indexOf('\t')); objectCategory = line.substring(line.indexOf('\t') + 1, line.length()); testObjects.setCategory(objectName, objectCategory); } scanner.close(); return testObjects; } /** * Saves JSONified collection of artificial workers to file. * * @param filename * Target file * @param workers * Collection of workers that will be saved * @throws IOException * Thrown if program was unable to save workers to file */ public void saveArtificialWorkers(String filename, Collection<ArtificialWorker> workers) throws IOException { logger.info("Saving artificial workers to file"); FileOutputStream stream = new FileOutputStream(filename); Writer out = new OutputStreamWriter(stream, defaultCharsetEncoding); Gson gson = new Gson(); out.write(gson.toJson(workers)); out.close(); } /** * Saves JSONified collection of artificial worker qualities to file. * * @param filename * Target file * @param qualities * Collection of qualities that will be saved * @throws IOException * Thrown if program was unable to save workers to file */ public void saveArtificialWorkerQualities(String filename, Collection<Map<String, Object>> qualities) throws IOException { logger.info("Saving qualities of artificial workers to file"); FileOutputStream stream = new FileOutputStream(filename); Writer out = new OutputStreamWriter(stream, defaultCharsetEncoding); Gson gson = new GsonBuilder().setPrettyPrinting().create(); out.write(gson.toJson(qualities)); out.close(); } /** * Saves JSONified misclassification cost matrix. * * @param filename * Target file * @throws IOException * Thrown if program was unable to save workers to file */ public void saveCostMatrix(String filename, CostMatrix<String> cm, Collection<String> categories) throws IOException { logger.info("Saving cost matrix to file"); FileOutputStream stream = new FileOutputStream(filename); Writer out = new OutputStreamWriter(stream, defaultCharsetEncoding); for (String c0 : categories) { for (String c1 : categories) { out.append(c0 + " " + c1 + " " + cm.getCost(c0, c1) + "\n"); } } out.close(); } /** * Loads fully configured artificial workers from file. This file should * contain JSONified ArtificialWorker classes. File containing those workers * should be generated by "saveArtificialWorkers" function, not manually by * user. For files in with user can define workers you should see * "loadBasicWorkers" function. * * @param filename * Name of file containing JSONified artificial workers * @return Collection of artificial workers fetched from file * @throws FileNotFoundException * If there is no file with given name. */ public Collection<ArtificialWorker> loadArtificialWorkersFromFile( String filename) throws FileNotFoundException { logger.info("Loading artificial workers from file"); FileInputStream stream = new FileInputStream(filename); Scanner scanner = new Scanner(stream, defaultCharsetEncoding); Gson gson = new Gson(); Type collectionType = new TypeToken<Collection<ArtificialWorker>>() {} .getType(); Collection<ArtificialWorker> workers; scanner.useDelimiter("\\Z"); workers = gson.fromJson(scanner.next(),collectionType); scanner.close(); return workers; } /** * Loads basic workers description from file and uses it to generate * collection of artificial workers in environment given as a parameter. * Correct format of file is worker-quality pairs separated by tabulation, * for example Worker1 0.4 Worker2 0.2 Worker3 1 * * @param filename * Name of file with contains basic workers definitions * @param categories * Categoris with workers will be assigning * @return Collection of artificial workers with names and qualities defined * in file * @throws FileNotFoundException * If there is no file with given name. */ public Collection<ArtificialWorker> loadBasicWorkers(String filename, Collection<String> categories) throws FileNotFoundException { logger.info("Loading worker definintion from baisc workers file."); FileInputStream stream = new FileInputStream(filename); Scanner scanner = new Scanner(stream, defaultCharsetEncoding); String line; Collection<ArtificialWorker> workers = new ArrayList<ArtificialWorker>(); DataGenerator genreator = DataGenerator.getInstance(); String workerName; Double workerQuality; while (scanner.hasNextLine()) { line = scanner.nextLine(); workerName = line.substring(0, line.indexOf('\t')); workerQuality = Double.parseDouble(line.substring( line.indexOf('\t') + 1, line.length())); workers.add(genreator.generateArtificialWorker(workerName, workerQuality, categories)); } scanner.close(); return workers; } /** * Saves labels to file * * @param filename * Name of file * @param labels * @throws IOException */ public void saveLabelsToFile(String filename, Collection<AssignedLabel<String>> labels) throws IOException { logger.info("Saving labels to file"); FileOutputStream stream = new FileOutputStream(filename); Writer out = new OutputStreamWriter(stream, defaultCharsetEncoding); for (AssignedLabel<String> label : labels) { out.write(label.getWorker().getName() + '\t' + label.getLobject().getName() + '\t' + label.getLabel() + '\n'); } out.close(); } /** * Loads labels from correctly formatted text file. Format must be * worker_name tabulator object_id tabulator object class For example : * Worker-7 Object-6 Category-1 Worker-8 Object-7 Category-2 Worker-9 * Object-7 Category-2 * * @param filename * @return * @throws FileNotFoundException */ public Collection<AssignedLabel<String>> loadLabelsFromFile(String filename) throws FileNotFoundException { logger.info("Loading labels from file"); FileInputStream stream = new FileInputStream(filename); Scanner scanner = new Scanner(stream, defaultCharsetEncoding); String line; Collection<AssignedLabel<String>> labels = new ArrayList<AssignedLabel<String>>(); while (scanner.hasNextLine()) { line = scanner.nextLine(); labels.add(this.parseLabelFromString(line)); } scanner.close(); return labels; } /** * Parses string formatted as * <workerName><tabulation><objectName><tabulation><categoryName> into * label. * * @param line * @return */ public AssignedLabel<String> parseLabelFromString(String line) { String objectName, objectCategory, workerName; workerName = line.substring(0, line.indexOf('\t')); objectName = line.substring(line.indexOf('\t') + 1, line.lastIndexOf('\t')); objectCategory = line.substring(line.lastIndexOf('\t') + 1, line.length()); return new AssignedLabel<String>(new Worker(workerName), new LObject(objectName), objectCategory); } /** * Saves gold labels to file * * @param filename * @param labels * @throws IOException */ public void saveGoldLabelsToFile(String filename, Collection<LObject<String>> labels) throws IOException { logger.info("Saving gold labels objects to file"); FileOutputStream stream = new FileOutputStream(filename); Writer out = new OutputStreamWriter(stream, defaultCharsetEncoding); for (LObject<String> label : labels) { // out.write(label.getObjectName() + '\t' + label.getCorrectCategory() // + '\n'); out.write(label.getName() + '\t' + label.getGoldLabel() + '\n'); } out.close(); } /** * Loads gold labels from file * * @param filename * @return * @throws FileNotFoundException */ public Collection<LObject<String>> loadGoldLabelsFromFile(String filename) throws FileNotFoundException { logger.info("Loading gold labels from file"); FileInputStream stream = new FileInputStream(filename); Scanner scanner = new Scanner(stream, defaultCharsetEncoding); String line, objectName, objectCategory; Collection<LObject<String>> goldLabels = new ArrayList<LObject<String>>(); while (scanner.hasNextLine()) { line = scanner.nextLine(); objectName = line.substring(0, line.indexOf('\t')); objectCategory = line.substring(line.indexOf('\t') + 1, line.length()); LObject<String> object = new LObject<String>(objectName); object.setGoldLabel(objectCategory); goldLabels.add(object); } scanner.close(); return goldLabels; } /** * Loads evaluation labels from file * * @param filename * @return * @throws FileNotFoundException */ public Collection<LObject<String>> loadEvaluationLabelsFromFile(String filename) throws FileNotFoundException { logger.info("Loading evaluation labels from file"); FileInputStream stream = new FileInputStream(filename); Scanner scanner = new Scanner(stream, defaultCharsetEncoding); String line, objectName, objectCategory; Collection<LObject<String>> evaluationLabels = new ArrayList<LObject<String>>(); while (scanner.hasNextLine()) { line = scanner.nextLine(); objectName = line.substring(0, line.indexOf('\t')); objectCategory = line.substring(line.indexOf('\t') + 1, line.length()); LObject<String> object = new LObject<String>(objectName); object.setEvaluationLabel(objectCategory); evaluationLabels.add(object); } scanner.close(); return evaluationLabels; } /** * Loads category names with probabilities of their occurence from file * with formatting <category_name><tabulation><category_prior> * @param filename Name of file containing categories with probabilities * @return Map that associates category name with probability of category occurence */ public Map<String, Double> loadCategoriesWithProbabilities(String filename) throws FileNotFoundException { logger.info("Loading prior file"); FileInputStream stream = new FileInputStream(filename); Scanner scanner = new Scanner(stream, defaultCharsetEncoding); String line; Map<String, Double> categories = new HashMap<String, Double>(); String categoryName; Double categoryProbability; while (scanner.hasNextLine()) { line = scanner.nextLine(); categoryName = line.substring(0, line.indexOf('\t')); categoryProbability = Double.parseDouble(line.substring( line.indexOf('\t') + 1, line.length())); categories.put(categoryName, categoryProbability); } scanner.close(); return categories; } public void saveTestData(String filename_base, Data data) throws IOException { logger.info("Saving test data"); if(data.getArtificialWorkers()!=null) { this.saveArtificialWorkers(filename_base + ARTIFICIAL_WORKERS_TAG + FILE_EXTENSION, data.getArtificialWorkers()); } if(data.getArtificialWorkerQualities()!=null) { this.saveArtificialWorkerQualities(filename_base + ARTIFICIAL_WORKER_QUALITIES_TAG + FILE_EXTENSION, data.getArtificialWorkerQualities()); } if(data.getCategories()!=null) { this.saveCostMatrix(filename_base + MISCLASSIFICATION_COST_TAG + FILE_EXTENSION, data.getCostMatrix(), data.getCategories()); } if(data.getGoldLabels()!=null) { this.saveGoldLabelsToFile(filename_base + GOLD_LABELS_TAG + FILE_EXTENSION, data.getGoldLabels()); } if(data.getLabels()!=null) { this.saveLabelsToFile(filename_base + LABELS_TAG + FILE_EXTENSION, data.getLabels()); } if(data.getObjectCollection()!=null) { this.saveTestObjectsToFile( filename_base + OBJECTS_TAG + FILE_EXTENSION, data.getObjectCollection()); } } public Data loadTestData(String filename_base) throws FileNotFoundException { Data data = new Data(); data.setArtificialWorkers(this .loadArtificialWorkersFromFile(filename_base + ARTIFICIAL_WORKERS_TAG + FILE_EXTENSION)); data.setGoldLabels(this.loadGoldLabelsFromFile(filename_base + GOLD_LABELS_TAG + FILE_EXTENSION)); data.setLabels(this.loadLabelsFromFile(filename_base + LABELS_TAG + FILE_EXTENSION)); data.setObjectCollection(this.loadTestObjectsFromFile(filename_base + OBJECTS_TAG + FILE_EXTENSION)); data.setWorkers(this.extractWorkerNamesFromLabels(data.getLabels())); data.setCategories(this.extractCategoryNamesFromLabels(data.getLabels())); return data; } public Collection<String> extractWorkerNamesFromLabels( Collection<AssignedLabel<String>> labels) { Collection<String> workers = new ArrayList<String>(); for (AssignedLabel<String> label : labels) { if (!workers.contains(label.getWorker().getName())) { workers.add(label.getWorker().getName()); } } return workers; } public Collection<String> extractCategoryNamesFromLabels( Collection<AssignedLabel<String>> labels) { Collection<String> categories = new ArrayList<String>(); for (AssignedLabel<String> label : labels) { if (!categories.contains(label.getLabel())) { categories.add(label.getLabel()); } } return categories; } public String converToJSON(Data data) { Gson gson = new Gson(); return gson.toJson(data); } public Data loadFromJSON(String jsonifiedData) { Gson gson = new Gson(); return gson.fromJson(jsonifiedData, Data.class); } /** * @return the instance */ public static DataManager getInstance() { return instance; } private static DataManager instance = new DataManager(); private DataManager() { } private static final String ARTIFICIAL_WORKERS_TAG = "_aiworker"; private static final String ARTIFICIAL_WORKER_QUALITIES_TAG = "_aiworker_qual"; private static final String MISCLASSIFICATION_COST_TAG = "_cost"; private static final String LABELS_TAG = "_labels"; private static final String GOLD_LABELS_TAG = "_goldLabels"; private static final String OBJECTS_TAG = "_objects"; private static final String FILE_EXTENSION = ".txt"; private static Logger logger = Logger.getLogger(DataManager.class); }