/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.ml.math.impls.storage.matrix; import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap; import it.unimi.dsi.fastutil.ints.Int2DoubleRBTreeMap; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.CacheAtomicityMode; import org.apache.ignite.cache.CacheMode; import org.apache.ignite.cache.CachePeekMode; import org.apache.ignite.cache.CacheWriteSynchronizationMode; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.lang.IgniteUuid; import org.apache.ignite.ml.math.MatrixStorage; import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.impls.CacheUtils; import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; /** * {@link MatrixStorage} implementation for {@link SparseDistributedMatrix}. */ public class SparseDistributedMatrixStorage extends CacheUtils implements MatrixStorage, StorageConstants { /** Cache name used for all instances of {@link SparseDistributedMatrixStorage}.*/ public static final String ML_CACHE_NAME = "ML_SPARSE_MATRICES_CONTAINER"; /** Amount of rows in the matrix. */ private int rows; /** Amount of columns in the matrix. */ private int cols; /** Row or column based storage mode. */ private int stoMode; /** Random or sequential access mode. */ private int acsMode; /** Matrix uuid. */ private IgniteUuid uuid; /** Actual distributed storage. */ private IgniteCache< IgniteBiTuple<Integer, IgniteUuid> /* Row or column index with matrix uuid. */, Map<Integer, Double> /* Map-based row or column. */ > cache = null; /** * */ public SparseDistributedMatrixStorage() { // No-op. } /** * @param rows Amount of rows in the matrix. * @param cols Amount of columns in the matrix. * @param stoMode Row or column based storage mode. * @param acsMode Random or sequential access mode. */ public SparseDistributedMatrixStorage(int rows, int cols, int stoMode, int acsMode) { assert rows > 0; assert cols > 0; assertAccessMode(acsMode); assertStorageMode(stoMode); this.rows = rows; this.cols = cols; this.stoMode = stoMode; this.acsMode = acsMode; cache = newCache(); uuid = IgniteUuid.randomUuid(); } /** * Create new ML cache if needed. */ private IgniteCache<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> newCache() { CacheConfiguration<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> cfg = new CacheConfiguration<>(); // Write to primary. cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); // Atomic transactions only. cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); // No eviction. cfg.setEvictionPolicy(null); // No copying of values. cfg.setCopyOnRead(false); // Cache is partitioned. cfg.setCacheMode(CacheMode.PARTITIONED); // Random cache name. cfg.setName(ML_CACHE_NAME); IgniteCache<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> cache = Ignition.localIgnite().getOrCreateCache(cfg); return cache; } /** * * */ public IgniteCache<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> cache() { return cache; } /** * * */ public int accessMode() { return acsMode; } /** * * */ public int storageMode() { return stoMode; } /** {@inheritDoc} */ @Override public double get(int x, int y) { if (stoMode == ROW_STORAGE_MODE) return matrixGet(x, y); else return matrixGet(y, x); } /** {@inheritDoc} */ @Override public void set(int x, int y, double v) { if (stoMode == ROW_STORAGE_MODE) matrixSet(x, y, v); else matrixSet(y, x, v); } /** * Distributed matrix get. * * @param a Row or column index. * @param b Row or column index. * @return Matrix value at (a, b) index. */ private double matrixGet(int a, int b) { // Remote get from the primary node (where given row or column is stored locally). return ignite().compute(groupForKey(ML_CACHE_NAME, a)).call(() -> { IgniteCache<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> cache = Ignition.localIgnite().getOrCreateCache(ML_CACHE_NAME); // Local get. Map<Integer, Double> map = cache.localPeek(getCacheKey(a), CachePeekMode.PRIMARY); if (map == null) map = cache.get(getCacheKey(a)); return (map == null || !map.containsKey(b)) ? 0.0 : map.get(b); }); } /** * Distributed matrix set. * * @param a Row or column index. * @param b Row or column index. * @param v New value to set. */ private void matrixSet(int a, int b, double v) { // Remote set on the primary node (where given row or column is stored locally). ignite().compute(groupForKey(ML_CACHE_NAME, a)).run(() -> { IgniteCache<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> cache = Ignition.localIgnite().getOrCreateCache(ML_CACHE_NAME); // Local get. Map<Integer, Double> map = cache.localPeek(getCacheKey(a), CachePeekMode.PRIMARY); if (map == null) { map = cache.get(getCacheKey(a)); //Remote entry get. if (map == null) map = acsMode == SEQUENTIAL_ACCESS_MODE ? new Int2DoubleRBTreeMap() : new Int2DoubleOpenHashMap(); } if (v != 0.0) map.put(b, v); else if (map.containsKey(b)) map.remove(b); // Local put. cache.put(getCacheKey(a), map); }); } /** Build cache key for row/column. */ private IgniteBiTuple<Integer, IgniteUuid> getCacheKey(int idx){ return new IgniteBiTuple<>(idx, uuid); } /** {@inheritDoc} */ @Override public int columnSize() { return cols; } /** {@inheritDoc} */ @Override public int rowSize() { return rows; } /** {@inheritDoc} */ @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(rows); out.writeInt(cols); out.writeInt(acsMode); out.writeInt(stoMode); out.writeObject(uuid); out.writeUTF(cache.getName()); } /** {@inheritDoc} */ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { rows = in.readInt(); cols = in.readInt(); acsMode = in.readInt(); stoMode = in.readInt(); uuid = (IgniteUuid)in.readObject(); cache = ignite().getOrCreateCache(in.readUTF()); } /** {@inheritDoc} */ @Override public boolean isSequentialAccess() { return acsMode == SEQUENTIAL_ACCESS_MODE; } /** {@inheritDoc} */ @Override public boolean isDense() { return false; } /** {@inheritDoc} */ @Override public boolean isRandomAccess() { return acsMode == RANDOM_ACCESS_MODE; } /** {@inheritDoc} */ @Override public boolean isDistributed() { return true; } /** {@inheritDoc} */ @Override public boolean isArrayBased() { return false; } /** Delete all data from cache. */ @Override public void destroy() { Set<IgniteBiTuple<Integer, IgniteUuid>> keyset = IntStream.range(0, rows).mapToObj(this::getCacheKey).collect(Collectors.toSet()); cache.clearAll(keyset); } /** {@inheritDoc} */ @Override public int hashCode() { int res = 1; res = res * 37 + cols; res = res * 37 + rows; res = res * 37 + acsMode; res = res * 37 + stoMode; res = res * 37 + uuid.hashCode(); res = res * 37 + cache.hashCode(); return res; } /** {@inheritDoc} */ @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; SparseDistributedMatrixStorage that = (SparseDistributedMatrixStorage)obj; return rows == that.rows && cols == that.cols && acsMode == that.acsMode && stoMode == that.stoMode && uuid.equals(that.uuid) && (cache != null ? cache.equals(that.cache) : that.cache == null); } /** */ public IgniteUuid getUUID() { return uuid; } }