/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.experiment.sandbox; import rapaio.core.RandomSource; import rapaio.core.distributions.Normal; import rapaio.core.stat.Maximum; import rapaio.core.stat.Minimum; import rapaio.data.*; import rapaio.datasets.Datasets; import rapaio.experiment.grid.MeshGrid1D; import rapaio.graphics.opt.ColorGradient; import rapaio.graphics.plot.Plot; import rapaio.graphics.plot.plotcomp.MeshContour; import rapaio.graphics.plot.plotcomp.Points; import rapaio.ml.classifier.CFit; import rapaio.ml.classifier.Classifier; import rapaio.ml.classifier.ensemble.CForest; import rapaio.printer.Summary; import java.io.IOException; import java.net.URISyntaxException; import static rapaio.graphics.Plotter.*; import static rapaio.sys.WS.draw; /** * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> at 1/23/15. */ @Deprecated public class IrisContour { public static void main(String[] args) throws IOException, URISyntaxException { RandomSource.setSeed(1); final String X = "petal-length"; final String Y = "sepal-width"; Frame iris = Datasets.loadIrisDataset().mapVars(X, Y, "class").stream().filter(s -> s.index(2) != 3).toMappedFrame(); Var trimmedClass = Nominal.from(iris.rowCount(), row -> iris.label(row, "class")).withName("class"); Frame tr = BoundFrame.byVars(iris.var(X), iris.var(Y), trimmedClass); Normal g1 = new Normal(0, 2); Normal g2 = new Normal(0, 5); for (int i = 0; i < iris.rowCount(); i++) { if (iris.index(i, 2) == 1) { iris.setValue(i, 0, g1.sampleNext()); iris.setValue(i, 1, g2.sampleNext()); } else { iris.setValue(i, 0, 4 + g1.sampleNext()); iris.setValue(i, 1, 9 + g2.sampleNext()); } } Summary.printSummary(iris); Classifier c = CForest.newRF().withMCols(1).withRuns(1_000); c.train(iris, "class"); Numeric x = Numeric.seq(Minimum.from(iris.var(X)).value(), Maximum.from(iris.var(X)).value(), 0.1).withName(X); Numeric y = Numeric.seq(Minimum.from(iris.var(Y)).value(), Maximum.from(iris.var(Y)).value(), 0.2).withName(Y); MeshGrid1D mg1 = new MeshGrid1D(x, y); // build a classification data sets with all required points Numeric sl = Numeric.empty().withName(X); Numeric sw = Numeric.empty().withName(Y); for (int i = 0; i < x.rowCount(); i++) { for (int j = 0; j < y.rowCount(); j++) { sl.addValue(mg1.getX().value(i)); sw.addValue(mg1.getY().value(j)); } } CFit cr2 = c.fit(SolidFrame.byVars(sl, sw)); c.fit(iris).printSummary(); int pos = 0; for (int i = 0; i < x.rowCount(); i++) { for (int j = 0; j < y.rowCount(); j++) { mg1.setValue(i, j, cr2.firstDensity().value(pos, 1)); pos++; } } Plot p = new Plot(); double[] qq = Numeric.seq(0, 1, 0.02).stream().mapToDouble().toArray(); qq[qq.length - 1] = Double.POSITIVE_INFINITY; ColorGradient bcg = ColorGradient.newHueGradient(qq); for (int i = 0; i < qq.length - 1; i++) { p.add(new MeshContour(mg1.compute(qq[i], qq[i + 1]), false, true, lwd(0.1f), color(bcg.getColor(i)))); } p.add(new Points(iris.var(0), iris.var(1), color(iris.var(2)), pch(2))); draw(p); } }