/**
* Copyright 2012 plista GmbH (http://www.plista.com/)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/
package org.plista.kornakapi.core.storage;
import com.google.common.collect.Sets;
import org.apache.commons.dbcp.BasicDataSource;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.JDBCDataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.IOUtils;
import org.plista.kornakapi.core.Candidate;
import org.plista.kornakapi.core.config.StorageConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Set;
/** an implementation of {@link Storage} for MySQL */
public class MySqlStorage implements Storage {
protected final BasicDataSource dataSource;
private final JDBCDataModel dataModel;
private static final String IMPORT_QUERY =
"INSERT INTO taste_preferences (user_id, item_id, preference) VALUES (?, ?, ?) " +
"ON DUPLICATE KEY UPDATE preference = VALUES(preference)";
private static final String INSERT_CANDIDATE_QUERY =
"INSERT INTO taste_candidates (label, item_id) VALUES (?, ?)";
private static final String REMOVE_CANDIDATE_QUERY =
"DELETE FROM taste_candidates WHERE label = ? AND item_id = ?";
private static final String REMOVE_ALL_CANDIDATES_QUERY =
"DELETE FROM taste_candidates WHERE label = ?";
private static final String GET_CANDIDATES_QUERY =
"SELECT item_id FROM taste_candidates WHERE label = ?";
private static final String GET_LABELS = "SELECT DISTINCT label FROM taste_candidates";
private static final String GET_ITEMSLABEL = "SELECT label FROM taste_candidates WHERE item_id = ?";
private static final Logger log = LoggerFactory.getLogger(MySqlStorage.class);
public MySqlStorage(StorageConfiguration storageConf, String label, BasicDataSource dataSource) {
dataSource.setDriverClassName(storageConf.getJdbcDriverClass());
dataSource.setUrl(storageConf.getJdbcUrl());
dataSource.setUsername(storageConf.getUsername());
dataSource.setPassword(storageConf.getPassword());
//TODO should be made configurable
dataSource.setMaxActive(10);
dataSource.setMinIdle(5);
dataSource.setInitialSize(5);
dataSource.setValidationQuery("SELECT 1;");
dataSource.setTestOnBorrow(false);
dataSource.setTestOnReturn(false);
dataSource.setTestWhileIdle(true);
dataSource.setTimeBetweenEvictionRunsMillis(5000);
dataModel = new LabeledMySQLJDBCDataModel(dataSource,
"taste_preferences",
"user_id",
"item_id",
"preference",
"timestamp",
"taste_candidates",
"label",
label);
this.dataSource = dataSource;
}
@Override
public DataModel trainingData() throws IOException {
try {
return new GenericDataModel(dataModel.exportWithPrefs());
} catch (TasteException e) {
throw new IOException(e);
}
}
@Override
public DataModel recommenderData() throws IOException {
return dataModel;
}
@Override
public void setPreference(long userID, long itemID, float value) throws IOException {
try {
dataModel.setPreference(userID, itemID, value);
} catch (TasteException e) {
throw new IOException(e);
}
}
@Override
public void batchSetPreferences(Iterator<Preference> preferences, int batchSize) throws IOException {
Connection conn = null;
PreparedStatement stmt = null;
try {
conn = dataSource.getConnection();
stmt = conn.prepareStatement(IMPORT_QUERY);
int recordsQueued = 0;
while (preferences.hasNext()) {
Preference preference = preferences.next();
stmt.setLong(1, preference.getUserID());
stmt.setLong(2, preference.getItemID());
stmt.setFloat(3, preference.getValue());
stmt.addBatch();
if (++recordsQueued % batchSize == 0) {
stmt.executeBatch();
log.info("imported {} records in batch", recordsQueued);
}
}
if (recordsQueued % batchSize != 0) {
stmt.executeBatch();
log.info("imported {} records in batch. done.", recordsQueued);
}
} catch (SQLException e) {
throw new IOException(e);
} finally {
IOUtils.quietClose(stmt);
IOUtils.quietClose(conn);
}
}
@Override
public void addCandidate(String label, long itemID) throws IOException {
Connection conn = null;
PreparedStatement stmt = null;
try {
conn = dataSource.getConnection();
stmt = conn.prepareStatement(INSERT_CANDIDATE_QUERY);
stmt.setString(1, label);
stmt.setLong(2, itemID);
stmt.execute();
} catch (SQLException e) {
if (log.isInfoEnabled()) {
log.info(e.getMessage());
}else{
throw new IOException(e);
}
} finally {
IOUtils.quietClose(stmt);
IOUtils.quietClose(conn);
}
}
@Override
public Iterable<String> batchAddCandidates(Iterator<Candidate> candidates, int batchSize) throws IOException {
Set<String> modifiedLabels = Sets.newHashSet();
Connection conn = null;
PreparedStatement stmt = null;
try {
conn = dataSource.getConnection();
stmt = conn.prepareStatement(INSERT_CANDIDATE_QUERY);
int recordsQueued = 0;
while (candidates.hasNext()) {
Candidate candidate = candidates.next();
modifiedLabels.add(candidate.getLabel());
stmt.setString(1, candidate.getLabel());
stmt.setLong(2, candidate.getItemID());
stmt.addBatch();
if (++recordsQueued % batchSize == 0) {
stmt.executeBatch();
log.info("imported {} candidates in batch", recordsQueued);
}
}
if (recordsQueued % batchSize != 0) {
stmt.executeBatch();
log.info("imported {} candidates in batch. done.", recordsQueued);
}
} catch (SQLException e) {
throw new IOException(e);
} finally {
IOUtils.quietClose(stmt);
IOUtils.quietClose(conn);
}
return modifiedLabels;
}
@Override
public void deleteCandidate(String label, long itemID) throws IOException {
Connection conn = null;
PreparedStatement stmt = null;
try {
conn = dataSource.getConnection();
stmt = conn.prepareStatement(REMOVE_CANDIDATE_QUERY);
stmt.setString(1, label);
stmt.setLong(2, itemID);
stmt.execute();
} catch (SQLException e) {
throw new IOException(e);
} finally {
IOUtils.quietClose(stmt);
IOUtils.quietClose(conn);
}
}
@Override
public void deleteAllCandidates(String label) throws IOException {
Connection conn = null;
PreparedStatement stmt = null;
try {
conn = dataSource.getConnection();
stmt = conn.prepareStatement(REMOVE_ALL_CANDIDATES_QUERY);
stmt.setString(1, label);
stmt.execute();
} catch (SQLException e) {
throw new IOException(e);
} finally {
IOUtils.quietClose(stmt);
IOUtils.quietClose(conn);
}
}
@Override
public Iterable<String> batchDeleteCandidates(Iterator<Candidate> candidates, int batchSize) throws IOException {
Set<String> modifiedLabels = Sets.newHashSet();
Connection conn = null;
PreparedStatement stmt = null;
try {
conn = dataSource.getConnection();
stmt = conn.prepareStatement(REMOVE_CANDIDATE_QUERY);
int recordsQueued = 0;
while (candidates.hasNext()) {
Candidate candidate = candidates.next();
modifiedLabels.add(candidate.getLabel());
stmt.setString(1, candidate.getLabel());
stmt.setLong(2, candidate.getItemID());
stmt.addBatch();
if (++recordsQueued % batchSize == 0) {
stmt.executeBatch();
log.info("deleted {} candidates in batch", recordsQueued);
}
}
if (recordsQueued % batchSize != 0) {
stmt.executeBatch();
log.info("deleted {} candidates in batch. done.", recordsQueued);
}
} catch (SQLException e) {
throw new IOException(e);
} finally {
IOUtils.quietClose(stmt);
IOUtils.quietClose(conn);
}
return modifiedLabels;
}
@Override
public FastIDSet getCandidates(String label) throws IOException {
Connection conn = null;
PreparedStatement stmt = null;
ResultSet rs = null;
try {
FastIDSet candidates = new FastIDSet();
conn = dataSource.getConnection();
stmt = conn.prepareStatement(GET_CANDIDATES_QUERY, ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY);
stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
stmt.setFetchSize(1000);
stmt.setString(1, label);
rs = stmt.executeQuery();
while (rs.next()) {
candidates.add(rs.getLong(1));
}
return candidates;
} catch (SQLException e) {
throw new IOException(e);
} finally {
IOUtils.quietClose(rs, stmt, conn);
}
}
public LinkedList<String> getAllLabels() throws IOException {
Connection conn = null;
PreparedStatement stmt = null;
ResultSet rs = null;
try {
LinkedList<String> candidates = new LinkedList<String>();
conn = dataSource.getConnection();
stmt = conn.prepareStatement(GET_LABELS, ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY);
stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
stmt.setFetchSize(1000);
rs = stmt.executeQuery();
while (rs.next()) {
candidates.add(rs.getString(1));
}
return candidates;
} catch (SQLException e) {
throw new IOException(e);
} finally {
IOUtils.quietClose(rs, stmt, conn);
}
}
public String getItemsLabel(long itemid) throws IOException {
Connection conn = null;
PreparedStatement stmt = null;
ResultSet rs = null;
try {
conn = dataSource.getConnection();
stmt = conn.prepareStatement(GET_ITEMSLABEL, ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY);
stmt.setLong(1, itemid);
stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
stmt.setFetchSize(10);
rs = stmt.executeQuery();
String label = null;
if(rs.next()){
label = rs.getString(1);
}
return label;
} catch (SQLException e) {
throw new IOException(e);
} finally {
IOUtils.quietClose(rs, stmt, conn);
}
}
@Override
public void close() throws IOException {
try {
dataSource.close();
} catch (SQLException e) {
throw new IOException("Unable to close datasource", e);
}
}
}