/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.fm;
import hivemall.fm.Entry.AdaGradEntry;
import hivemall.fm.Entry.FTRLEntry;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.collections.Int2LongOpenHashTable;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachineModel {
private static final int DEFAULT_MAPSIZE = 65536;
// LEARNING PARAMS
private float _w0;
@Nonnull
private final Int2LongOpenHashTable _map;
private final HeapBuffer _buf;
// hyperparams
private final int _numFeatures;
private final int _numFields;
// FTEL
private final float _alpha;
private final float _beta;
private final float _lambda1;
private final float _lamdda2;
private final int _entrySize;
public FFMStringFeatureMapModel(@Nonnull FFMHyperParameters params) {
super(params);
this._w0 = 0.f;
this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE);
this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
this._numFeatures = params.numFeatures;
this._numFields = params.numFields;
this._alpha = params.alphaFTRL;
this._beta = params.betaFTRL;
this._lambda1 = params.lambda1;
this._lamdda2 = params.lamdda2;
this._entrySize = entrySize(_factor, _useFTRL, _useAdaGrad);
}
@Nonnull
FFMPredictionModel toPredictionModel() {
return new FFMPredictionModel(_map, _buf, _w0, _factor, _numFeatures, _numFields);
}
@Override
public int getSize() {
return _map.size();
}
@Override
public float getW0() {
return _w0;
}
@Override
protected void setW0(float nextW0) {
this._w0 = nextW0;
}
@Override
public float getW(@Nonnull final Feature x) {
int j = x.getFeatureIndex();
Entry entry = getEntry(j);
if (entry == null) {
return 0.f;
}
return entry.getW();
}
@Override
protected void setW(@Nonnull final Feature x, final float nextWi) {
final int j = x.getFeatureIndex();
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(nextWi, V);
long ptr = entry.getOffset();
_map.put(j, ptr);
} else {
entry.setW(nextWi);
}
}
@Override
void updateWi(final double dloss, @Nonnull final Feature x, final float eta) {
final double Xi = x.getValue();
float gradWi = (float) (dloss * Xi);
final Entry theta = getEntry(x);
float wi = theta.getW();
float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi);
if (!NumberUtils.isFinite(nextWi)) {
throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
+ "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss
+ ", eta=" + eta);
}
theta.setW(nextWi);
}
/**
* Update Wi using Follow-the-Regularized-Leader
*/
boolean updateWiFTRL(final double dloss, @Nonnull final Feature x, final float eta) {
final double Xi = x.getValue();
float gradWi = (float) (dloss * Xi);
final Entry theta = getEntry(x);
float wi = theta.getW();
final float z = theta.updateZ(gradWi, _alpha);
final double n = theta.updateN(gradWi);
if (Math.abs(z) <= _lambda1) {
removeEntry(x);
return wi != 0;
}
final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n))
/ _alpha + _lamdda2));
if (!NumberUtils.isFinite(nextWi)) {
throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
+ "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss
+ ", eta=" + eta + ", n=" + n + ", z=" + z);
}
theta.setW(nextWi);
return (nextWi != 0) || (wi != 0);
}
/**
* @return V_x,yField,f
*/
@Override
public float getV(@Nonnull final Feature x, @Nonnull final int yField, final int f) {
final int j = Feature.toIntFeature(x, yField, _numFields);
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
return entry.getV(f);
}
@Override
protected void setV(@Nonnull final Feature x, @Nonnull final int yField, final int f,
final float nextVif) {
final int j = Feature.toIntFeature(x, yField, _numFields);
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
entry.setV(f, nextVif);
}
@Override
protected Entry getEntry(@Nonnull final Feature x) {
final int j = x.getFeatureIndex();
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
return entry;
}
@Override
protected Entry getEntry(@Nonnull final Feature x, @Nonnull final int yField) {
final int j = Feature.toIntFeature(x, yField, _numFields);
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
return entry;
}
protected void removeEntry(@Nonnull final Feature x) {
int j = x.getFeatureIndex();
_map.remove(j);
}
@Nonnull
protected final Entry newEntry(final float W, @Nonnull final float[] V) {
Entry entry = newEntry();
entry.setW(W);
entry.setV(V);
return entry;
}
@Nonnull
protected final Entry newEntry(@Nonnull final float[] V) {
Entry entry = newEntry();
entry.setV(V);
return entry;
}
@Nonnull
private Entry newEntry() {
if (_useFTRL) {
long ptr = _buf.allocate(_entrySize);
return new FTRLEntry(_buf, _factor, ptr);
} else if (_useAdaGrad) {
long ptr = _buf.allocate(_entrySize);
return new AdaGradEntry(_buf, _factor, ptr);
} else {
long ptr = _buf.allocate(_entrySize);
return new Entry(_buf, _factor, ptr);
}
}
@Nullable
private Entry getEntry(final int key) {
final long ptr = _map.get(key);
if (ptr == -1L) {
return null;
}
return getEntry(ptr);
}
@Nonnull
private Entry getEntry(long ptr) {
if (_useFTRL) {
return new FTRLEntry(_buf, _factor, ptr);
} else if (_useAdaGrad) {
return new AdaGradEntry(_buf, _factor, ptr);
} else {
return new Entry(_buf, _factor, ptr);
}
}
private static int entrySize(int factors, boolean ftrl, boolean adagrad) {
if (ftrl) {
return FTRLEntry.sizeOf(factors);
} else if (adagrad) {
return AdaGradEntry.sizeOf(factors);
} else {
return Entry.sizeOf(factors);
}
}
}