/* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.discriminative;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Classifies a data point using a hyperplane.
*/
public class LinearModel {
private static final Logger log = LoggerFactory.getLogger(LinearModel.class);
/** Represents the direction of the hyperplane found during training.*/
private Vector hyperplane;
/** Displacement of hyperplane from origin.*/
private double bias;
/** Classification threshold. */
private final double threshold;
/**
* Init a linear model with a hyperplane, distance and displacement.
*/
public LinearModel(Vector hyperplane, double displacement, double threshold) {
this.hyperplane = hyperplane;
this.bias = displacement;
this.threshold = threshold;
}
/**
* Init a linear model with zero displacement and a threshold of 0.5.
*/
public LinearModel(Vector hyperplane) {
this(hyperplane, 0, 0.5);
}
/**
* Classify a point to either belong to the class modeled by this linear model or not.
* @param dataPoint the data point to classify.
* @return returns true if data point should be classified as belonging to this model.
*/
public boolean classify(Vector dataPoint) {
double product = this.hyperplane.dot(dataPoint);
if (log.isDebugEnabled()) {
log.debug("model: {} product: {} Bias: {} threshold: {}",
new Object[] {this, product, bias, threshold});
}
return product + this.bias > this.threshold;
}
/**
* Update the hyperplane by adding delta.
* @param delta the delta to add to the hyperplane vector.
*/
public void addDelta(Vector delta) {
this.hyperplane = this.hyperplane.plus(delta);
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder("Model: ");
for (int i = 0; i < this.hyperplane.size(); i++) {
builder.append(" ").append(this.hyperplane.get(i));
}
builder.append(" C: ").append(this.bias);
return builder.toString();
}
/**
* Shift the bias of the model.
* @param factor factor to multiply the bias by.
*/
public void shiftBias(double factor) {
this.bias += factor;
}
/**
* Multiply the weight at index by delta.
* @param index the index of the element to update.
* @param delta the delta to multiply the element with.
*/
public void timesDelta(int index, double delta) {
double element = this.hyperplane.get(index);
element *= delta;
this.hyperplane.setQuick(index, element);
}
}