/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.dirichlet.models;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Iterator;
import com.google.common.base.Splitter;
import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
/**
* Simply describes parameters needs to create a {@link org.apache.mahout.clustering.ModelDistribution}.
*/
public final class DistributionDescription {
private final String modelFactory;
private final String modelPrototype;
private final String distanceMeasure;
private final int prototypeSize;
public DistributionDescription(String modelFactory,
String modelPrototype,
String distanceMeasure,
int prototypeSize) {
this.modelFactory = modelFactory;
this.modelPrototype = modelPrototype;
this.distanceMeasure = distanceMeasure;
this.prototypeSize = prototypeSize;
}
public String getModelFactory() {
return modelFactory;
}
public String getModelPrototype() {
return modelPrototype;
}
public String getDistanceMeasure() {
return distanceMeasure;
}
public int getPrototypeSize() {
return prototypeSize;
}
/**
* Create an instance of AbstractVectorModelDistribution from the given command line arguments
*/
public ModelDistribution<VectorWritable> createModelDistribution() {
ClassLoader ccl = Thread.currentThread().getContextClassLoader();
AbstractVectorModelDistribution modelDistribution;
try {
modelDistribution = ClassUtils.instantiateAs(modelFactory, AbstractVectorModelDistribution.class);
Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
Constructor<? extends Vector> v = vcl.getConstructor(int.class);
modelDistribution.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
if (modelDistribution instanceof DistanceMeasureClusterDistribution) {
DistanceMeasure measure = ClassUtils.instantiateAs(distanceMeasure, DistanceMeasure.class);
((DistanceMeasureClusterDistribution) modelDistribution).setMeasure(measure);
}
} catch (ClassNotFoundException cnfe) {
throw new IllegalStateException(cnfe);
} catch (NoSuchMethodException nsme) {
throw new IllegalStateException(nsme);
} catch (InstantiationException ie) {
throw new IllegalStateException(ie);
} catch (IllegalAccessException iae) {
throw new IllegalStateException(iae);
} catch (InvocationTargetException ite) {
throw new IllegalStateException(ite);
}
return modelDistribution;
}
@Override
public String toString() {
return modelFactory + ',' + modelPrototype + ',' + distanceMeasure + ',' + prototypeSize;
}
public static DistributionDescription fromString(CharSequence s) {
Iterator<String> tokens = Splitter.on(',').split(s).iterator();
String modelFactory = tokens.next();
String modelPrototype = tokens.next();
String distanceMeasure = tokens.next();
int prototypeSize = Integer.parseInt(tokens.next());
return new DistributionDescription(modelFactory, modelPrototype, distanceMeasure, prototypeSize);
}
}