/** * Copyright 2014, Emory University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.emory.clir.clearnlp.classification.trainer; import java.util.Arrays; import java.util.Random; import edu.emory.clir.clearnlp.classification.instance.IntInstance; import edu.emory.clir.clearnlp.classification.model.SparseModel; import edu.emory.clir.clearnlp.classification.model.StringModel; import edu.emory.clir.clearnlp.util.DSUtils; import edu.emory.clir.clearnlp.util.MathUtils; /** * @since 3.0.0 * @author Jinho D. Choi ({@code jinho.choi@emory.edu}) */ abstract public class AbstractOnlineTrainer extends AbstractTrainer { protected double[] d_average; protected Random r_rand; /** @param average if {@code true}, weights are averaged. */ public AbstractOnlineTrainer(SparseModel model, boolean average) { super(TrainerType.ONLINE, model); init(average); } /** @param average if {@code true}, weights are averaged. */ public AbstractOnlineTrainer(StringModel model, int labelCutoff, int featureCutoff, boolean average) { super(TrainerType.ONLINE, model, labelCutoff, featureCutoff); init(average); } private void init(boolean average) { d_average = average ? new double[w_vector.size()] : null; r_rand = new Random(RANDOM_SEED); } public void train() { if (average()) Arrays.fill(d_average, 0); DSUtils.shuffle(l_instances, r_rand); int i, size = getInstanceSize(); for (i=0; i<size; i++) update(getInstance(i), i+1); if (average()) setAverageWeights(size+1); } protected boolean average() { return d_average != null; } private void setAverageWeights(int count) { double c = -MathUtils.reciprocal(count); int i, size = w_vector.size(); for (i=0; i<size; i++) w_vector.add(i, (float)(c*d_average[i])); } abstract protected boolean update(IntInstance instance, int averageCount); }