/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math.impls;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.Affinity;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.cluster.ClusterGroup;
import org.apache.ignite.cluster.ClusterNode;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.lang.IgniteCallable;
import org.apache.ignite.lang.IgnitePredicate;
import org.apache.ignite.lang.IgniteRunnable;
import org.apache.ignite.lang.IgniteUuid;
import org.apache.ignite.ml.math.KeyMapper;
import org.apache.ignite.ml.math.ValueMapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteConsumer;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage;
/**
* Distribution-related misc. support.
*/
public class CacheUtils {
/**
* Cache entry support.
*
* @param <K>
* @param <V>
*/
public static class CacheEntry<K, V> {
/** */
private Cache.Entry<K, V> entry;
/** */
private IgniteCache<K, V> cache;
/**
* @param entry Original cache entry.
* @param cache Cache instance.
*/
CacheEntry(Cache.Entry<K, V> entry, IgniteCache<K, V> cache) {
this.entry = entry;
this.cache = cache;
}
/**
*
*
*/
public Cache.Entry<K, V> entry() {
return entry;
}
/**
*
*
*/
public IgniteCache<K, V> cache() {
return cache;
}
}
/**
* Gets local Ignite instance.
*/
public static Ignite ignite() {
return Ignition.localIgnite();
}
/**
* @param cacheName Cache name.
* @param k Key into the cache.
* @param <K> Key type.
* @return Cluster group for given key.
*/
public static <K> ClusterGroup groupForKey(String cacheName, K k) {
return ignite().cluster().forNode(ignite().affinity(cacheName).mapKeyToNode(k));
}
/**
* @param cacheName Cache name.
* @param keyMapper {@link KeyMapper} to validate cache key.
* @param valMapper {@link ValueMapper} to obtain double value for given cache key.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
* @return Sum of the values obtained for valid keys.
*/
public static <K, V> double sum(String cacheName, KeyMapper<K> keyMapper, ValueMapper<V> valMapper) {
Collection<Double> subSums = fold(cacheName, (CacheEntry<K, V> ce, Double acc) -> {
if (keyMapper.isValid(ce.entry().getKey())) {
double v = valMapper.toDouble(ce.entry().getValue());
return acc == null ? v : acc + v;
}
else
return acc;
});
return sum(subSums);
}
/**
* @param matrixUuid Matrix UUID.
* @return Sum obtained using sparse logic.
*/
public static <K, V> double sparseSum(IgniteUuid matrixUuid) {
Collection<Double> subSums = fold(SparseDistributedMatrixStorage.ML_CACHE_NAME, (CacheEntry<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> ce, Double acc) -> {
Cache.Entry<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> entry = ce.entry();
if (entry.getKey().get2().equals(matrixUuid)) {
Map<Integer, Double> map = entry.getValue();
double sum = sum(map.values());
return acc == null ? sum : acc + sum;
} else
return acc;
}, key -> key.get2().equals(matrixUuid));
return sum(subSums);
}
/**
* @param c {@link Collection} of double values to sum.
* @return Sum of the values.
*/
private static double sum(Collection<Double> c) {
double sum = 0.0;
for (double d : c)
sum += d;
return sum;
}
/**
* @param cacheName Cache name.
* @param keyMapper {@link KeyMapper} to validate cache key.
* @param valMapper {@link ValueMapper} to obtain double value for given cache key.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
* @return Minimum value for valid keys.
*/
public static <K, V> double min(String cacheName, KeyMapper<K> keyMapper, ValueMapper<V> valMapper) {
Collection<Double> mins = fold(cacheName, (CacheEntry<K, V> ce, Double acc) -> {
if (keyMapper.isValid(ce.entry().getKey())) {
double v = valMapper.toDouble(ce.entry().getValue());
if (acc == null)
return v;
else
return Math.min(acc, v);
}
else
return acc;
});
return Collections.min(mins);
}
/**
* @param matrixUuid Matrix UUID.
* @return Minimum value obtained using sparse logic.
*/
public static <K, V> double sparseMin(IgniteUuid matrixUuid) {
Collection<Double> mins = fold(SparseDistributedMatrixStorage.ML_CACHE_NAME, (CacheEntry<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> ce, Double acc) -> {
Cache.Entry<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> entry = ce.entry();
if (entry.getKey().get2().equals(matrixUuid)) {
Map<Integer, Double> map = entry.getValue();
double min = Collections.min(map.values());
if (acc == null)
return min;
else
return Math.min(acc, min);
} else
return acc;
}, key -> key.get2().equals(matrixUuid));
return Collections.min(mins);
}
/**
* @param matrixUuid Matrix UUID.
* @return Maximum value obtained using sparse logic.
*/
public static <K, V> double sparseMax(IgniteUuid matrixUuid) {
Collection<Double> maxes = fold(SparseDistributedMatrixStorage.ML_CACHE_NAME, (CacheEntry<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> ce, Double acc) -> {
Cache.Entry<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> entry = ce.entry();
if (entry.getKey().get2().equals(matrixUuid)) {
Map<Integer, Double> map = entry.getValue();
double max = Collections.max(map.values());
if (acc == null)
return max;
else
return Math.max(acc, max);
} else
return acc;
}, key -> key.get2().equals(matrixUuid));
return Collections.max(maxes);
}
/**
* @param cacheName Cache name.
* @param keyMapper {@link KeyMapper} to validate cache key.
* @param valMapper {@link ValueMapper} to obtain double value for given cache key.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
* @return Maximum value for valid keys.
*/
public static <K, V> double max(String cacheName, KeyMapper<K> keyMapper, ValueMapper<V> valMapper) {
Collection<Double> maxes = fold(cacheName, (CacheEntry<K, V> ce, Double acc) -> {
if (keyMapper.isValid(ce.entry().getKey())) {
double v = valMapper.toDouble(ce.entry().getValue());
if (acc == null)
return v;
else
return Math.max(acc, v);
}
else
return acc;
});
return Collections.max(maxes);
}
/**
* @param cacheName Cache name.
* @param keyMapper {@link KeyMapper} to validate cache key.
* @param valMapper {@link ValueMapper} to obtain double value for given cache key.
* @param mapper Mapping {@link IgniteFunction}.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
*/
public static <K, V> void map(String cacheName, KeyMapper<K> keyMapper, ValueMapper<V> valMapper,
IgniteFunction<Double, Double> mapper) {
foreach(cacheName, (CacheEntry<K, V> ce) -> {
K k = ce.entry().getKey();
if (keyMapper.isValid(k))
// Actual assignment.
ce.cache().put(k, valMapper.fromDouble(mapper.apply(valMapper.toDouble(ce.entry().getValue()))));
});
}
/**
* @param matrixUuid Matrix UUID.
* @param mapper Mapping {@link IgniteFunction}.
*/
public static <K, V> void sparseMap(IgniteUuid matrixUuid, IgniteFunction<Double, Double> mapper) {
foreach(SparseDistributedMatrixStorage.ML_CACHE_NAME, (CacheEntry<IgniteBiTuple<Integer, IgniteUuid>, Map<Integer, Double>> ce) -> {
IgniteBiTuple k = ce.entry().getKey();
Map<Integer, Double> v = ce.entry().getValue();
for (Map.Entry<Integer, Double> e : v.entrySet())
e.setValue(mapper.apply(e.getValue()));
ce.cache().put(k, v);
}, key -> key.get2().equals(matrixUuid));
}
/**
* @param cacheName Cache name.
* @param fun An operation that accepts a cache entry and processes it.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
*/
public static <K, V> void foreach(String cacheName, IgniteConsumer<CacheEntry<K, V>> fun) {
foreach(cacheName, fun, null);
}
/**
* @param cacheName Cache name.
* @param fun An operation that accepts a cache entry and processes it.
* @param keyFilter Cache keys filter.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
*/
public static <K, V> void foreach(String cacheName, IgniteConsumer<CacheEntry<K, V>> fun, IgnitePredicate<K> keyFilter) {
bcast(cacheName, () -> {
Ignite ignite = Ignition.localIgnite();
IgniteCache<K, V> cache = ignite.getOrCreateCache(cacheName);
int partsCnt = ignite.affinity(cacheName).partitions();
// Use affinity in filter for scan query. Otherwise we accept consumer in each node which is wrong.
Affinity affinity = ignite.affinity(cacheName);
ClusterNode locNode = ignite.cluster().localNode();
// Iterate over all partitions. Some of them will be stored on that local node.
for (int part = 0; part < partsCnt; part++) {
int p = part;
// Iterate over given partition.
// Query returns an empty cursor if this partition is not stored on this node.
for (Cache.Entry<K, V> entry : cache.query(new ScanQuery<K, V>(part,
(k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
fun.accept(new CacheEntry<>(entry, cache));
}
});
}
/**
* <b>Currently fold supports only commutative operations.<b/>
*
* @param cacheName Cache name.
* @param folder Fold function operating over cache entries.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
* @param <A> Fold result type.
* @return Fold operation result.
*/
public static <K, V, A> Collection<A> fold(String cacheName, IgniteBiFunction<CacheEntry<K, V>, A, A> folder) {
return fold(cacheName, folder, null);
}
/**
* <b>Currently fold supports only commutative operations.<b/>
*
* @param cacheName Cache name.
* @param folder Fold function operating over cache entries.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
* @param <A> Fold result type.
* @return Fold operation result.
*/
public static <K, V, A> Collection<A> fold(String cacheName, IgniteBiFunction<CacheEntry<K, V>, A, A> folder, IgnitePredicate<K> keyFilter) {
return bcast(cacheName, () -> {
Ignite ignite = Ignition.localIgnite();
IgniteCache<K, V> cache = ignite.getOrCreateCache(cacheName);
int partsCnt = ignite.affinity(cacheName).partitions();
// Use affinity in filter for ScanQuery. Otherwise we accept consumer in each node which is wrong.
Affinity affinity = ignite.affinity(cacheName);
ClusterNode locNode = ignite.cluster().localNode();
A a = null;
// Iterate over all partitions. Some of them will be stored on that local node.
for (int part = 0; part < partsCnt; part++) {
int p = part;
// Iterate over given partition.
// Query returns an empty cursor if this partition is not stored on this node.
for (Cache.Entry<K, V> entry : cache.query(new ScanQuery<K, V>(part,
(k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
a = folder.apply(new CacheEntry<>(entry, cache), a);
}
return a;
});
}
/**
* @param cacheName Cache name.
* @param run {@link Runnable} to broadcast to cache nodes for given cache name.
*/
public static void bcast(String cacheName, IgniteRunnable run) {
ignite().compute(ignite().cluster().forCacheNodes(cacheName)).broadcast(run);
}
/**
* @param cacheName Cache name.
* @param call {@link IgniteCallable} to broadcast to cache nodes for given cache name.
* @param <A> Type returned by the callable.
*/
public static <A> Collection<A> bcast(String cacheName, IgniteCallable<A> call) {
return ignite().compute(ignite().cluster().forCacheNodes(cacheName)).broadcast(call);
}
}