/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.mf; import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; import hivemall.common.EtaEstimator; import hivemall.mf.FactorizedModel.RankInitScheme; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; import hivemall.utils.io.NioFixedSegment; import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Primitives; import hivemall.utils.math.MathUtils; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.mapred.Counters.Counter; import org.apache.hadoop.mapred.Reporter; /** * Bayesian Personalized Ranking from Implicit Feedback. */ @Description(name = "train_bprmf", value = "_FUNC_(INT user, INT posItem, INT negItem [, String options])" + " - Returns a relation <INT i, FLOAT Pi, FLOAT Qi [, FLOAT Bi]>") public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements RatingInitilizer { private static final Log LOG = LogFactory.getLog(OnlineMatrixFactorizationUDTF.class); private static final int RECORD_BYTES = (Integer.SIZE + Integer.SIZE + Integer.SIZE) / 8; // Option variables /** The number of latent factors */ protected int factor; /** The regularization factors */ protected float regU, regI, regJ; /** The regularization factor for Bias. reg / 2.0 by the default. */ protected float regBias; /** Whether to use bias clause */ protected boolean useBiasClause; /** The number of iterations */ protected int iterations; protected LossFunction lossFunction; /** Initialization strategy of rank matrix */ protected RankInitScheme rankInit; /** Learning rate */ protected EtaEstimator etaEstimator; // Variable managing status of learning /** The number of processed training examples */ protected long count; protected ConversionState cvState; // Model itself protected FactorizedModel model; // Input OIs and Context protected PrimitiveObjectInspector userOI; protected PrimitiveObjectInspector posItemOI; protected PrimitiveObjectInspector negItemOI; // Used for iterations protected NioFixedSegment fileIO; protected ByteBuffer inputBuf; private long lastWritePos; private float[] uProbe, iProbe, jProbe; public BPRMatrixFactorizationUDTF() { this.factor = 10; this.regU = 0.0025f; this.regI = 0.0025f; this.regJ = 0.00125f; this.regBias = 0.01f; this.useBiasClause = true; this.iterations = 30; } public enum LossFunction { sigmoid, logistic, lnLogistic; @Nonnull public static LossFunction resolve(@Nullable String name) { if (name == null) { return lnLogistic; } if (name.equalsIgnoreCase("lnLogistic")) { return lnLogistic; } else if (name.equalsIgnoreCase("logistic")) { return logistic; } else if (name.equalsIgnoreCase("sigmoid")) { return sigmoid; } else { throw new IllegalArgumentException("Unexpected loss function: " + name); } } } @Override protected Options getOptions() { Options opts = new Options(); opts.addOption("k", "factor", true, "The number of latent factor [default: 10]"); opts.addOption("iter", "iterations", true, "The number of iterations [default: 30]"); opts.addOption("loss", "loss_function", true, "Loss function [default: lnLogistic, logistic, sigmoid]"); // initialization opts.addOption("rankinit", true, "Initialization strategy of rank matrix [random, gaussian] (default: random)"); opts.addOption("maxval", "max_init_value", true, "The maximum initial value in the rank matrix [default: 1.0]"); opts.addOption("min_init_stddev", true, "The minimum standard deviation of initial rank matrix [default: 0.1]"); // regularization opts.addOption("reg", "lambda", true, "The regularization factor [default: 0.0025]"); opts.addOption("reg_u", "reg_user", true, "The regularization factor for user [default: 0.0025 (reg)]"); opts.addOption("reg_i", "reg_item", true, "The regularization factor for positive item [default: 0.0025 (reg)]"); opts.addOption("reg_j", true, "The regularization factor for negative item [default: 0.00125 (reg_i/2) ]"); // bias opts.addOption("reg_bias", true, "The regularization factor for bias clause [default: 0.01]"); opts.addOption("disable_bias", "no_bias", false, "Turn off bias clause"); // learning rates opts.addOption("eta", true, "The initial learning rate [default: 0.001]"); opts.addOption("eta0", true, "The initial learning rate [default 0.3]"); opts.addOption("t", "total_steps", true, "The total number of training examples"); opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]"); opts.addOption("boldDriver", "bold_driver", false, "Whether to use Bold Driver for learning rate [default: false]"); // conversion check opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: enabled]"); opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]"); return opts; } @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { CommandLine cl = null; String lossFuncName = null; String rankInitOpt = null; float maxInitValue = 1.f; double initStdDev = 0.1d; boolean conversionCheck = true; double convergenceRate = 0.005d; if (argOIs.length >= 4) { String rawArgs = HiveUtils.getConstString(argOIs[3]); cl = parseOptions(rawArgs); this.factor = Primitives.parseInt(cl.getOptionValue("factor"), factor); this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations); if (iterations < 1) { throw new UDFArgumentException( "'-iterations' must be greater than or equals to 1: " + iterations); } lossFuncName = cl.getOptionValue("loss_function"); float reg = Primitives.parseFloat(cl.getOptionValue("reg"), 0.0025f); this.regU = Primitives.parseFloat(cl.getOptionValue("reg_u"), reg); this.regI = Primitives.parseFloat(cl.getOptionValue("reg_i"), reg); this.regJ = Primitives.parseFloat(cl.getOptionValue("reg_j"), regI / 2.f); this.regBias = Primitives.parseFloat(cl.getOptionValue("reg_bias"), regBias); rankInitOpt = cl.getOptionValue("rankinit"); maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 1.f); initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d); conversionCheck = !cl.hasOption("disable_cvtest"); convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate); this.useBiasClause = !cl.hasOption("no_bias"); } this.lossFunction = LossFunction.resolve(lossFuncName); this.rankInit = RankInitScheme.resolve(rankInitOpt); rankInit.setMaxInitValue(maxInitValue); initStdDev = Math.max(initStdDev, 1.0d / factor); rankInit.setInitStdDev(initStdDev); this.etaEstimator = EtaEstimator.get(cl); this.cvState = new ConversionState(conversionCheck, convergenceRate); return cl; } @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 3 && argOIs.length != 4) { throw new UDFArgumentException( getClass().getSimpleName() + " takes 3 or 4 arguments: INT user, INT posItem, INT negItem [, CONSTANT STRING options]"); } this.userOI = HiveUtils.asIntCompatibleOI(argOIs[0]); this.posItemOI = HiveUtils.asIntCompatibleOI(argOIs[1]); this.negItemOI = HiveUtils.asIntCompatibleOI(argOIs[2]); processOptions(argOIs); this.model = new FactorizedModel(this, factor, rankInit); this.count = 0L; this.lastWritePos = 0L; this.uProbe = new float[factor]; this.iProbe = new float[factor]; this.jProbe = new float[factor]; if (mapredContext != null && iterations > 1) { // invoke only at task node (initialize is also invoked in compilation) final File file; try { file = File.createTempFile("hivemall_bprmf", ".sgmt"); file.deleteOnExit(); if (!file.canWrite()) { throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath()); } } catch (IOException ioe) { throw new UDFArgumentException(ioe); } catch (Throwable e) { throw new UDFArgumentException(e); } this.fileIO = new NioFixedSegment(file, RECORD_BYTES, false); this.inputBuf = ByteBuffer.allocateDirect(65536); // 64 KiB } ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("idx"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("Pu"); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); fieldNames.add("Qi"); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); if (useBiasClause) { fieldNames.add("Bi"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); } return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @Override public void process(Object[] args) throws HiveException { assert (args.length >= 3) : args.length; int u = PrimitiveObjectInspectorUtils.getInt(args[0], userOI); int i = PrimitiveObjectInspectorUtils.getInt(args[1], posItemOI); int j = PrimitiveObjectInspectorUtils.getInt(args[2], negItemOI); validateInput(u, i, j); beforeTrain(count, u, i, j); count++; train(u, i, j); } protected void beforeTrain(final long rowNum, final int u, final int i, final int j) throws HiveException { if (inputBuf != null) { assert (fileIO != null); final ByteBuffer buf = inputBuf; int remain = buf.remaining(); if (remain < RECORD_BYTES) { writeBuffer(buf, fileIO, lastWritePos); this.lastWritePos = rowNum; } buf.putInt(u); buf.putInt(i); buf.putInt(j); } } protected void train(final int u, final int i, final int j) { Rating[] user = model.getUserVector(u, true); Rating[] itemI = model.getItemVector(i, true); Rating[] itemJ = model.getItemVector(j, true); copyToProbe(user, uProbe); copyToProbe(itemI, iProbe); copyToProbe(itemJ, jProbe); double x_uij = predict(u, i, uProbe, iProbe) - predict(u, j, uProbe, jProbe); final double dloss = dloss(x_uij, lossFunction); final float eta = eta(); for (int k = 0, size = factor; k < size; k++) { float w_uf = uProbe[k]; float h_if = iProbe[k]; float h_jf = jProbe[k]; updateUserRating(user[k], w_uf, h_if, h_jf, dloss, eta); updateItemRating(itemI[k], w_uf, h_if, dloss, eta, regI); // positive item updateItemRating(itemJ[k], w_uf, h_jf, -dloss, eta, regJ); // negative item } if (useBiasClause) { updateBias(i, j, dloss, eta); } } protected double predict(final int user, final int item, @Nonnull final float[] userProbe, @Nonnull final float[] itemProbe) { double ret = model.getItemBias(item); for (int k = 0, size = factor; k < size; k++) { ret += userProbe[k] * itemProbe[k]; } if (!NumberUtils.isFinite(ret)) { throw new IllegalStateException("Detected " + ret + " in predict where user=" + user + " and item=" + item); } return ret; } protected double dloss(final double x, @Nonnull final LossFunction loss) { switch (loss) { case sigmoid: { return 1.d / (1.d + Math.exp(x)); } case logistic: { double sigmoid = MathUtils.sigmoid(x); return sigmoid * (1.d - sigmoid); } case lnLogistic: { double ex = Math.exp(-x); return ex / (1.d + ex); } default: { throw new IllegalStateException("Unexpectd loss function: " + loss); } } } protected float eta() { return etaEstimator.eta(count); } protected void updateUserRating(final Rating rating, final float w_uf, final float h_if, final float h_jf, final double dloss, final float eta) { double grad = dloss * (h_if - h_jf) - regU * w_uf; float delta = (float) (eta * grad); float newWeight = w_uf + delta; if (!NumberUtils.isFinite(newWeight)) { throw new IllegalStateException("Detected " + newWeight + " for w_uf"); } rating.setWeight(newWeight); cvState.incrLoss(regU * w_uf * w_uf); } protected void updateItemRating(final Rating rating, final float w_uf, final float h_f, final double dloss, final float eta, final float reg) { double grad = dloss * w_uf - reg * h_f; float delta = (float) (eta * grad); float newWeight = h_f + delta; if (!NumberUtils.isFinite(newWeight)) { throw new IllegalStateException("Detected " + newWeight + " for h_f"); } rating.setWeight(newWeight); cvState.incrLoss(reg * h_f * h_f); } protected void updateBias(final int i, final int j, final double dloss, final float eta) { float Bi = model.getItemBias(i); double Gi = dloss - regBias * Bi; Bi += eta * Gi; if (!NumberUtils.isFinite(Bi)) { throw new IllegalStateException("Detected " + Bi + " for Bi"); } model.setItemBias(i, Bi); cvState.incrLoss(regBias * Bi * Bi); float Bj = model.getItemBias(j); double Gj = -dloss - regBias * Bj; Bj += eta * Gj; if (!NumberUtils.isFinite(Bj)) { throw new IllegalStateException("Detected " + Bj + " for Bj"); } model.setItemBias(j, Bj); cvState.incrLoss(regBias * Bj * Bj); } @Override public void close() throws HiveException { if (model != null) { if (count == 0) { this.model = null; // help GC return; } if (iterations > 1) { runIterativeTraining(iterations); } final IntWritable idx = new IntWritable(); final FloatWritable[] Pu = HiveUtils.newFloatArray(factor, 0.f); final FloatWritable[] Qi = HiveUtils.newFloatArray(factor, 0.f); final FloatWritable Bi = useBiasClause ? new FloatWritable() : null; final Object[] forwardObj = new Object[] {idx, Pu, Qi, Bi}; int numForwarded = 0; for (int i = model.getMinIndex(), maxIdx = model.getMaxIndex(); i <= maxIdx; i++) { idx.set(i); Rating[] userRatings = model.getUserVector(i); if (userRatings == null) { forwardObj[1] = null; } else { forwardObj[1] = Pu; copyTo(userRatings, Pu); } Rating[] itemRatings = model.getItemVector(i); if (itemRatings == null) { forwardObj[2] = null; } else { forwardObj[2] = Qi; copyTo(itemRatings, Qi); } if (useBiasClause) { Bi.set(model.getItemBias(i)); } forward(forwardObj); numForwarded++; } this.model = null; // help GC LOG.info("Forwarded the prediction model of " + numForwarded + " rows. [lastLosses=" + cvState.getCumulativeLoss() + ", #trainingExamples=" + count + "]"); } } private final void runIterativeTraining(@Nonnegative final int iterations) throws HiveException { final ByteBuffer inputBuf = this.inputBuf; final NioFixedSegment fileIO = this.fileIO; assert (inputBuf != null); assert (fileIO != null); final long numTrainingExamples = count; final Reporter reporter = getReporter(); final Counter iterCounter = (reporter == null) ? null : reporter.getCounter( "hivemall.mf.BPRMatrixFactorization$Counter", "iteration"); try { if (lastWritePos == 0) {// run iterations w/o temporary file if (inputBuf.position() == 0) { return; // no training example } inputBuf.flip(); int iter = 2; for (; iter <= iterations; iter++) { reportProgress(reporter); setCounterValue(iterCounter, iter); while (inputBuf.remaining() > 0) { int u = inputBuf.getInt(); int i = inputBuf.getInt(); int j = inputBuf.getInt(); // invoke train count++; train(u, i, j); } cvState.multiplyLoss(0.5d); cvState.logState(iter, eta()); if (cvState.isConverged(iter, numTrainingExamples)) { break; } if (cvState.isLossIncreased()) { etaEstimator.update(1.1f); } else { etaEstimator.update(0.5f); } inputBuf.rewind(); } LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(count) + " training updates in total) "); } else {// read training examples in the temporary file and invoke train for each example // write training examples in buffer to a temporary file if (inputBuf.position() > 0) { writeBuffer(inputBuf, fileIO, lastWritePos); } else if (lastWritePos == 0) { return; // no training example } try { fileIO.flush(); } catch (IOException e) { throw new HiveException("Failed to flush a file: " + fileIO.getFile().getAbsolutePath(), e); } if (LOG.isInfoEnabled()) { File tmpFile = fileIO.getFile(); LOG.info("Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"); } // run iterations int iter = 2; for (; iter <= iterations; iter++) { setCounterValue(iterCounter, iter); inputBuf.clear(); long seekPos = 0L; while (true) { reportProgress(reporter); // TODO prefetch // writes training examples to a buffer in the temporary file final int bytesRead; try { bytesRead = fileIO.read(seekPos, inputBuf); } catch (IOException e) { throw new HiveException("Failed to read a file: " + fileIO.getFile().getAbsolutePath(), e); } if (bytesRead == 0) { // reached file EOF break; } assert (bytesRead > 0) : bytesRead; seekPos += bytesRead; // reads training examples from a buffer inputBuf.flip(); int remain = inputBuf.remaining(); assert (remain > 0) : remain; for (; remain >= RECORD_BYTES; remain -= RECORD_BYTES) { int u = inputBuf.getInt(); int i = inputBuf.getInt(); int j = inputBuf.getInt(); // invoke train count++; train(u, i, j); } inputBuf.compact(); } cvState.multiplyLoss(0.5d); cvState.logState(iter, eta()); if (cvState.isConverged(iter, numTrainingExamples)) { break; } if (cvState.isLossIncreased()) { etaEstimator.update(1.1f); } else { etaEstimator.update(0.5f); } } LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples using a secondary storage (thus " + NumberUtils.formatNumber(count) + " training updates in total)"); } } finally { // delete the temporary file and release resources try { fileIO.close(true); } catch (IOException e) { throw new HiveException("Failed to close a file: " + fileIO.getFile().getAbsolutePath(), e); } this.inputBuf = null; this.fileIO = null; } } @Override public Rating newRating(float v) { return new Rating(v); } // ---------------------------------------------- // static utility methods private static void validateInput(int u, int i, int j) throws HiveException { if (u < 0) { throw new HiveException("Illegal u index: " + u); } if (i < 0) { throw new HiveException("Illegal i index: " + i); } if (j < 0) { throw new HiveException("Illegal j index: " + j); } } private static void writeBuffer(@Nonnull final ByteBuffer srcBuf, @Nonnull final NioFixedSegment dst, final long lastWritePos) throws HiveException { // TODO asynchronous write in the background srcBuf.flip(); try { dst.writeRecords(lastWritePos, srcBuf); } catch (IOException e) { throw new HiveException("Exception causes while writing records to : " + lastWritePos, e); } srcBuf.clear(); } @Nonnull private final void copyToProbe(@Nonnull final Rating[] rating, @Nonnull float[] probe) { for (int k = 0, size = factor; k < size; k++) { probe[k] = rating[k].getWeight(); } } private static void copyTo(@Nonnull final Rating[] rating, @Nonnull final FloatWritable[] dst) { for (int k = 0, size = rating.length; k < size; k++) { float w = rating[k].getWeight(); dst[k].set(w); } } }