/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.fm;
import hivemall.common.EtaEstimator;
import hivemall.fm.FactorizationMachineModel.VInitScheme;
import hivemall.utils.lang.Primitives;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
class FMHyperParameters {
private static final float DEFAULT_ETA0 = 0.05f;
// -------------------------------------
// Model parameters
boolean classification = false;
int factors = 5;
// regularization
float lambda = 0.01f;
float lambdaW0 = 0.01f;
float lambdaW = 0.01f;
float lambdaV = 0.01f;
// V initialization
double sigma = 0.1d;
long seed = -1L;
VInitScheme vInit;
// regression
double minTarget = Double.MIN_VALUE;
double maxTarget = Double.MAX_VALUE;
// learning rate
EtaEstimator eta;
// feature hashing
int numFeatures = -1;
// -------------------------------------
// non-model parameters
int iters = 1;
boolean conversionCheck = true;
double convergenceRate = 0.005d;
// adaptive regularization
boolean adaptiveReglarization = false;
float validationRatio = 0.05f;
int validationThreshold = 1000;
boolean parseFeatureAsInt = false;
FMHyperParameters() {}
@Override
public String toString() {
return "FMHyperParameters [classification=" + classification + ", factors=" + factors
+ ", lambda=" + lambda + ", lambdaW0=" + lambdaW0 + ", lambdaW=" + lambdaW
+ ", lambdaV=" + lambdaV + ", sigma=" + sigma + ", seed=" + seed + ", vInit="
+ vInit + ", minTarget=" + minTarget + ", maxTarget=" + maxTarget + ", eta=" + eta
+ ", numFeatures=" + numFeatures + ", iters=" + iters + ", conversionCheck="
+ conversionCheck + ", convergenceRate=" + convergenceRate
+ ", adaptiveReglarization=" + adaptiveReglarization + ", validationRatio="
+ validationRatio + ", validationThreshold=" + validationThreshold
+ ", parseFeatureAsInt=" + parseFeatureAsInt + "]";
}
void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException {
this.classification = cl.hasOption("classification");
this.factors = Primitives.parseInt(cl.getOptionValue("factors"), factors);
this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), lambda);
this.lambdaW0 = Primitives.parseFloat(cl.getOptionValue("lambda_w0"), lambda);
this.lambdaW = Primitives.parseFloat(cl.getOptionValue("lambda_w"), lambda);
this.lambdaV = Primitives.parseFloat(cl.getOptionValue("lambda_v"), lambda);
this.sigma = Primitives.parseDouble(cl.getOptionValue("sigma"), sigma);
this.seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
if (seed == -1L) {
this.seed = System.nanoTime();
}
this.vInit = instantiateVInit(cl, factors, seed, classification);
this.minTarget = Primitives.parseDouble(cl.getOptionValue("min_target"), minTarget);
this.maxTarget = Primitives.parseDouble(cl.getOptionValue("max_target"), maxTarget);
this.eta = EtaEstimator.get(cl, DEFAULT_ETA0);
this.numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), numFeatures);
this.iters = Primitives.parseInt(cl.getOptionValue("iterations"), iters);
this.conversionCheck = !cl.hasOption("disable_cvtest");
this.convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
this.adaptiveReglarization = cl.hasOption("adaptive_regularizaion");
this.validationRatio = Primitives.parseFloat(cl.getOptionValue("validation_ratio"),
validationRatio);
if (validationRatio < 0.f || validationRatio >= 1.f) {
throw new UDFArgumentException("validation_ratio should be in range [0, 1): "
+ validationRatio);
}
this.validationThreshold = Primitives.parseInt(cl.getOptionValue("validation_threshold"),
validationThreshold);
this.parseFeatureAsInt = cl.hasOption("int_feature");
}
@Nonnull
private static VInitScheme instantiateVInit(@Nonnull CommandLine cl, int factor, long seed,
final boolean classification) {
String vInitOpt = cl.getOptionValue("init_v");
float maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 0.5f);
double initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d);
VInitScheme defaultInit = classification ? VInitScheme.gaussian : VInitScheme.random;
VInitScheme vInit = VInitScheme.resolve(vInitOpt, defaultInit);
vInit.setMaxInitValue(maxInitValue);
initStdDev = Math.max(initStdDev, 1.0d / factor);
vInit.setInitStdDev(initStdDev);
vInit.initRandom(factor, seed);
return vInit;
}
public static final class FFMHyperParameters extends FMHyperParameters {
// FFM hyper parameters
boolean globalBias = false;
boolean linearCoeff = true;
// feature hashing
int numFields = Feature.DEFAULT_NUM_FIELDS;
// adagrad
boolean useAdaGrad = true;
float eta0_V = 1.f;
float eps = 1.f;
// FTRL
boolean useFTRL = true;
float alphaFTRL = 0.1f; // Learning Rate
float betaFTRL = 1.f; // Smoothing parameter for AdaGrad
float lambda1 = 0.1f; // L1 Regularization
float lamdda2 = 0.01f; // L2 Regularization
FFMHyperParameters() {
super();
}
@Override
void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException {
super.processOptions(cl);
if (cl.hasOption("int_feature")) {
throw new UDFArgumentException("int_feature option is not supported yet for FFM");
}
this.globalBias = cl.hasOption("global_bias");
this.linearCoeff = !cl.hasOption("no_coeff");
// feature hashing
if (numFeatures == -1) {
int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"),
Feature.DEFAULT_FEATURE_BITS);
if (hashbits < 18 || hashbits > 31) {
throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: "
+ hashbits);
}
this.numFeatures = 1 << hashbits;
}
this.numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), numFields);
if (numFields <= 1) {
throw new UDFArgumentException("-num_fields MUST be greater than 1: " + numFields);
}
// adagrad
this.useAdaGrad = !cl.hasOption("disable_adagrad");
this.eta0_V = Primitives.parseFloat(cl.getOptionValue("eta0_V"), eta0_V);
this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), eps);
// FTRL
this.useFTRL = !cl.hasOption("disable_ftrl");
this.alphaFTRL = Primitives.parseFloat(cl.getOptionValue("alphaFTRL"), alphaFTRL);
if (alphaFTRL == 0.f) {
throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0");
}
this.betaFTRL = Primitives.parseFloat(cl.getOptionValue("betaFTRL"), betaFTRL);
this.lambda1 = Primitives.parseFloat(cl.getOptionValue("lambda1"), lambda1);
this.lamdda2 = Primitives.parseFloat(cl.getOptionValue("lamdda2"), lamdda2);
}
@Override
public String toString() {
return "FFMHyperParameters [globalBias=" + globalBias + ", linearCoeff=" + linearCoeff
+ ", numFields=" + numFields + ", useAdaGrad=" + useAdaGrad + ", eta0_V="
+ eta0_V + ", eps=" + eps + ", useFTRL=" + useFTRL + ", alphaFTRL=" + alphaFTRL
+ ", betaFTRL=" + betaFTRL + ", lambda1=" + lambda1 + ", lamdda2=" + lamdda2
+ "], " + super.toString();
}
}
}