/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.ml;
import io.airlift.slice.Slice;
import org.testng.annotations.Test;
import static com.facebook.presto.ml.TestUtils.getDataset;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;
public class TestModelSerialization
{
@Test
public void testSvmClassifier()
{
Model model = new SvmClassifier();
model.train(getDataset());
Slice serialized = ModelUtils.serialize(model);
Model deserialized = ModelUtils.deserialize(serialized);
assertNotNull(deserialized, "deserialization failed");
assertTrue(deserialized instanceof SvmClassifier, "deserialized model is not a svm");
}
@Test
public void testSvmRegressor()
{
Model model = new SvmRegressor();
model.train(getDataset());
Slice serialized = ModelUtils.serialize(model);
Model deserialized = ModelUtils.deserialize(serialized);
assertNotNull(deserialized, "deserialization failed");
assertTrue(deserialized instanceof SvmRegressor, "deserialized model is not a svm");
}
@Test
public void testRegressorFeatureTransformer()
{
Model model = new RegressorFeatureTransformer(new SvmRegressor(), new FeatureVectorUnitNormalizer());
model.train(getDataset());
Slice serialized = ModelUtils.serialize(model);
Model deserialized = ModelUtils.deserialize(serialized);
assertNotNull(deserialized, "deserialization failed");
assertTrue(deserialized instanceof RegressorFeatureTransformer, "deserialized model is not a regressor feature transformer");
}
@Test
public void testClassifierFeatureTransformer()
{
Model model = new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureVectorUnitNormalizer());
model.train(getDataset());
Slice serialized = ModelUtils.serialize(model);
Model deserialized = ModelUtils.deserialize(serialized);
assertNotNull(deserialized, "deserialization failed");
assertTrue(deserialized instanceof ClassifierFeatureTransformer, "deserialized model is not a classifier feature transformer");
}
@Test
public void testVarcharClassifierAdapter()
{
Model model = new StringClassifierAdapter(new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureVectorUnitNormalizer()));
model.train(getDataset());
Slice serialized = ModelUtils.serialize(model);
Model deserialized = ModelUtils.deserialize(serialized);
assertNotNull(deserialized, "deserialization failed");
assertTrue(deserialized instanceof StringClassifierAdapter, "deserialized model is not a varchar classifier adapter");
}
@Test
public void testSerializationIds()
{
assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(SvmClassifier.class), 1);
assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(SvmRegressor.class), 2);
assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(FeatureVectorUnitNormalizer.class), 3);
assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(ClassifierFeatureTransformer.class), 4);
assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(RegressorFeatureTransformer.class), 5);
assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(FeatureUnitNormalizer.class), 6);
assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(StringClassifierAdapter.class), 7);
}
}