package com.skp.experiment.cf.als.hadoop; import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Iterator; import java.util.Map; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.Pair; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.SequentialAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.apache.mahout.math.map.OpenIntObjectHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool; import redis.clients.jedis.JedisPoolConfig; import redis.clients.jedis.Pipeline; import com.google.common.primitives.Ints; import com.skp.experiment.math.als.hadoop.ImplicitFeedbackAlternatingLeastSquaresSolver; public class TestRedisJob extends AbstractJob { private static final String HOST_NAME = "20.20.20.31"; private static final Logger log = LoggerFactory.getLogger(TestRedisJob.class); private ImplicitFeedbackAlternatingLeastSquaresSolver solver; public static void main(String[] args) throws Exception { ToolRunner.run(new TestRedisJob(), args); } @Override public int run(String[] args) throws Exception { addInputOption(); addOutputOption(); addOption("indexSizes", null, "."); addOption("numFeatures", null, ",", String.valueOf(30)); addOption("m", null, ","); Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } long startTime = System.currentTimeMillis(); Map<String, String> indexSizesTmp = ALSMatrixUtil.fetchTextFiles(new Path(getOption("indexSizes")), ",", Arrays.asList(0), Arrays.asList(1)); int numUsers = Integer.parseInt(indexSizesTmp.get("0")) + 1; int numItems = Integer.parseInt(indexSizesTmp.get("1")) + 1; int numFeatures = Integer.parseInt(getOption("numFeatures")); Path M = new Path(getOption("m")); Matrix Y = ALSMatrixUtil.readDenseMatrixByRows(M, getConf(), numItems, numFeatures); Matrix YtransposeY = Y.transpose().times(Y); System.out.println("Y Size: " + Y.numRows() + "\t" + Y.numCols()); solver = new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, 0.065, 40, Y, YtransposeY); SequenceFileDirIterator<IntWritable, VectorWritable> ratings = new SequenceFileDirIterator<IntWritable, VectorWritable>(getInputPath(), PathType.LIST, null, null, true, getConf()); while (ratings.hasNext()) { Pair<IntWritable, VectorWritable> rating = ratings.next(); Vector v = new SequentialAccessSparseVector(rating.getSecond().get()); System.out.println("in" + rating.getFirst() + "\t" + rating.getSecond()); Vector uiOrmj = solver.solve(v); System.out.println(rating.getFirst() + "\t" + uiOrmj); } return 0; } private static boolean flushAll() { log.info("check redis server from local"); JedisPool pool = null; Jedis jedis = null; Pipeline pipeline = null; try { pool = new JedisPool(new JedisPoolConfig(), HOST_NAME); jedis = pool.getResource(); pipeline = jedis.pipelined(); jedis.flushAll(); } catch (Exception e) { return false; } finally { pipeline.exec(); pool.returnResource(jedis); pool.destroy(); } return true; } private static String generateRedisProto(String... args) { StringBuffer sb = new StringBuffer(); sb.append("*" + args.length + "\r\n"); for (String arg : args) { sb.append("$" + arg.getBytes().length + "\r\n"); sb.append(arg + "\r\n"); } return sb.toString(); } public static byte[] toByteArray(double value) { byte[] bytes = new byte[8]; ByteBuffer.wrap(bytes).putDouble(value); return bytes; } private static class Map1 extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> { private static JedisPool pool; private static Jedis jedis; private static Pipeline pipeline; private static long sum = 0; private static long cnt = 0; @Override protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, InterruptedException { Vector v = value.get(); Iterator<Vector.Element> iter = v.iterateNonZero(); long start = System.currentTimeMillis(); pipeline.set(Ints.toByteArray(key.get()), buildOutput(v).getBytes()); sum += System.currentTimeMillis() - start; cnt++; /* while(iter.hasNext()) { Vector.Element e = iter.next(); pipeline.rpush(Ints.toByteArray(key.get()), toByteArray(e.get())); } */ } private static String buildOutput(Vector v) { StringBuffer sb = new StringBuffer(); Iterator<Vector.Element> iter = v.iterateNonZero(); int idx = 0; while (iter.hasNext()) { Vector.Element e = iter.next(); if (idx++ != 0) { sb.append(","); } sb.append(e.get()); } return sb.toString(); } @Override protected void setup(Context context) throws IOException, InterruptedException { JedisPoolConfig jedisConf = new JedisPoolConfig(); jedisConf.setMaxWait(120000); pool = new JedisPool(jedisConf, HOST_NAME); jedis = pool.getResource(); pipeline = jedis.pipelined(); } @Override protected void cleanup(Context context) throws IOException, InterruptedException { log.info("Average access time: " + (sum / (double)cnt)); pipeline.exec(); pool.returnResource(jedis); pool.destroy(); } } private static class Map2 extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> { private static JedisPool pool; private static Jedis jedis; private static long sum = 0; private static long cnt = 0; @Override protected void cleanup(Context context) throws IOException, InterruptedException { pool.returnResource(jedis); pool.destroy(); } @Override protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, InterruptedException { Vector userRatings = value.get(); Iterator<Vector.Element> ratings = userRatings.iterateNonZero(); while (ratings.hasNext()) { Vector.Element r = ratings.next(); long start = System.currentTimeMillis(); String features = jedis.get(r.toString()); sum += System.currentTimeMillis() - start; cnt++; } } @Override protected void setup(Context context) throws IOException, InterruptedException { JedisPoolConfig jedisConf = new JedisPoolConfig(); jedisConf.setMaxWait(120000); pool = new JedisPool(jedisConf, HOST_NAME); jedis = pool.getResource(); context.setStatus("Average access time: " + ((double)sum / cnt)); } } }