/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.bayes; import java.util.Map; import com.google.common.collect.Maps; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DoubleWritable; import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants; import org.apache.mahout.common.Pair; import org.apache.mahout.common.Parameters; import org.apache.mahout.common.StringTuple; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * This Class reads the different interim files created during the Training stage as well as the Model File * during testing. */ public final class SequenceFileModelReader { private static final Logger log = LoggerFactory.getLogger(SequenceFileModelReader.class); private SequenceFileModelReader() { } public static void loadModel(InMemoryBayesDatastore datastore, Parameters params, Configuration conf) { loadFeatureWeights(datastore, new Path(params.get("sigma_j")), conf); loadLabelWeights(datastore, new Path(params.get("sigma_k")), conf); loadSumWeight(datastore, new Path(params.get("sigma_kSigma_j")), conf); loadThetaNormalizer(datastore, new Path(params.get("thetaNormalizer")), conf); loadWeightMatrix(datastore, new Path(params.get("weight")), conf); } public static void loadWeightMatrix(InMemoryBayesDatastore datastore, Path pathPattern, Configuration conf) { // the key is label,feature for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); datastore.loadFeatureWeight(key.stringAt(2), key.stringAt(1), value.get()); } } public static void loadFeatureWeights(InMemoryBayesDatastore datastore, Path pathPattern, Configuration conf) { // the key is either _label_ or label,feature long count = 0; for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { // Sum of weights for a Feature StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); if (key.stringAt(0).equals(BayesConstants.FEATURE_SUM)) { datastore.setSumFeatureWeight(key.stringAt(1), value.get()); if (++count % 50000 == 0) { log.info("Read {} feature weights", count); } } } } public static void loadLabelWeights(InMemoryBayesDatastore datastore, Path pathPattern, Configuration conf) { long count = 0; for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { // Sum of weights in a Label StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); if (key.stringAt(0).equals(BayesConstants.LABEL_SUM)) { datastore.setSumLabelWeight(key.stringAt(1), value.get()); if (++count % 10000 == 0) { log.info("Read {} label weights", count); } } } } public static void loadThetaNormalizer(InMemoryBayesDatastore datastore, Path pathPattern, Configuration conf) { long count = 0; for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); // Sum of weights in a Label if (key.stringAt(0).equals(BayesConstants.LABEL_THETA_NORMALIZER)) { datastore.setThetaNormalizer(key.stringAt(1), value.get()); if (++count % 50000 == 0) { log.info("Read {} theta norms", count); } } } } public static void loadSumWeight(InMemoryBayesDatastore datastore, Path pathPattern, Configuration conf) { // the key is _label for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); if (key.stringAt(0).equals(BayesConstants.TOTAL_SUM)) { // Sum of weights for all Features and all Labels datastore.setSigmaJSigmaK(value.get()); log.info("{}", value.get()); } } } public static Map<String,Double> readLabelSums(Path pathPattern, Configuration conf) { Map<String,Double> labelSum = Maps.newHashMap(); // the key is either _label_ or label,feature for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); if (key.stringAt(0).equals(BayesConstants.LABEL_SUM)) { // Sum of counts of labels labelSum.put(key.stringAt(1), value.get()); } } return labelSum; } public static Map<String,Double> readLabelDocumentCounts(Path pathPattern, Configuration conf) { Map<String,Double> labelDocumentCounts = Maps.newHashMap(); // the key is either _label_ or label,feature for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); // Count of Documents in a Label if (key.stringAt(0).equals(BayesConstants.LABEL_COUNT)) { labelDocumentCounts.put(key.stringAt(1), value.get()); } } return labelDocumentCounts; } public static double readSigmaJSigmaK(Path pathPattern, Configuration conf) { Map<String,Double> weightSum = Maps.newHashMap(); for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); if (weightSum.size() > 1) { throw new IllegalStateException("Incorrect Sum File"); } else if (key.stringAt(0).equals(BayesConstants.TOTAL_SUM)) { weightSum.put(BayesConstants.TOTAL_SUM, value.get()); } } return weightSum.get(BayesConstants.TOTAL_SUM); } public static double readVocabCount(Path pathPattern, Configuration conf) { Map<String,Double> weightSum = Maps.newHashMap(); for (Pair<StringTuple,DoubleWritable> record : new SequenceFileDirIterable<StringTuple,DoubleWritable>(pathPattern, PathType.GLOB, null, null, true, conf)) { if (weightSum.size() > 1) { throw new IllegalStateException("Incorrect vocabCount File"); } StringTuple key = record.getFirst(); DoubleWritable value = record.getSecond(); if (key.stringAt(0).equals(BayesConstants.FEATURE_SET_SIZE)) { weightSum.put(BayesConstants.FEATURE_SET_SIZE, value.get()); } } return weightSum.get(BayesConstants.FEATURE_SET_SIZE); } }