/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.fm;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.collections.DoubleArray3D;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.lang.NumberUtils;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;
public abstract class FieldAwareFactorizationMachineModel extends FactorizationMachineModel {
@Nonnull
protected final FFMHyperParameters _params;
protected final float _eta0_V;
protected final float _eps;
protected final boolean _useAdaGrad;
protected final boolean _useFTRL;
public FieldAwareFactorizationMachineModel(@Nonnull FFMHyperParameters params) {
super(params);
this._params = params;
this._eta0_V = params.eta0_V;
this._eps = params.eps;
this._useAdaGrad = params.useAdaGrad;
this._useFTRL = params.useFTRL;
}
public abstract float getV(@Nonnull Feature x, @Nonnull int yField, int f);
@Deprecated
protected abstract void setV(@Nonnull Feature x, @Nonnull int yField, int f, float nextVif);
@Override
public float getV(Feature x, int f) {
throw new UnsupportedOperationException();
}
@Override
protected void setV(Feature x, int f, float nextVif) {
throw new UnsupportedOperationException();
}
@Override
protected final double predict(@Nonnull final Feature[] x) throws HiveException {
// w0
double ret = getW0();
// W
for (Feature e : x) {
double xi = e.getValue();
float wi = getW(e);
double wx = wi * xi;
ret += wx;
}
// V
for (int i = 0; i < x.length; i++) {
final Feature ei = x[i];
final double xi = ei.getValue();
final int iField = ei.getField();
for (int j = i + 1; j < x.length; j++) {
final Feature ej = x[j];
final double xj = ej.getValue();
final int jField = ej.getField();
for (int f = 0, k = _factor; f < k; f++) {
float vijf = getV(ei, jField, f);
float vjif = getV(ej, iField, f);
ret += vijf * vjif * xi * xj;
assert (!Double.isNaN(ret));
}
}
}
if (!NumberUtils.isFinite(ret)) {
throw new HiveException("Detected " + ret
+ " in predict. We recommend to normalize training examples.\n"
+ "Dumping variables ...\n" + varDump(x));
}
return ret;
}
void updateV(final double dloss, @Nonnull final Feature x, @Nonnull final int yField,
final int f, final double sumViX, long t) {
final double Xi = x.getValue();
final double h = Xi * sumViX;
final float gradV = (float) (dloss * h);
final float lambdaVf = getLambdaV(f);
final Entry theta = getEntry(x, yField);
final float currentV = theta.getV(f);
final float eta = etaV(theta, t, gradV);
final float nextV = currentV - eta * (gradV + 2.f * lambdaVf * currentV);
if (!NumberUtils.isFinite(nextV)) {
throw new IllegalStateException("Got " + nextV + " for next V" + f + '['
+ x.getFeatureIndex() + "]\n" + "Xi=" + Xi + ", Vif=" + currentV + ", h=" + h
+ ", gradV=" + gradV + ", lambdaVf=" + lambdaVf + ", dloss=" + dloss
+ ", sumViX=" + sumViX);
}
theta.setV(f, nextV);
}
protected final float etaV(@Nonnull final Entry theta, final long t, final float grad) {
if (_useAdaGrad) {
double gg = theta.getSumOfSquaredGradientsV();
theta.addGradientV(grad);
return (float) (_eta0_V / Math.sqrt(_eps + gg));
} else {
return _eta.eta(t);
}
}
/**
* sum{XiViaf} where a is field index of Xi
*/
@Nonnull
final DoubleArray3D sumVfX(@Nonnull final Feature[] x, @Nonnull final IntArrayList fieldList,
@Nullable DoubleArray3D cached) {
final int xSize = x.length;
final int fieldSize = fieldList.size();
final int factors = _factor;
final DoubleArray3D mdarray;
if (cached == null) {
mdarray = new DoubleArray3D();
mdarray.setSanityCheck(false);
} else {
mdarray = cached;
}
mdarray.configure(xSize, fieldSize, factors);
for (int i = 0; i < xSize; i++) {
for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) {
final int yField = fieldList.get(fieldIndex);
for (int f = 0; f < factors; f++) {
double val = sumVfX(x, i, yField, f);
mdarray.set(i, fieldIndex, f, val);
}
}
}
return mdarray;
}
private double sumVfX(@Nonnull final Feature[] x, final int i, @Nonnull final int yField,
final int f) {
final Feature xi = x[i];
final int xiFeature = xi.getFeatureIndex();
final double xiValue = xi.getValue();
final int xiField = xi.getField();
double ret = 0.d;
// find all other features whose field matches field
for (Feature e : x) {
if (e.getFeatureIndex() == xiFeature) { // ignore x[i] = e
continue;
}
if (e.getField() == yField) { // multiply x_e and v_d,field(e),f
float Vjf = getV(e, xiField, f);
ret += Vjf * xiValue;
}
}
if (!NumberUtils.isFinite(ret)) {
throw new IllegalStateException("Got " + ret + " for sumV[ " + i + "][ " + f + "]X.\n"
+ "x = " + Arrays.toString(x));
}
return ret;
}
@Nonnull
protected abstract Entry getEntry(@Nonnull Feature x);
@Nonnull
protected abstract Entry getEntry(@Nonnull Feature x, @Nonnull int yField);
@Override
protected final String varDump(@Nonnull final Feature[] x) {
final StringBuilder buf1 = new StringBuilder(1024);
final StringBuilder buf2 = new StringBuilder(1024);
// X
for (int i = 0; i < x.length; i++) {
Feature e = x[i];
String j = e.getFeature();
double xj = e.getValue();
if (i != 0) {
buf1.append(", ");
}
buf1.append("x[").append(j).append("] = ").append(xj);
}
buf1.append("\n");
// w0
double ret = getW0();
buf1.append("predict(x) = w0");
buf2.append("predict(x) = ").append(ret);
// W
for (Feature e : x) {
String i = e.getFeature();
double xi = e.getValue();
float wi = getW(e);
buf1.append(" + (w[").append(i).append("] * x[").append(i).append("])");
buf2.append(" + (").append(wi).append(" * ").append(xi).append(')');
double wx = wi * xi;
ret += wx;
if (!NumberUtils.isFinite(ret)) {
return buf1.append(" + ... = ")
.append(ret)
.append('\n')
.append(buf2)
.append(" + ... = ")
.append(ret)
.toString();
}
}
// V
for (int i = 0; i < x.length; i++) {
final Feature ei = x[i];
final String fi = ei.getFeature();
final double xi = ei.getValue();
final int iField = ei.getField();
for (int j = i + 1; j < x.length; j++) {
final Feature ej = x[j];
final String fj = ej.getFeature();
final double xj = ej.getValue();
final int jField = ej.getField();
for (int f = 0, k = _factor; f < k; f++) {
float vijf = getV(ei, jField, f);
float vjif = getV(ej, iField, f);
buf1.append(" + (v[i")
.append(fi)
.append("-j")
.append(jField)
.append("-f")
.append(f)
.append("] * v[j")
.append(fj)
.append("-i")
.append(iField)
.append("-f")
.append(f)
.append("] * x[")
.append(fi)
.append("] * x[")
.append(fj)
.append("])");
buf2.append(" + (")
.append(vijf)
.append(" * ")
.append(vjif)
.append(" * ")
.append(xi)
.append(" * ")
.append(xj)
.append(')');
ret += vijf * vjif * xi * xj;
if (!NumberUtils.isFinite(ret)) {
return buf1.append(" + ... = ")
.append(ret)
.append('\n')
.append(buf2)
.append(" + ... = ")
.append(ret)
.toString();
}
}
}
}
return buf1.append(" = ")
.append(ret)
.append('\n')
.append(buf2)
.append(" = ")
.append(ret)
.toString();
}
}