/*
* Seldon -- open source prediction engine
* =======================================
* Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
*
**********************************************************************************************
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**********************************************************************************************
*/
package io.seldon.vw;
import io.seldon.recommendation.model.ModelManager;
import io.seldon.resources.external.ExternalResourceStreamer;
import io.seldon.resources.external.NewResourceNotifier;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.io.*;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@Component
public class VwModelManager extends ModelManager<VwModelManager.VwModel> {
private static Logger logger = Logger.getLogger(VwModelManager.class.getName());
private final ExternalResourceStreamer featuresFileHandler;
private static final String MF_NEW_LOC_PATTERN = "vw";
@Autowired
public VwModelManager(ExternalResourceStreamer featuresFileHandler,
NewResourceNotifier notifier) {
super(notifier, Collections.singleton(MF_NEW_LOC_PATTERN));
this.featuresFileHandler = featuresFileHandler;
}
private Map<Integer, String> loadClassIdMap(BufferedReader reader) throws IOException {
Map<Integer, String> classIdMap = new HashMap<Integer, String>();
String line;
while ((line = reader.readLine()) != null) {
int firstComma = line.indexOf(",");
if (firstComma > 0) {
int id = Integer.parseInt(line.substring(0, firstComma));
String className = line.substring(firstComma + 1);
classIdMap.put(id, className);
}
}
return classIdMap;
}
private VwModel loadModel(BufferedReader reader, Map<Integer, String> classIdMap) throws IOException {
Map<Integer, Float> weights = new HashMap<Integer, Float>();
int oaa = 1;
int bits = 18;
String line;
boolean insideFeatures = false;
while ((line = reader.readLine()) != null) {
if (!insideFeatures) {
if (line.startsWith("bits:")) {
bits = Integer.parseInt(line.split(":")[1]);
} else if (line.startsWith("options:")) {
String[] parts = line.split(":");
if (parts.length > 1) {
String[] options = parts[1].split("\\s+");
for (int i = 0; i < options.length; i++) {
if ("--oaa".equals(options[i])) {
oaa = Integer.parseInt(options[i + 1]);
i++;
} else
logger.warn("Unhandled VW option - this model may not behave correctly at prediction time " + options[i]);
}
}
} else if (line.startsWith(":0")) {
insideFeatures = true;
}
} else {
String[] featureAndWeight = line.split(":");
int feature = Integer.parseInt(featureAndWeight[0]);
float weight = Float.parseFloat(featureAndWeight[1]);
weights.put(feature, weight);
}
}
return new VwModel(bits, oaa, weights, classIdMap);
}
@Override
protected VwModel loadModel(String location, String client) {
logger.info("Reloading VW model for client: " + client);
try (
BufferedReader modelReader = new BufferedReader(new InputStreamReader(
featuresFileHandler.getResourceStream(location + "/model")
))
) {
Map<Integer, String> classIdMap;
try (
BufferedReader modelReader2 = new BufferedReader(new InputStreamReader(
featuresFileHandler.getResourceStream(location + "/classes.txt")
))
) {
classIdMap = loadClassIdMap(modelReader2);
} catch (IOException e) {
logger.warn("Found no classes.txt for " + location);
classIdMap = new HashMap<Integer, String>();
}
VwModel model = loadModel(modelReader, classIdMap);
logger.info("Loaded VW model from " + location + " for " + client + " with " + model.weights.size() + " weights and " + model.classIdMap.size() + " classes");
return model;
} catch (FileNotFoundException e) {
logger.error("Couldn't reload modelfor client " + client + " at " + location, e);
} catch (IOException e) {
logger.error("Couldn't reload model for client " + client + " at " + location, e);
}
return null;
}
public static class VwModel {
public final int bits;
public final int oaa;
public final Map<Integer, String> classIdMap;
public final Map<Integer, Float> weights;
public final VwFeatureHash hasher;
public VwModel(int bits, int oaa, Map<Integer, Float> weights, Map<Integer, String> classIdMap) {
super();
this.bits = bits;
this.oaa = oaa;
this.weights = weights;
this.hasher = new VwFeatureHash(bits, oaa);
this.classIdMap = classIdMap;
}
@Override
public String toString() {
return "VwModel [bits=" + bits + ", oaa=" + oaa + ", weights="
+ weights + "]";
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new FileReader("/home/clive/work/seldon/external_prediction_server/vw/iris/model.txt"));
VwModelManager m = new VwModelManager(null, null);
VwModel vwModel = m.loadModel(br, new HashMap<Integer, String>());
System.out.println(vwModel);
}
}