package quickml.data.instances;
import org.javatuples.Pair;
import quickml.data.AttributesMap;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Map;
/**
* Created by alexanderhawk on 10/12/15.
*/
public class SparseRegressionInstance extends RegressionInstance {
private int[] indicesOfCorrespondingWeights;
private double[] values;
public SparseRegressionInstance(AttributesMap attributes, Double label, Map<String, Integer> nameToValueIndexMap) {
super(attributes, label);
setIndicesAndValues(attributes, nameToValueIndexMap);
}
public SparseRegressionInstance(AttributesMap attributes, Double label, double weight, Map<String, Integer> nameToValueIndexMap) {
super(attributes, label, weight);
setIndicesAndValues(attributes, nameToValueIndexMap);
}
private void setIndicesAndValues(AttributesMap attributes, Map<String, Integer> nameToIndexMap) {
indicesOfCorrespondingWeights = new int[attributes.size()+1];
values = new double[attributes.size()+1];
//add bias term
indicesOfCorrespondingWeights[0] = 0;
values[0] = 1.0;
//add non bias terms
int i = 1;
for (Map.Entry<String, Serializable> entry : attributes.entrySet()) {
if (!(entry.getValue() instanceof Double)) {
throw new RuntimeException("wrong type of values in attributes");
}
int valueIndex = nameToIndexMap.get(entry.getKey());
indicesOfCorrespondingWeights[i] = valueIndex;
values[i] = (Double)entry.getValue();
i++;
}
}
public static double[] getArrayOfValues(RegressionInstance regressionInstance, Map<String, Integer> nameToIndexMap, boolean useBias){
int numAttributes = regressionInstance.getAttributes().size();
AttributesMap attributesMap = regressionInstance.getAttributes();
double[] valuesArray;
int attributeIndex = 0;
if (useBias) {
valuesArray = new double[numAttributes + 1];
valuesArray[0] = 1.0;
attributeIndex++;
} else {
valuesArray = new double[numAttributes];
}
for (Map.Entry<String, Serializable> attributeEntry : attributesMap.entrySet()) {
attributeIndex = nameToIndexMap.get(attributeEntry.getKey());
valuesArray[attributeIndex] = (Double)attributeEntry.getValue();
}
return valuesArray;
}
@Override
public AttributesMap getAttributes() {
return super.getAttributes();
}
public Pair<int[], double[]> getSparseAttributes(){
return new Pair<>(indicesOfCorrespondingWeights, values);
}
public double dotProduct(double[] omega) {
double result = 0;
for (int i = 0; i< indicesOfCorrespondingWeights.length; i++) {
int indexOfFeature = indicesOfCorrespondingWeights[i];
double valueOfFeature = values[i];
result+= omega[indexOfFeature]* valueOfFeature;
}
return result;
}
}