/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.mix.server;
import static org.mockito.Mockito.mock;
import hivemall.mix.MixMessage;
import hivemall.mix.MixMessage.MixEventName;
import hivemall.mix.store.PartialAverage;
import hivemall.mix.store.PartialResult;
import hivemall.mix.store.SessionObject;
import hivemall.mix.store.SessionStore;
import hivemall.test.HivemallTestBase;
import io.netty.channel.ChannelHandlerContext;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.hamcrest.Description;
import org.hamcrest.TypeSafeMatcher;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public final class MixServerHandlerTest extends HivemallTestBase {
static final Integer dummyFeature = 0;
@Rule
public ExpectedException exception = ExpectedException.none();
@Test
public void MxiWeightTest() throws NoSuchMethodException, SecurityException,
IllegalAccessException, IllegalArgumentException, InvocationTargetException {
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
SessionStore session = new SessionStore();
MixServerHandler handler = new MixServerHandler(session, 4, 1.0f);
Method mixMethod = MixServerHandler.class.getDeclaredMethod("mix",
ChannelHandlerContext.class, MixMessage.class, PartialResult.class, SessionObject.class);
mixMethod.setAccessible(true);
SessionObject sessionObj = session.get("dummy");
// Initially, clock=0
PartialAverage acc = new PartialAverage();
MixMessage msg1 = new MixMessage(MixEventName.average, dummyFeature, 3.0f, (short) 5, 1);
mixMethod.invoke(handler, ctx, msg1, acc, sessionObj);
Assert.assertEquals(1, acc.getClock());
Assert.assertEquals(3.0, acc.getWeight(1.0f), 0.001);
MixMessage msg2 = new MixMessage(MixEventName.average, dummyFeature, 5.0f, (short) -1, 1);
mixMethod.invoke(handler, ctx, msg2, acc, sessionObj);
Assert.assertEquals(2, acc.getClock());
Assert.assertEquals(4.0, acc.getWeight(1.0f), 0.001);
MixMessage msg3 = new MixMessage(MixEventName.average, dummyFeature, 7.0f, (short) 6, 1);
mixMethod.invoke(handler, ctx, msg3, acc, sessionObj);
Assert.assertEquals(3, acc.getClock());
Assert.assertEquals(5.0, acc.getWeight(1.0f), 0.001);
// Check expected exceptions
exception.expectCause(new CauseMatcher(IllegalArgumentException.class,
"Illegal deltaUpdates received: 0"));
MixMessage msg4 = new MixMessage(MixEventName.average, dummyFeature, 0.0f, (short) 0, 0);
mixMethod.invoke(handler, ctx, msg4, acc, sessionObj);
}
private static class CauseMatcher extends TypeSafeMatcher<Throwable> {
private final Class<? extends Throwable> type;
private final String expectedMessage;
public CauseMatcher(Class<? extends Throwable> type, String expectedMessage) {
this.type = type;
this.expectedMessage = expectedMessage;
}
@Override
protected boolean matchesSafely(Throwable item) {
return item.getClass().isAssignableFrom(type)
&& item.getMessage().contains(expectedMessage);
}
@Override
public void describeTo(Description description) {
description.appendText("expects type ")
.appendValue(type)
.appendText(" and a message ")
.appendValue(expectedMessage);
}
}
}