/* * Copyright (c) 2015 Villu Ruusmann * * This file is part of JPMML-SkLearn * * JPMML-SkLearn is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SkLearn is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>. */ package org.jpmml.sklearn; import java.io.File; import java.io.FileOutputStream; import java.io.OutputStream; import java.util.Collections; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import org.dmg.pmml.PMML; import org.jpmml.model.MetroJAXBUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import sklearn.Estimator; import sklearn.pipeline.Pipeline; import sklearn2pmml.PMMLPipeline; public class Main { @Parameter ( names = "--help", description = "Show the list of configuration options and exit", help = true ) private boolean help = false; @Parameter ( names = {"--pkl-pipeline-input", "--pkl-input"}, description = "Pickle input file", required = true ) private File input = null; @Parameter ( names = "--pmml-output", description = "PMML output file", required = true ) private File output = null; static public void main(String... args) throws Exception { Main main = new Main(); JCommander commander = new JCommander(main); commander.setProgramName(Main.class.getName()); try { commander.parse(args); } catch(ParameterException pe){ StringBuilder sb = new StringBuilder(); sb.append(pe.toString()); sb.append("\n"); commander.usage(sb); System.err.println(sb.toString()); System.exit(-1); } if(main.help){ StringBuilder sb = new StringBuilder(); commander.usage(sb); System.out.println(sb.toString()); System.exit(0); } main.run(); } public void run() throws Exception { Object object; try(Storage storage = PickleUtil.createStorage(this.input)){ logger.info("Parsing PKL.."); long start = System.currentTimeMillis(); object = PickleUtil.unpickle(storage); long end = System.currentTimeMillis(); logger.info("Parsed PKL in {} ms.", (end - start)); } catch(Exception e){ logger.error("Failed to parse PKL", e); throw e; } if(!(object instanceof PMMLPipeline)){ // Create a single- or multi-step PMMLPipeline from a Pipeline if(object instanceof Pipeline){ Pipeline pipeline = (Pipeline)object; object = new PMMLPipeline() .setSteps(pipeline.getSteps()); } else // Create a single-step PMMLPipeline from an Estimator if(object instanceof Estimator){ Estimator estimator = (Estimator)object; object = new PMMLPipeline() .setSteps(Collections.<Object[]>singletonList(new Object[]{"estimator", estimator})); } else { throw new IllegalArgumentException("The object (" + ClassDictUtil.formatClass(object) + ") is not a PMMLPipeline"); } } PMMLPipeline pipeline = (PMMLPipeline)object; PMML pmml; try { logger.info("Converting.."); long begin = System.currentTimeMillis(); pmml = pipeline.encodePMML(); long end = System.currentTimeMillis(); logger.info("Converted in {} ms.", (end - begin)); } catch(Exception e){ logger.error("Failed to convert", e); throw e; } try(OutputStream os = new FileOutputStream(this.output)){ logger.info("Marshalling PMML.."); long start = System.currentTimeMillis(); MetroJAXBUtil.marshalPMML(pmml, os); long end = System.currentTimeMillis(); logger.info("Marshalled PMML in {} ms.", (end - start)); } catch(Exception e){ logger.error("Failed to marshal PMML", e); throw e; } } public File getInput(){ return this.input; } public void setInput(File input){ this.input = input; } public File getOutput(){ return this.output; } public void setOutput(File output){ this.output = output; } private static final Logger logger = LoggerFactory.getLogger(Main.class); }