/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.bayes;
import java.util.Collection;
import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.map.OpenIntDoubleHashMap;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Class implementing the Datastore for Algorithms to read In-Memory model
*
*/
public class InMemoryBayesDatastore implements Datastore {
private static final Logger log = LoggerFactory.getLogger(InMemoryBayesDatastore.class);
private final OpenObjectIntHashMap<String> featureDictionary = new OpenObjectIntHashMap<String>();
private final OpenObjectIntHashMap<String> labelDictionary = new OpenObjectIntHashMap<String>();
private final OpenIntDoubleHashMap sigmaJ = new OpenIntDoubleHashMap();
private final OpenIntDoubleHashMap sigmaK = new OpenIntDoubleHashMap();
private final OpenIntDoubleHashMap thetaNormalizerPerLabel = new OpenIntDoubleHashMap();
private final Matrix weightMatrix = new SparseMatrix(1, 0);
private final BayesParameters params;
private double thetaNormalizer = 1.0;
private double alphaI = 1.0;
private double sigmaJsigmaK = 1.0;
public InMemoryBayesDatastore(BayesParameters params) {
String basePath = params.getBasePath();
this.params = params;
params.set("sigma_j", basePath + "/trainer-weights/Sigma_j/part-*");
params.set("sigma_k", basePath + "/trainer-weights/Sigma_k/part-*");
params.set("sigma_kSigma_j", basePath + "/trainer-weights/Sigma_kSigma_j/part-*");
params.set("thetaNormalizer", basePath + "/trainer-thetaNormalizer/part-*");
params.set("weight", basePath + "/trainer-tfIdf/trainer-tfIdf/part-*");
alphaI = Double.valueOf(params.get("alpha_i", "1.0"));
}
@Override
public void initialize() throws InvalidDatastoreException {
Configuration conf = new Configuration();
SequenceFileModelReader.loadModel(this, params, conf);
for (String label : getKeys("")) {
log.info("{} {} {} {}", new Object[] {
label,
thetaNormalizerPerLabel.get(getLabelID(label)),
thetaNormalizer,
thetaNormalizerPerLabel.get(getLabelID(label)) / thetaNormalizer
});
}
}
@Override
public Collection<String> getKeys(String name) throws InvalidDatastoreException {
return labelDictionary.keys();
}
@Override
public double getWeight(String matrixName, String row, String column) throws InvalidDatastoreException {
if ("weight".equals(matrixName)) {
if ("sigma_j".equals(column)) {
return sigmaJ.get(getFeatureID(row));
} else {
return weightMatrix.getQuick(getFeatureID(row), getLabelID(column));
}
} else {
throw new InvalidDatastoreException("Matrix not found: " + matrixName);
}
}
@Override
public double getWeight(String vectorName, String index) throws InvalidDatastoreException {
if ("sumWeight".equals(vectorName)) {
if ("sigma_jSigma_k".equals(index)) {
return sigmaJsigmaK;
} else if ("vocabCount".equals(index)) {
return featureDictionary.size();
} else {
throw new InvalidDatastoreException();
}
} else if ("thetaNormalizer".equals(vectorName)) {
return thetaNormalizerPerLabel.get(getLabelID(index)) / thetaNormalizer;
} else if ("params".equals(vectorName)) {
if ("alpha_i".equals(index)) {
return alphaI;
} else {
throw new InvalidDatastoreException();
}
} else if ("labelWeight".equals(vectorName)) {
return sigmaK.get(getLabelID(index));
} else {
throw new InvalidDatastoreException();
}
}
private int getFeatureID(String feature) {
if (featureDictionary.containsKey(feature)) {
return featureDictionary.get(feature);
} else {
int id = featureDictionary.size();
featureDictionary.put(feature, id);
return id;
}
}
private int getLabelID(String label) {
if (labelDictionary.containsKey(label)) {
return labelDictionary.get(label);
} else {
int id = labelDictionary.size();
labelDictionary.put(label, id);
return id;
}
}
public void loadFeatureWeight(String feature, String label, double weight) {
int fid = getFeatureID(feature);
int lid = getLabelID(label);
weightMatrix.setQuick(fid, lid, weight);
}
public void setSumFeatureWeight(String feature, double weight) {
int fid = getFeatureID(feature);
sigmaJ.put(fid, weight);
}
public void setSumLabelWeight(String label, double weight) {
int lid = getLabelID(label);
sigmaK.put(lid, weight);
}
public void setThetaNormalizer(String label, double weight) {
int lid = getLabelID(label);
thetaNormalizerPerLabel.put(lid, weight);
thetaNormalizer = Math.max(thetaNormalizer, Math.abs(weight));
}
public void setSigmaJSigmaK(double weight) {
this.sigmaJsigmaK = weight;
}
}