/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.cf.taste.impl.model.cassandra; import com.google.common.base.Preconditions; import me.prettyprint.cassandra.model.HColumnImpl; import me.prettyprint.cassandra.serializers.BytesArraySerializer; import me.prettyprint.cassandra.serializers.FloatSerializer; import me.prettyprint.cassandra.serializers.LongSerializer; import me.prettyprint.cassandra.service.OperationType; import me.prettyprint.hector.api.Cluster; import me.prettyprint.hector.api.ConsistencyLevelPolicy; import me.prettyprint.hector.api.HConsistencyLevel; import me.prettyprint.hector.api.Keyspace; import me.prettyprint.hector.api.beans.ColumnSlice; import me.prettyprint.hector.api.beans.HColumn; import me.prettyprint.hector.api.factory.HFactory; import me.prettyprint.hector.api.mutation.Mutator; import me.prettyprint.hector.api.query.ColumnQuery; import me.prettyprint.hector.api.query.CountQuery; import me.prettyprint.hector.api.query.SliceQuery; import org.apache.mahout.cf.taste.common.NoSuchItemException; import org.apache.mahout.cf.taste.common.NoSuchUserException; import org.apache.mahout.cf.taste.common.Refreshable; import org.apache.mahout.cf.taste.common.TasteException; import org.apache.mahout.cf.taste.impl.common.Cache; import org.apache.mahout.cf.taste.impl.common.FastIDSet; import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; import org.apache.mahout.cf.taste.impl.common.Retriever; import org.apache.mahout.cf.taste.impl.model.GenericItemPreferenceArray; import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; import org.apache.mahout.cf.taste.model.DataModel; import org.apache.mahout.cf.taste.model.PreferenceArray; import java.io.Closeable; import java.util.Collection; import java.util.List; import java.util.concurrent.atomic.AtomicReference; /** * <p>A {@link DataModel} based on a Cassandra keyspace. By default it uses keyspace "recommender" but this * can be configured. Create the keyspace before using this class; this can be done on the Cassandra command * line with a command linke {@code create keyspace recommender;}.</p> * * <p>Within the keyspace, this model uses four column families:</p> * * <p>First, it uses a column family called "users". This is keyed by the user ID as an 8-byte long. * It contains a column for every preference the user expresses. The column name is item ID, again as * an 8-byte long, and value is a floating point value represnted as an IEEE 32-bit floating poitn value.</p> * * <p>It uses an analogous column family called "items" for the same data, but keyed by item ID rather * than user ID. In this column family, column names are user IDs instead.</p> * * <p>It uses a column family called "userIDs" as well, with an identical schema. It has one row under key * 0. IT contains a column for every user ID in th emodel. It has no values.</p> * * <p>Finally it also uses an analogous column family "itemIDs" containing item IDs.</p> * * <p>Each of these four column families needs to be created ahead of time. Again the * Cassandra CLI can be used to do so, with commands like {@code create column family users;}.</p> * * <p>Note that this thread uses a long-lived Cassandra client which will run until terminated. You * must {@link #close()} this implementation when done or the JVM will not terminate.</p> * * <p>This implementation still relies heavily on reading data into memory and caching, * as it remains too data-intensive to be effective even against Cassandra. It will take some time to * "warm up" as the first few requests will block loading user and item data into caches. This is still going * to send a great deal of query traffic to Cassandra. It would be advisable to employ caching wrapper * classes in your implementation, like {@link org.apache.mahout.cf.taste.impl.recommender.CachingRecommender} * or {@link org.apache.mahout.cf.taste.impl.similarity.CachingItemSimilarity}.</p> */ public final class CassandraDataModel implements DataModel, Closeable { /** Default Cassandra host. Default: localhost */ private static final String DEFAULT_HOST = "localhost"; /** Default Cassandra port. Default: 9160 */ private static final int DEFAULT_PORT = 9160; /** Default Cassandra keyspace. Default: recommender */ private static final String DEFAULT_KEYSPACE = "recommender"; static final String USERS_CF = "users"; static final String ITEMS_CF = "items"; static final String USER_IDS_CF = "userIDs"; static final String ITEM_IDS_CF = "itemIDs"; private static final long ID_ROW_KEY = 0L; private static final byte[] EMPTY = new byte[0]; private final Cluster cluster; private final Keyspace keyspace; private final Cache<Long,PreferenceArray> userCache; private final Cache<Long,PreferenceArray> itemCache; private final Cache<Long,FastIDSet> itemIDsFromUserCache; private final Cache<Long,FastIDSet> userIDsFromItemCache; private final AtomicReference<Integer> userCountCache; private final AtomicReference<Integer> itemCountCache; /** * Uses the standard Cassandra host and port (localhost:9160), and keyspace name ("recommender"). */ public CassandraDataModel() { this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_KEYSPACE); } /** * @param host Cassandra server host name * @param port Cassandra server port * @param keyspaceName name of Cassandra keyspace to use */ public CassandraDataModel(String host, int port, String keyspaceName) { Preconditions.checkNotNull(host); Preconditions.checkArgument(port > 0); Preconditions.checkNotNull(keyspaceName); cluster = HFactory.getOrCreateCluster(CassandraDataModel.class.getSimpleName(), host + ':' + port); keyspace = HFactory.createKeyspace(keyspaceName, cluster); keyspace.setConsistencyLevelPolicy(new OneConsistencyLevelPolicy()); userCache = new Cache<Long,PreferenceArray>(new UserPrefArrayRetriever(), 1 << 20); itemCache = new Cache<Long,PreferenceArray>(new ItemPrefArrayRetriever(), 1 << 20); itemIDsFromUserCache = new Cache<Long,FastIDSet>(new ItemIDsFromUserRetriever(), 1 << 20); userIDsFromItemCache = new Cache<Long,FastIDSet>(new UserIDsFromItemRetriever(), 1 << 20); userCountCache = new AtomicReference<Integer>(null); itemCountCache = new AtomicReference<Integer>(null); } @Override public LongPrimitiveIterator getUserIDs() { SliceQuery<Long,Long,?> query = buildNoValueSliceQuery(USER_IDS_CF); query.setKey(ID_ROW_KEY); FastIDSet userIDs = new FastIDSet(); for (HColumn<Long,?> userIDColumn : query.execute().get().getColumns()) { userIDs.add(userIDColumn.getName()); } return userIDs.iterator(); } @Override public PreferenceArray getPreferencesFromUser(long userID) throws TasteException { return userCache.get(userID); } @Override public FastIDSet getItemIDsFromUser(long userID) throws TasteException { return itemIDsFromUserCache.get(userID); } @Override public LongPrimitiveIterator getItemIDs() { SliceQuery<Long,Long,?> query = buildNoValueSliceQuery(ITEM_IDS_CF); query.setKey(ID_ROW_KEY); FastIDSet itemIDs = new FastIDSet(); for (HColumn<Long,?> itemIDColumn : query.execute().get().getColumns()) { itemIDs.add(itemIDColumn.getName()); } return itemIDs.iterator(); } @Override public PreferenceArray getPreferencesForItem(long itemID) throws TasteException { return itemCache.get(itemID); } @Override public Float getPreferenceValue(long userID, long itemID) { ColumnQuery<Long,Long,Float> query = HFactory.createColumnQuery(keyspace, LongSerializer.get(), LongSerializer.get(), FloatSerializer.get()); query.setColumnFamily(USERS_CF); query.setKey(userID); query.setName(itemID); HColumn<Long,Float> column = query.execute().get(); return column == null ? null : column.getValue(); } @Override public Long getPreferenceTime(long userID, long itemID) { ColumnQuery<Long,Long,?> query = HFactory.createColumnQuery(keyspace, LongSerializer.get(), LongSerializer.get(), BytesArraySerializer.get()); query.setColumnFamily(USERS_CF); query.setKey(userID); query.setName(itemID); HColumn<Long,?> result = query.execute().get(); return result == null ? null : result.getClock(); } @Override public int getNumItems() { Integer itemCount = itemCountCache.get(); if (itemCount == null) { CountQuery<Long,Long> countQuery = HFactory.createCountQuery(keyspace, LongSerializer.get(), LongSerializer.get()); countQuery.setKey(ID_ROW_KEY); countQuery.setColumnFamily(ITEM_IDS_CF); countQuery.setRange(null, null, Integer.MAX_VALUE); itemCount = countQuery.execute().get(); itemCountCache.set(itemCount); } return itemCount; } @Override public int getNumUsers() { Integer userCount = userCountCache.get(); if (userCount == null) { CountQuery<Long,Long> countQuery = HFactory.createCountQuery(keyspace, LongSerializer.get(), LongSerializer.get()); countQuery.setKey(ID_ROW_KEY); countQuery.setColumnFamily(USER_IDS_CF); countQuery.setRange(null, null, Integer.MAX_VALUE); userCount = countQuery.execute().get(); userCountCache.set(userCount); } return userCount; } @Override public int getNumUsersWithPreferenceFor(long itemID) throws TasteException { /* CountQuery<Long,Long> query = HFactory.createCountQuery(keyspace, LongSerializer.get(), LongSerializer.get()); query.setColumnFamily(ITEMS_CF); query.setKey(itemID); query.setRange(null, null, Integer.MAX_VALUE); return query.execute().get(); */ return userIDsFromItemCache.get(itemID).size(); } @Override public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException { FastIDSet userIDs1 = userIDsFromItemCache.get(itemID1); FastIDSet userIDs2 = userIDsFromItemCache.get(itemID2); return userIDs1.size() < userIDs2.size() ? userIDs2.intersectionSize(userIDs1) : userIDs1.intersectionSize(userIDs2); } @Override public void setPreference(long userID, long itemID, float value) { if (Float.isNaN(value)) { value = 1.0f; } long now = System.currentTimeMillis(); Mutator<Long> mutator = HFactory.createMutator(keyspace, LongSerializer.get()); HColumn<Long,Float> itemForUsers = new HColumnImpl<Long,Float>(LongSerializer.get(), FloatSerializer.get()); itemForUsers.setName(itemID); itemForUsers.setClock(now); itemForUsers.setValue(value); mutator.addInsertion(userID, USERS_CF, itemForUsers); HColumn<Long,Float> userForItems = new HColumnImpl<Long,Float>(LongSerializer.get(), FloatSerializer.get()); userForItems.setName(userID); userForItems.setClock(now); userForItems.setValue(value); mutator.addInsertion(itemID, ITEMS_CF, userForItems); HColumn<Long,byte[]> userIDs = new HColumnImpl<Long,byte[]>(LongSerializer.get(), BytesArraySerializer.get()); userIDs.setName(userID); userIDs.setClock(now); userIDs.setValue(EMPTY); mutator.addInsertion(ID_ROW_KEY, USER_IDS_CF, userIDs); HColumn<Long,byte[]> itemIDs = new HColumnImpl<Long,byte[]>(LongSerializer.get(), BytesArraySerializer.get()); itemIDs.setName(itemID); itemIDs.setClock(now); itemIDs.setValue(EMPTY); mutator.addInsertion(ID_ROW_KEY, ITEM_IDS_CF, itemIDs); mutator.execute(); } @Override public void removePreference(long userID, long itemID) { Mutator<Long> mutator = HFactory.createMutator(keyspace, LongSerializer.get()); mutator.addDeletion(userID, USERS_CF, itemID, LongSerializer.get()); mutator.addDeletion(itemID, ITEMS_CF, userID, LongSerializer.get()); mutator.execute(); // Not deleting from userIDs, itemIDs though } /** * @return true */ @Override public boolean hasPreferenceValues() { return true; } /** * @return Float#NaN */ @Override public float getMaxPreference() { return Float.NaN; } /** * @return Float#NaN */ @Override public float getMinPreference() { return Float.NaN; } @Override public void refresh(Collection<Refreshable> alreadyRefreshed) { userCache.clear(); itemCache.clear(); userIDsFromItemCache.clear(); itemIDsFromUserCache.clear(); userCountCache.set(null); itemCountCache.set(null); } @Override public String toString() { return "CassandraDataModel[" + keyspace + ']'; } @Override public void close() { HFactory.shutdownCluster(cluster); } private SliceQuery<Long,Long,byte[]> buildNoValueSliceQuery(String cf) { SliceQuery<Long,Long,byte[]> query = HFactory.createSliceQuery(keyspace, LongSerializer.get(), LongSerializer.get(), BytesArraySerializer.get()); query.setColumnFamily(cf); query.setRange(null, null, false, Integer.MAX_VALUE); return query; } private SliceQuery<Long,Long,Float> buildValueSliceQuery(String cf) { SliceQuery<Long,Long,Float> query = HFactory.createSliceQuery(keyspace, LongSerializer.get(), LongSerializer.get(), FloatSerializer.get()); query.setColumnFamily(cf); query.setRange(null, null, false, Integer.MAX_VALUE); return query; } private static final class OneConsistencyLevelPolicy implements ConsistencyLevelPolicy { @Override public HConsistencyLevel get(OperationType op) { return HConsistencyLevel.ONE; } @Override public HConsistencyLevel get(OperationType op, String cfName) { return HConsistencyLevel.ONE; } } private final class UserPrefArrayRetriever implements Retriever<Long, PreferenceArray> { @Override public PreferenceArray get(Long userID) throws TasteException { SliceQuery<Long,Long,Float> query = buildValueSliceQuery(USERS_CF); query.setKey(userID); ColumnSlice<Long,Float> result = query.execute().get(); if (result == null) { throw new NoSuchUserException(userID); } List<HColumn<Long,Float>> itemIDColumns = result.getColumns(); if (itemIDColumns.isEmpty()) { throw new NoSuchUserException(userID); } int size = itemIDColumns.size(); PreferenceArray prefs = new GenericUserPreferenceArray(size); prefs.setUserID(0, userID); for (int i = 0; i < size; i++) { HColumn<Long,Float> itemIDColumn = itemIDColumns.get(i); prefs.setItemID(i, itemIDColumn.getName()); prefs.setValue(i, itemIDColumn.getValue()); } return prefs; } } private final class ItemPrefArrayRetriever implements Retriever<Long, PreferenceArray> { @Override public PreferenceArray get(Long itemID) throws TasteException { SliceQuery<Long,Long,Float> query = buildValueSliceQuery(ITEMS_CF); query.setKey(itemID); ColumnSlice<Long,Float> result = query.execute().get(); if (result == null) { throw new NoSuchItemException(itemID); } List<HColumn<Long,Float>> userIDColumns = result.getColumns(); if (userIDColumns.isEmpty()) { throw new NoSuchItemException(itemID); } int size = userIDColumns.size(); PreferenceArray prefs = new GenericItemPreferenceArray(size); prefs.setItemID(0, itemID); for (int i = 0; i < size; i++) { HColumn<Long,Float> userIDColumn = userIDColumns.get(i); prefs.setUserID(i, userIDColumn.getName()); prefs.setValue(i, userIDColumn.getValue()); } return prefs; } } private final class UserIDsFromItemRetriever implements Retriever<Long, FastIDSet> { @Override public FastIDSet get(Long itemID) throws TasteException { SliceQuery<Long,Long,byte[]> query = buildNoValueSliceQuery(ITEMS_CF); query.setKey(itemID); ColumnSlice<Long,byte[]> result = query.execute().get(); if (result == null) { throw new NoSuchItemException(itemID); } List<HColumn<Long,byte[]>> columns = result.getColumns(); FastIDSet userIDs = new FastIDSet(columns.size()); for (HColumn<Long,?> userIDColumn : columns) { userIDs.add(userIDColumn.getName()); } return userIDs; } } private final class ItemIDsFromUserRetriever implements Retriever<Long, FastIDSet> { @Override public FastIDSet get(Long userID) throws TasteException { SliceQuery<Long,Long,byte[]> query = buildNoValueSliceQuery(USERS_CF); query.setKey(userID); FastIDSet itemIDs = new FastIDSet(); ColumnSlice<Long,byte[]> result = query.execute().get(); if (result == null) { throw new NoSuchUserException(userID); } List<HColumn<Long,byte[]>> columns = result.getColumns(); if (columns.isEmpty()) { throw new NoSuchUserException(userID); } for (HColumn<Long,?> itemIDColumn : columns) { itemIDs.add(itemIDColumn.getName()); } return itemIDs; } } }