/*
* Seldon -- open source prediction engine
* =======================================
* Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
*
**********************************************************************************************
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**********************************************************************************************
*/
package io.seldon.recommendation.userdimensionmapping;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.seldon.recommendation.model.ModelManager;
import io.seldon.resources.external.ExternalResourceStreamer;
import io.seldon.resources.external.NewResourceNotifier;
@Component
public class UserDimensionMappingModelManager extends ModelManager<UserDimensionMappingModelManager.UserDimensionMappingModel> {
private static Logger logger = Logger.getLogger(UserDimensionMappingModelManager.class.getName());
private final ExternalResourceStreamer dimensionsFileHandler;
public static final String USER_DIM_MODEL_LOC_PATTERN = "user_dim_model";
private final Map<String, UserDimensionMappingModel> client_userDimensionMappingModel = new HashMap<>();
@Autowired
public UserDimensionMappingModelManager(ExternalResourceStreamer dimensionsFileHandler, NewResourceNotifier notifier) {
super(notifier, Collections.singleton(USER_DIM_MODEL_LOC_PATTERN));
this.dimensionsFileHandler = dimensionsFileHandler;
logger.info("Initializing");
}
@Override
protected UserDimensionMappingModel loadModel(String location, String client) {
logger.info("Reloading user dimensions for client[" + client + "]");
UserDimensionMappingModel model = null;
try {
// location = location + "/part-00000";
BufferedReader reader = new BufferedReader(new InputStreamReader(dimensionsFileHandler.getResourceStream(location)));
model = loadUserDimensionMapping(reader);
reader.close();
int num_users = (model != null) ? model.userToDimensions.size() : 0;
logger.info("Loaded user dimensions for client[" + client + "] users[" + num_users + "]");
} catch (IOException e) {
logger.error("Couldn't reload user dimensions for client " + client, e);
} catch (Exception e) {
logger.error("Couldn't reload user dimensions for client " + client, e);
}
if (model != null) {
client_userDimensionMappingModel.put(client, model);
} else {
client_userDimensionMappingModel.remove(client);
}
return model;
}
public UserDimensionMappingModel loadUserDimensionMapping(BufferedReader reader) throws IOException {
Map<String, DimensionMapping> userToDimensions = new HashMap<String, DimensionMapping>();
String line;
ObjectMapper mapper = new ObjectMapper();
while ((line = reader.readLine()) != null) {
DimensionMapping userDimensionMapping = mapper.readValue(line, DimensionMapping.class);
userToDimensions.put(userDimensionMapping.client_user_id, userDimensionMapping);
}
UserDimensionMappingModel model = new UserDimensionMappingModel(userToDimensions);
return model;
}
public Set<Integer> getMappedDimensionsByUser(String client, Set<Integer> dimensions, String client_user_id) {
logger.debug("dimensions in: " + dimensions);
UserDimensionMappingModel userDimensionMappingModel = client_userDimensionMappingModel.get(client);
if (userDimensionMappingModel == null) {
logger.debug(String.format("No mappings for client[%s]", client));
logger.debug("dimensions out: " + dimensions);
return dimensions; // no mappings for this client so return input
}
DimensionMapping dimensionMapping = userDimensionMappingModel.userToDimensions.get(client_user_id);
if (dimensionMapping == null) {
logger.debug(String.format("No mappings for client_user_id[%s]", client_user_id));
logger.debug("dimensions out: " + dimensions);
return dimensions; // no mappings for this userid so return input
}
Set<Integer> mapped_dimensions = new HashSet<Integer>();
for (Integer dimension : dimensions) {
if (dimensionMapping.dims_in.contains(dimension)) {
mapped_dimensions.addAll(dimensionMapping.dims_out);
} else {
mapped_dimensions.add(dimension);
}
}
logger.debug("dimensions out: " + mapped_dimensions);
return mapped_dimensions;
}
public static class DimensionMapping {
String client_user_id;
Set<Integer> dims_in;
Set<Integer> dims_out;
public DimensionMapping() {
}
public String getClient_user_id() {
return client_user_id;
}
public void setClient_user_id(String client_user_id) {
this.client_user_id = client_user_id;
}
public Set<Integer> getDims_in() {
return dims_in;
}
public void setDims_in(Set<Integer> dims_in) {
this.dims_in = dims_in;
}
public Set<Integer> getDims_out() {
return dims_out;
}
public void setDims_out(Set<Integer> dims_out) {
this.dims_out = dims_out;
}
@Override
public String toString() {
String output = String.format("{client_user_id:%s, dims_in:%s, dims_out:%s}", client_user_id, dims_in, dims_out);
return output;
}
}
public static class UserDimensionMappingModel {
final Map<String, DimensionMapping> userToDimensions;
public UserDimensionMappingModel(Map<String, DimensionMapping> userToDimensions) {
this.userToDimensions = userToDimensions;
}
}
}