/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.pmerienne.trident.ml.clustering; import static org.junit.Assert.assertEquals; import org.junit.Test; import com.github.pmerienne.trident.ml.clustering.ClusterQuery; import com.github.pmerienne.trident.ml.clustering.ClusterUpdater; import com.github.pmerienne.trident.ml.clustering.KMeans; import com.github.pmerienne.trident.ml.core.Instance; import com.github.pmerienne.trident.ml.preprocessing.InstanceCreator; import com.github.pmerienne.trident.ml.testing.RandomFeaturesForClusteringSpout; import storm.trident.TridentState; import storm.trident.TridentTopology; import storm.trident.operation.BaseFunction; import storm.trident.operation.TridentCollector; import storm.trident.testing.MemoryMapState; import storm.trident.tuple.TridentTuple; import backtype.storm.Config; import backtype.storm.LocalCluster; import backtype.storm.LocalDRPC; import backtype.storm.tuple.Fields; import backtype.storm.tuple.Values; public class ClustererTridentIntegrationTest { @Test public void testInTopology() throws InterruptedException { // Start local cluster LocalCluster cluster = new LocalCluster(); LocalDRPC localDRPC = new LocalDRPC(); try { // Build topology TridentTopology toppology = new TridentTopology(); // Training stream TridentState kmeansState = toppology // Emit tuples with a instance containing an integer as label and 3 // double features named (x0, x1 and x2) .newStream("samples", new RandomFeaturesForClusteringSpout()) // Convert trident tuple to instance .each(new Fields("label", "x0", "x1", "x2"), new InstanceCreator<Integer>(), new Fields("instance")) // Update a 3 classes kmeans .partitionPersist(new MemoryMapState.Factory(), new Fields("instance"), new ClusterUpdater("kmeans", new KMeans(3))); // Cluster stream toppology.newDRPCStream("predict", localDRPC) // Convert DRPC args to instance .each(new Fields("args"), new DRPCArgsToInstance(), new Fields("instance")) // Query kmeans to classify instance .stateQuery(kmeansState, new Fields("instance"), new ClusterQuery("kmeans"), new Fields("prediction")) .project(new Fields("prediction")); cluster.submitTopology(this.getClass().getSimpleName(), new Config(), toppology.build()); Thread.sleep(10000); Integer result11 = extractPrediction(localDRPC.execute("predict", "1.0 1.0 1.0")); Integer result12 = extractPrediction(localDRPC.execute("predict", "0.8 1.1 0.9")); assertEquals(result11, result12); Integer result21 = extractPrediction(localDRPC.execute("predict", "1.0 -1.0 1.0")); Integer result22 = extractPrediction(localDRPC.execute("predict", "0.8 -1.1 0.9")); assertEquals(result21, result22); Integer result31 = extractPrediction(localDRPC.execute("predict", "1.0 -1.0 -1.0")); Integer result32 = extractPrediction(localDRPC.execute("predict", "0.8 -1.1 -0.9")); assertEquals(result31, result32); } finally { cluster.shutdown(); localDRPC.shutdown(); } } protected static Integer extractPrediction(String drpcResult) { return Integer.parseInt(drpcResult.replaceAll("\\[", "").replaceAll("\\]", "")); } public static class DRPCArgsToInstance extends BaseFunction { private static final long serialVersionUID = -2932222000448806586L; @SuppressWarnings("rawtypes") @Override public void execute(TridentTuple tuple, TridentCollector collector) { String[] args = tuple.getString(0).split(" "); double[] features = new double[args.length]; for (int i = 0; i < args.length; i++) { features[i] = Double.parseDouble(args[i]); } Instance<?> instance = new Instance(features); collector.emit(new Values(instance)); } } }