/* * ARX: Powerful Data Anonymization * Copyright 2012 - 2017 Fabian Prasser, Florian Kohlmayer and contributors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.deidentifier.arx.aggregates.classification; import org.apache.mahout.classifier.sgd.ElasticBandPrior; import org.apache.mahout.classifier.sgd.L1; import org.apache.mahout.classifier.sgd.L2; import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; import org.apache.mahout.classifier.sgd.PriorFunction; import org.apache.mahout.classifier.sgd.UniformPrior; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; import org.deidentifier.arx.ARXLogisticRegressionConfiguration; import org.deidentifier.arx.DataHandleInternal; /** * Implements a classifier * @author Fabian Prasser */ public class MultiClassLogisticRegression implements ClassificationMethod { /** Config */ private final ARXLogisticRegressionConfiguration config; /** Encoder */ private final ConstantValueEncoder interceptEncoder; /** Instance */ private final OnlineLogisticRegression lr; /** Specification */ private final ClassificationDataSpecification specification; /** Encoder */ private final StaticWordValueEncoder wordEncoder; /** * Creates a new instance * @param specification * @param config */ public MultiClassLogisticRegression(ClassificationDataSpecification specification, ARXLogisticRegressionConfiguration config) { // Store this.config = config; this.specification = specification; // Prepare classifier PriorFunction prior = null; switch (config.getPriorFunction()) { case ELASTIC_BAND: prior = new ElasticBandPrior(); break; case L1: prior = new L1(); break; case L2: prior = new L2(); break; case UNIFORM: prior = new UniformPrior(); break; default: throw new IllegalArgumentException("Unknown prior function"); } this.lr = new OnlineLogisticRegression(this.specification.classMap.size(), config.getVectorLength(), prior); // Configure this.lr.learningRate(config.getLearningRate()); this.lr.alpha(config.getAlpha()); this.lr.lambda(config.getLambda()); this.lr.stepOffset(config.getStepOffset()); this.lr.decayExponent(config.getDecayExponent()); // Prepare encoders this.interceptEncoder = new ConstantValueEncoder("intercept"); this.wordEncoder = new StaticWordValueEncoder("feature"); // Configure this.lr.learningRate(1); this.lr.alpha(1); this.lr.lambda(0.000001); this.lr.stepOffset(10000); this.lr.decayExponent(0.2); } @Override public ClassificationResult classify(DataHandleInternal features, int row) { return new MultiClassLogisticRegressionClassificationResult(lr.classifyFull(encodeFeatures(features, row)), specification.classMap); } @Override public void close() { lr.close(); } @Override public void train(DataHandleInternal features, DataHandleInternal clazz, int row) { lr.train(encodeClass(clazz, row), encodeFeatures(features, row)); } /** * Encodes a class * @param handle * @param row * @return */ private int encodeClass(DataHandleInternal handle, int row) { return specification.classMap.get(handle.getValue(row, specification.classIndex, true)); } /** * Encodes a feature * @param handle * @param row * @return */ private Vector encodeFeatures(DataHandleInternal handle, int row) { // Prepare DenseVector vector = new DenseVector(config.getVectorLength()); interceptEncoder.addToVector("1", vector); // Special case where there are no features if (specification.featureIndices.length == 0) { wordEncoder.addToVector("Feature:1", 1, vector); return vector; } // TODO: Consider difference between continuous and categorical // For each attribute for (int index : specification.featureIndices) { // Obtain data String name = "Attribute-"+index; String value = handle.getValue(row, index, true); wordEncoder.addToVector(name + ":" + value, 1, vector); } // Return return vector; } }