package org.deeplearning4j.nn.simple.multiclass; import org.junit.Test; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assume.assumeNotNull; /** * Created by agibsonccc on 4/28/17. */ public class RankClassificationResultTest { @Test public void testOutcome() { RankClassificationResult result = new RankClassificationResult(Transforms.sigmoid(Nd4j.linspace(1,4,4)).reshape(2,2)); assumeNotNull(result.getLabels()); assertEquals("1",result.maxOutcomeForRow(0)); assertEquals("1",result.maxOutcomeForRow(1)); List<String> maxOutcomes = result.maxOutcomes(); assertEquals(2,result.maxOutcomes().size()); for(int i = 0; i < 2; i++) { assertEquals("1",maxOutcomes.get(i)); } } }