/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.model;
import hivemall.mix.MixedWeight;
import hivemall.mix.MixedWeight.WeightWithCovar;
import hivemall.mix.MixedWeight.WeightWithDelta;
import hivemall.utils.collections.IntOpenHashMap;
import hivemall.utils.collections.OpenHashMap;
import javax.annotation.Nonnull;
public abstract class AbstractPredictionModel implements PredictionModel {
public static final byte BYTE0 = 0;
protected ModelUpdateHandler handler;
private long numMixed;
private boolean cancelMixRequest;
private IntOpenHashMap<MixedWeight> mixedRequests_i;
private OpenHashMap<Object, MixedWeight> mixedRequests_o;
public AbstractPredictionModel() {
this.numMixed = 0L;
this.cancelMixRequest = false;
}
protected abstract boolean isDenseModel();
@Override
public ModelUpdateHandler getUpdateHandler() {
return handler;
}
@Override
public void configureMix(ModelUpdateHandler handler, boolean cancelMixRequest) {
this.handler = handler;
this.cancelMixRequest = cancelMixRequest;
if (cancelMixRequest) {
if (isDenseModel()) {
this.mixedRequests_i = new IntOpenHashMap<MixedWeight>(327680);
} else {
this.mixedRequests_o = new OpenHashMap<Object, MixedWeight>(327680);
}
}
}
@Override
public final long getNumMixed() {
return numMixed;
}
@Override
public void resetDeltaUpdates(int feature) {
throw new UnsupportedOperationException();
}
protected final void onUpdate(final int feature, final float weight, final float covar,
final short clock, final int deltaUpdates, final boolean hasCovar) {
if (handler != null) {
if (deltaUpdates < 1) {
return;
}
final boolean requestSent;
try {
requestSent = handler.onUpdate(feature, weight, covar, clock, deltaUpdates);
} catch (Exception e) {
throw new RuntimeException(e);
}
if (requestSent) {
if (cancelMixRequest) {
if (hasCovar) {
MixedWeight prevMixed = mixedRequests_i.get(feature);
if (prevMixed == null) {
prevMixed = new WeightWithCovar(weight, covar);
mixedRequests_i.put(feature, prevMixed);
} else {
try {
handler.sendCancelRequest(feature, prevMixed);
} catch (Exception e) {
throw new RuntimeException(e);
}
prevMixed.setWeight(weight);
prevMixed.setCovar(covar);
}
} else {
MixedWeight prevMixed = mixedRequests_i.get(feature);
if (prevMixed == null) {
prevMixed = new WeightWithDelta(weight, deltaUpdates);
mixedRequests_i.put(feature, prevMixed);
} else {
try {
handler.sendCancelRequest(feature, prevMixed);
} catch (Exception e) {
throw new RuntimeException(e);
}
prevMixed.setWeight(weight);
prevMixed.setDeltaUpdates(deltaUpdates);
}
}
}
resetDeltaUpdates(feature);
}
}
}
protected final void onUpdate(final Object feature, final IWeightValue value) {
if (handler != null) {
if (!value.isTouched()) {
return;
}
final float weight = value.get();
final short clock = value.getClock();
final int deltaUpdates = value.getDeltaUpdates();
if (value.hasCovariance()) {
final float covar = value.getCovariance();
final boolean requestSent;
try {
requestSent = handler.onUpdate(feature, weight, covar, clock, deltaUpdates);
} catch (Exception e) {
throw new RuntimeException(e);
}
if (requestSent) {
if (cancelMixRequest) {
MixedWeight prevMixed = mixedRequests_o.get(feature);
if (prevMixed == null) {
prevMixed = new WeightWithCovar(weight, covar);
mixedRequests_o.put(feature, prevMixed);
} else {
try {
handler.sendCancelRequest(feature, prevMixed);
} catch (Exception e) {
throw new RuntimeException(e);
}
prevMixed.setWeight(weight);
prevMixed.setCovar(covar);
}
}
value.setDeltaUpdates(BYTE0);
}
} else {
final boolean requestSent;
try {
requestSent = handler.onUpdate(feature, weight, 1.f, clock, deltaUpdates);
} catch (Exception e) {
throw new RuntimeException(e);
}
if (requestSent) {
if (cancelMixRequest) {
MixedWeight prevMixed = mixedRequests_o.get(feature);
if (prevMixed == null) {
prevMixed = new WeightWithDelta(weight, deltaUpdates);
mixedRequests_o.put(feature, prevMixed);
} else {
try {
handler.sendCancelRequest(feature, prevMixed);
} catch (Exception e) {
throw new RuntimeException(e);
}
prevMixed.setWeight(weight);
prevMixed.setDeltaUpdates(deltaUpdates);
}
}
value.setDeltaUpdates(BYTE0);
}
}
}
}
/**
*
*/
@Override
public void set(@Nonnull Object feature, float weight, float covar, short clock) {
if (hasCovariance()) {
_set(feature, weight, covar, clock);
} else {
_set(feature, weight, clock);
}
numMixed++;
}
protected abstract void _set(@Nonnull Object feature, float weight, short clock);
protected abstract void _set(@Nonnull Object feature, float weight, float covar, short clock);
}