/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.fm; import hivemall.utils.buffer.HeapBuffer; import hivemall.utils.codec.VariableByteCodec; import hivemall.utils.codec.ZigZagLEB128Codec; import hivemall.utils.collections.Int2LongOpenHashTable; import hivemall.utils.collections.IntOpenHashTable; import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm; import hivemall.utils.io.IOUtils; import hivemall.utils.lang.ArrayUtils; import hivemall.utils.lang.HalfFloat; import hivemall.utils.lang.ObjectUtils; import java.io.DataInput; import java.io.DataOutput; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Arrays; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; public final class FFMPredictionModel implements Externalizable { private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class); private static final byte HALF_FLOAT_ENTRY = 1; private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2; private static final byte FLOAT_ENTRY = 3; private static final byte W_ONLY_FLOAT_ENTRY = 4; /** * maps feature to feature weight pointer */ private Int2LongOpenHashTable _map; private HeapBuffer _buf; private double _w0; private int _factors; private int _numFeatures; private int _numFields; public FFMPredictionModel() {}// for Externalizable public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull HeapBuffer buf, double w0, int factor, int numFeatures, int numFields) { this._map = map; this._buf = buf; this._w0 = w0; this._factors = factor; this._numFeatures = numFeatures; this._numFields = numFields; } public int getNumFactors() { return _factors; } public double getW0() { return _w0; } public int getNumFeatures() { return _numFeatures; } public int getNumFields() { return _numFields; } public int getActualNumFeatures() { return _map.size(); } public long approxBytesConsumed() { int size = _map.size(); // [map] size * (|state| + |key| + |entry|) long bytes = size * (1L + 4L + 4L + (4L * _factors)); int rest = _map.capacity() - size; if (rest > 0) { bytes += rest * 1L; } // w0, factors, numFeatures, numFields, used, size bytes += (8 + 4 + 4 + 4 + 4 + 4); return bytes; } @Nullable private Entry getEntry(final int key) { final long ptr = _map.get(key); if (ptr == -1L) { return null; } return new Entry(_buf, _factors, ptr); } public float getW(@Nonnull final Feature x) { int j = x.getFeatureIndex(); Entry entry = getEntry(j); if (entry == null) { return 0.f; } return entry.getW(); } /** * @return true if V exists */ public boolean getV(@Nonnull final Feature x, @Nonnull final int yField, @Nonnull float[] dst) { int j = Feature.toIntFeature(x, yField, _numFields); Entry entry = getEntry(j); if (entry == null) { return false; } entry.getV(dst); if (ArrayUtils.equals(dst, 0.f)) { return false; // treat as null } return true; } @Override public void writeExternal(@Nonnull ObjectOutput out) throws IOException { out.writeDouble(_w0); final int factors = _factors; out.writeInt(factors); out.writeInt(_numFeatures); out.writeInt(_numFields); int used = _map.size(); out.writeInt(used); final int[] keys = _map.getKeys(); final int size = keys.length; out.writeInt(size); final byte[] states = _map.getStates(); writeStates(states, out); final long[] values = _map.getValues(); final HeapBuffer buf = _buf; final Entry e = new Entry(buf, factors); final float[] Vf = new float[factors]; for (int i = 0; i < size; i++) { if (states[i] != IntOpenHashTable.FULL) { continue; } ZigZagLEB128Codec.writeSignedInt(keys[i], out); e.setOffset(values[i]); writeEntry(e, factors, Vf, out); } // help GC this._map = null; this._buf = null; } private static void writeEntry(@Nonnull final Entry e, final int factors, @Nonnull final float[] Vf, @Nonnull final DataOutput out) throws IOException { final float W = e.getW(); e.getV(Vf); if (ArrayUtils.almostEquals(Vf, 0.f)) { if (HalfFloat.isRepresentable(W)) { out.writeByte(W_ONLY_HALF_FLOAT_ENTRY); out.writeShort(HalfFloat.floatToHalfFloat(W)); } else { out.writeByte(W_ONLY_FLOAT_ENTRY); out.writeFloat(W); } } else if (isRepresentableAsHalfFloat(W, Vf)) { out.writeByte(HALF_FLOAT_ENTRY); out.writeShort(HalfFloat.floatToHalfFloat(W)); for (int i = 0; i < factors; i++) { out.writeShort(HalfFloat.floatToHalfFloat(Vf[i])); } } else { out.writeByte(FLOAT_ENTRY); out.writeFloat(W); IOUtils.writeFloats(Vf, factors, out); } } private static boolean isRepresentableAsHalfFloat(final float W, @Nonnull final float[] Vf) { if (!HalfFloat.isRepresentable(W)) { return false; } for (float V : Vf) { if (!HalfFloat.isRepresentable(V)) { return false; } } return true; } @Nonnull static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out) throws IOException { // write empty states's indexes differentially final int size = status.length; int cardinarity = 0; for (int i = 0; i < size; i++) { if (status[i] != IntOpenHashTable.FULL) { cardinarity++; } } out.writeInt(cardinarity); if (cardinarity == 0) { return; } int prev = 0; for (int i = 0; i < size; i++) { if (status[i] != IntOpenHashTable.FULL) { int diff = i - prev; assert (diff >= 0); VariableByteCodec.encodeUnsignedInt(diff, out); prev = i; } } } @Override public void readExternal(@Nonnull final ObjectInput in) throws IOException, ClassNotFoundException { this._w0 = in.readDouble(); final int factors = in.readInt(); this._factors = factors; this._numFeatures = in.readInt(); this._numFields = in.readInt(); final int used = in.readInt(); final int size = in.readInt(); final int[] keys = new int[size]; final long[] values = new long[size]; final byte[] states = new byte[size]; readStates(in, states); final int entrySize = Entry.sizeOf(factors); int numChunks = (entrySize * used) / HeapBuffer.DEFAULT_CHUNK_BYTES + 1; final HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE, numChunks); final Entry e = new Entry(buf, factors); final float[] Vf = new float[factors]; for (int i = 0; i < size; i++) { if (states[i] != IntOpenHashTable.FULL) { continue; } keys[i] = ZigZagLEB128Codec.readSignedInt(in); long ptr = buf.allocate(entrySize); e.setOffset(ptr); readEntry(in, factors, Vf, e); values[i] = ptr; } this._map = new Int2LongOpenHashTable(keys, values, states, used); this._buf = buf; } @Nonnull private static void readEntry(@Nonnull final DataInput in, final int factors, @Nonnull final float[] Vf, @Nonnull Entry dst) throws IOException { final byte type = in.readByte(); switch (type) { case HALF_FLOAT_ENTRY: { float W = HalfFloat.halfFloatToFloat(in.readShort()); dst.setW(W); for (int i = 0; i < factors; i++) { Vf[i] = HalfFloat.halfFloatToFloat(in.readShort()); } dst.setV(Vf); break; } case W_ONLY_HALF_FLOAT_ENTRY: { float W = HalfFloat.halfFloatToFloat(in.readShort()); dst.setW(W); break; } case FLOAT_ENTRY: { float W = in.readFloat(); dst.setW(W); IOUtils.readFloats(in, Vf); dst.setV(Vf); break; } case W_ONLY_FLOAT_ENTRY: { float W = in.readFloat(); dst.setW(W); break; } default: throw new IOException("Unexpected Entry type: " + type); } } @Nonnull static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status) throws IOException { // read non-empty states differentially final int cardinarity = in.readInt(); Arrays.fill(status, IntOpenHashTable.FULL); int prev = 0; for (int j = 0; j < cardinarity; j++) { int i = VariableByteCodec.decodeUnsignedInt(in) + prev; status[i] = IntOpenHashTable.FREE; prev = i; } } public byte[] serialize() throws IOException { LOG.info("FFMPredictionModel#serialize(): " + _buf.toString()); return ObjectUtils.toCompressedBytes(this, CompressionAlgorithm.lzma2, true); } public static FFMPredictionModel deserialize(@Nonnull final byte[] serializedObj, final int len) throws ClassNotFoundException, IOException { FFMPredictionModel model = new FFMPredictionModel(); ObjectUtils.readCompressedObject(serializedObj, len, model, CompressionAlgorithm.lzma2, true); LOG.info("FFMPredictionModel#deserialize(): " + model._buf.toString()); return model; } }