/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * MetricLearner.java * Copyright (C) 2004 Mikhail Bilenko and Sugato Basu * */ package weka.clusterers.metriclearners; import java.util.*; import weka.core.*; import weka.core.metrics.LearnableMetric; import weka.clusterers.MPCKMeans; /** * A parent class for MPCKMeans metric learners * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) and Sugato Basu * (sugato@cs.utexas.edu * @version $Revision: 1.2 $ */ public abstract class GDMetricLearner extends MPCKMeansMetricLearner { /** Initial value of gradient descent step parameter */ protected double m_eta = 0.001; public void setEta(double eta) { m_eta = eta; } public void resetLearner() { m_currEta = m_eta; } public double getEta() { return m_eta; } /** Current value of the step parameter */ protected double m_currEta = 0; /** Decay rate of gradient descent eta */ protected double m_etaDecayRate = 0.9; public void setEtaDecayRate(double etaDecayRate) { m_etaDecayRate = etaDecayRate; } public double getEtaDecayRate() { return m_etaDecayRate; } /** The maximum number of GD iterations */ protected int m_maxGDIterations = 20; public void setMaxGDIterations(int maxGDIterations) { m_maxGDIterations = maxGDIterations; } public int getMaxGDIterations() { return m_maxGDIterations; } protected double[] InitRegularizerComponents(double []currentWeights) { double [] regularizerComponents = new double[m_numAttributes]; for (int attr = 0; attr < m_numAttributes; attr++) { if (currentWeights[attr] > 0) { regularizerComponents[attr] = m_regularizerTermWeight * m_metric.getRegularizer().gradient(currentWeights[attr]); } else { regularizerComponents[attr] = 0; } } return regularizerComponents; } /** * Perform gradient step update using the current weights, * the gradients, the regularizers and the current learning rate * Returns the updated weights. **/ protected double[] GDUpdate(double [] currentWeights, double [] gradients, double [] regularizerComponents) { double [] newWeights = new double[m_numAttributes]; for (int attr = 0; attr < m_numAttributes; attr++) { newWeights[attr] = currentWeights[attr] - m_currEta*(gradients[attr] - regularizerComponents[attr]); if (newWeights[attr] <= 0) { System.out.println("Prevented 0/- weight " + ((float)newWeights[attr]) + " for attribute " + m_instances.attribute(attr).name() + ";\tprev=" + ((float)currentWeights[attr]) + ";\tgrad=" + ((float)gradients[attr]) + ";\treg=" + ((float)regularizerComponents[attr])); newWeights[attr] = m_minWeightValue; } } System.out.print("eta=" + (float)m_currEta); m_currEta = m_currEta * m_etaDecayRate; System.out.print(" -> " + (float)m_currEta); // PRINT top weights TreeMap map = new TreeMap(Collections.reverseOrder()); for (int j = 0; j < newWeights.length; j++) { map.put(new Double(newWeights[j]), new Integer(j)); } Iterator it = map.entrySet().iterator(); for (int j=0; j < 5 && it.hasNext(); j++) { Map.Entry entry = (Map.Entry) it.next(); int idx = ((Integer)entry.getValue()).intValue(); System.out.println("\t" + m_instances.attribute(idx).name() + "\t" + (float)currentWeights[idx] + "->" + (float)newWeights[idx] + "\tgradient=" + (float)gradients[idx] + "\tregularizer=" + (float)regularizerComponents[idx]); } // end PRINT top weights return newWeights; } }