/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.ml.math.impls.vector; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.NoSuchElementException; import java.util.Spliterator; import java.util.function.Consumer; import java.util.function.IntToDoubleFunction; import org.apache.ignite.lang.IgniteUuid; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.VectorStorage; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.exceptions.IndexException; import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; import org.apache.ignite.ml.math.functions.Functions; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteDoubleFunction; import org.apache.ignite.ml.math.impls.matrix.MatrixView; import org.jetbrains.annotations.NotNull; /** * This class provides a helper implementation of the {@link Vector} * interface to minimize the effort required to implement it. * Subclasses may override some of the implemented methods if a more * specific or optimized implementation is desirable. */ public abstract class AbstractVector implements Vector { /** Vector storage implementation. */ private VectorStorage sto; /** Meta attribute storage. */ private Map<String, Object> meta = new HashMap<>(); /** Vector's GUID. */ private IgniteUuid guid = IgniteUuid.randomUuid(); /** Cached value for length squared. */ private double lenSq = 0.0; /** Maximum cached element. */ private Element maxElm = null; /** Minimum cached element. */ private Element minElm = null; /** Readonly flag (false by default). */ private boolean readOnly = false; /** Read-only error message. */ private static final String RO_MSG = "Vector is read-only."; /** */ private void ensureReadOnly() { if (readOnly) throw new UnsupportedOperationException(RO_MSG); } /** * @param sto Storage. */ public AbstractVector(VectorStorage sto) { this(false, sto); } /** * @param readOnly Is read only. * @param sto Storage. */ public AbstractVector(boolean readOnly, VectorStorage sto) { assert sto != null; this.readOnly = readOnly; this.sto = sto; } /** * */ public AbstractVector() { // No-op. } /** * Set storage. * * @param sto Storage. */ protected void setStorage(VectorStorage sto) { this.sto = sto; } /** * @param i Index. * @param v Value. */ protected void storageSet(int i, double v) { ensureReadOnly(); sto.set(i, v); // Reset cached values. lenSq = 0.0; maxElm = minElm = null; } /** * @param i Index. * @return Value. */ protected double storageGet(int i) { return sto.get(i); } /** {@inheritDoc} */ @Override public int size() { return sto.size(); } /** * Check index bounds. * * @param idx Index to check. */ protected void checkIndex(int idx) { if (idx < 0 || idx >= sto.size()) throw new IndexException(idx); } /** {@inheritDoc} */ @Override public double get(int idx) { checkIndex(idx); return storageGet(idx); } /** {@inheritDoc} */ @Override public double getX(int idx) { return storageGet(idx); } /** {@inheritDoc} */ @Override public boolean isArrayBased() { return sto.isArrayBased(); } /** {@inheritDoc} */ @Override public Vector sort() { if (isArrayBased()) Arrays.parallelSort(sto.data()); else throw new UnsupportedOperationException(); return this; } /** {@inheritDoc} */ @Override public Vector map(IgniteDoubleFunction<Double> fun) { if (sto.isArrayBased()) { double[] data = sto.data(); Arrays.setAll(data, (idx) -> fun.apply(data[idx])); } else { int len = size(); for (int i = 0; i < len; i++) storageSet(i, fun.apply(storageGet(i))); } return this; } /** {@inheritDoc} */ @Override public Vector map(Vector vec, IgniteBiFunction<Double, Double, Double> fun) { checkCardinality(vec); int len = size(); for (int i = 0; i < len; i++) storageSet(i, fun.apply(storageGet(i), vec.get(i))); return this; } /** {@inheritDoc} */ @Override public Vector map(IgniteBiFunction<Double, Double, Double> fun, double y) { int len = size(); for (int i = 0; i < len; i++) storageSet(i, fun.apply(storageGet(i), y)); return this; } /** * @param idx Index. * @return Value. */ protected Element makeElement(int idx) { checkIndex(idx); return new Element() { /** {@inheritDoc} */ @Override public double get() { return storageGet(idx); } /** {@inheritDoc} */ @Override public int index() { return idx; } /** {@inheritDoc} */ @Override public void set(double val) { storageSet(idx, val); } }; } /** {@inheritDoc} */ @Override public Element minElement() { if (minElm == null) { int minIdx = 0; int len = size(); for (int i = 0; i < len; i++) if (storageGet(i) < storageGet(minIdx)) minIdx = i; minElm = makeElement(minIdx); } return minElm; } /** {@inheritDoc} */ @Override public Element maxElement() { if (maxElm == null) { int maxIdx = 0; int len = size(); for (int i = 0; i < len; i++) if (storageGet(i) > storageGet(maxIdx)) maxIdx = i; maxElm = makeElement(maxIdx); } return maxElm; } /** {@inheritDoc} */ @Override public double minValue() { return minElement().get(); } /** {@inheritDoc} */ @Override public double maxValue() { return maxElement().get(); } /** {@inheritDoc} */ @Override public Vector set(int idx, double val) { checkIndex(idx); storageSet(idx, val); return this; } /** {@inheritDoc} */ @Override public Vector setX(int idx, double val) { storageSet(idx, val); return this; } /** {@inheritDoc} */ @Override public Vector increment(int idx, double val) { checkIndex(idx); storageSet(idx, storageGet(idx) + val); return this; } /** {@inheritDoc} */ @Override public Vector incrementX(int idx, double val) { storageSet(idx, storageGet(idx) + val); return this; } /** * Tests if given value is considered a zero value. * * @param val Value to check. */ protected boolean isZero(double val) { return val == 0.0; } /** {@inheritDoc} */ @Override public double sum() { double sum = 0; int len = size(); for (int i = 0; i < len; i++) sum += storageGet(i); return sum; } /** {@inheritDoc} */ @Override public IgniteUuid guid() { return guid; } /** {@inheritDoc} */ @Override public Iterable<Element> all() { return new Iterable<Element>() { private int idx = 0; /** {@inheritDoc} */ @NotNull @Override public Iterator<Element> iterator() { return new Iterator<Element>() { /** {@inheritDoc} */ @Override public boolean hasNext() { return size() > 0 && idx < size(); } /** {@inheritDoc} */ @Override public Element next() { if (hasNext()) return getElement(idx++); throw new NoSuchElementException(); } }; } }; } /** {@inheritDoc} */ @Override public int nonZeroElements() { int cnt = 0; for (Element ignored : nonZeroes()) cnt++; return cnt; } /** {@inheritDoc} */ @Override public <T> T foldMap(IgniteBiFunction<T, Double, T> foldFun, IgniteDoubleFunction<Double> mapFun, T zeroVal) { T res = zeroVal; int len = size(); for (int i = 0; i < len; i++) res = foldFun.apply(res, mapFun.apply(storageGet(i))); return res; } /** {@inheritDoc} */ @Override public <T> T foldMap(Vector vec, IgniteBiFunction<T, Double, T> foldFun, IgniteBiFunction<Double, Double, Double> combFun, T zeroVal) { checkCardinality(vec); T res = zeroVal; int len = size(); for (int i = 0; i < len; i++) res = foldFun.apply(res, combFun.apply(storageGet(i), vec.getX(i))); return res; } /** {@inheritDoc} */ @Override public Iterable<Element> nonZeroes() { return new Iterable<Element>() { private int idx = 0; private int idxNext = -1; /** {@inheritDoc} */ @NotNull @Override public Iterator<Element> iterator() { return new Iterator<Element>() { @Override public boolean hasNext() { findNext(); return !over(); } @Override public Element next() { if (hasNext()) { idx = idxNext; return getElement(idxNext); } throw new NoSuchElementException(); } private void findNext() { if (over()) return; if (idxNextInitialized() && idx != idxNext) return; if (idxNextInitialized()) idx = idxNext + 1; while (idx < size() && isZero(get(idx))) idx++; idxNext = idx++; } private boolean over() { return idxNext >= size(); } private boolean idxNextInitialized() { return idxNext != -1; } }; } }; } /** {@inheritDoc} */ @Override public Map<String, Object> getMetaStorage() { return meta; } /** {@inheritDoc} */ @Override public Vector assign(double val) { if (sto.isArrayBased()) { ensureReadOnly(); Arrays.fill(sto.data(), val); } else { int len = size(); for (int i = 0; i < len; i++) storageSet(i, val); } return this; } /** {@inheritDoc} */ @Override public Vector assign(double[] vals) { checkCardinality(vals); if (sto.isArrayBased()) { ensureReadOnly(); System.arraycopy(vals, 0, sto.data(), 0, vals.length); lenSq = 0.0; } else { int len = size(); for (int i = 0; i < len; i++) storageSet(i, vals[i]); } return this; } /** {@inheritDoc} */ @Override public Vector assign(Vector vec) { checkCardinality(vec); for (Vector.Element x : vec.all()) storageSet(x.index(), x.get()); return this; } /** {@inheritDoc} */ @Override public Vector assign(IntToDoubleFunction fun) { assert fun != null; if (sto.isArrayBased()) { ensureReadOnly(); Arrays.setAll(sto.data(), fun); } else { int len = size(); for (int i = 0; i < len; i++) storageSet(i, fun.applyAsDouble(i)); } return this; } /** {@inheritDoc} */ @Override public Spliterator<Double> allSpliterator() { return new Spliterator<Double>() { /** {@inheritDoc} */ @Override public boolean tryAdvance(Consumer<? super Double> act) { int len = size(); for (int i = 0; i < len; i++) act.accept(storageGet(i)); return true; } /** {@inheritDoc} */ @Override public Spliterator<Double> trySplit() { return null; // No Splitting. } /** {@inheritDoc} */ @Override public long estimateSize() { return size(); } /** {@inheritDoc} */ @Override public int characteristics() { return ORDERED | SIZED; } }; } /** {@inheritDoc} */ @Override public Spliterator<Double> nonZeroSpliterator() { return new Spliterator<Double>() { /** {@inheritDoc} */ @Override public boolean tryAdvance(Consumer<? super Double> act) { int len = size(); for (int i = 0; i < len; i++) { double val = storageGet(i); if (!isZero(val)) act.accept(val); } return true; } /** {@inheritDoc} */ @Override public Spliterator<Double> trySplit() { return null; // No Splitting. } /** {@inheritDoc} */ @Override public long estimateSize() { return nonZeroElements(); } /** {@inheritDoc} */ @Override public int characteristics() { return ORDERED | SIZED; } }; } /** {@inheritDoc} */ @Override public double dot(Vector vec) { checkCardinality(vec); double sum = 0.0; int len = size(); for (int i = 0; i < len; i++) sum += storageGet(i) * vec.getX(i); return sum; } /** {@inheritDoc} */ @Override public double getLengthSquared() { if (lenSq == 0.0) lenSq = dotSelf(); return lenSq; } /** {@inheritDoc} */ @Override public boolean isDense() { return sto.isDense(); } /** {@inheritDoc} */ @Override public boolean isSequentialAccess() { return sto.isSequentialAccess(); } /** {@inheritDoc} */ @Override public boolean isRandomAccess() { return sto.isRandomAccess(); } /** {@inheritDoc} */ @Override public boolean isDistributed() { return sto.isDistributed(); } /** {@inheritDoc} */ @Override public VectorStorage getStorage() { return sto; } /** {@inheritDoc} */ @Override public Vector viewPart(int off, int len) { return new VectorView(this, off, len); } /** {@inheritDoc} */ @Override public Matrix cross(Vector vec) { Matrix res = likeMatrix(size(), vec.size()); if (res == null) return null; for (Element e : nonZeroes()) { int row = e.index(); res.assignRow(row, vec.times(getX(row))); } return res; } /** {@inheritDoc} */ @Override public Matrix toMatrix(boolean rowLike) { Matrix res = likeMatrix(rowLike ? 1 : size(), rowLike ? size() : 1); if (res == null) return null; if (rowLike) res.assignRow(0, this); else res.assignColumn(0, this); return res; } /** {@inheritDoc} */ @Override public Matrix toMatrixPlusOne(boolean rowLike, double zeroVal) { Matrix res = likeMatrix(rowLike ? 1 : size() + 1, rowLike ? size() + 1 : 1); if (res == null) return null; res.set(0, 0, zeroVal); if (rowLike) new MatrixView(res, 0, 1, 1, size()).assignRow(0, this); else new MatrixView(res, 1, 0, size(), 1).assignColumn(0, this); return res; } /** {@inheritDoc} */ @Override public double getDistanceSquared(Vector vec) { checkCardinality(vec); double thisLenSq = getLengthSquared(); double thatLenSq = vec.getLengthSquared(); double dot = dot(vec); double distEst = thisLenSq + thatLenSq - 2 * dot; if (distEst > 1.0e-3 * (thisLenSq + thatLenSq)) // The vectors are far enough from each other that the formula is accurate. return Math.max(distEst, 0); else return foldMap(vec, Functions.PLUS, Functions.MINUS_SQUARED, 0d); } /** * @param vec Vector to check for valid cardinality. */ protected void checkCardinality(Vector vec) { if (vec.size() != size()) throw new CardinalityException(size(), vec.size()); } /** * @param vec Array to check for valid cardinality. */ protected void checkCardinality(double[] vec) { if (vec.length != size()) throw new CardinalityException(size(), vec.length); } /** * @param arr Array to check for valid cardinality. */ protected void checkCardinality(int[] arr) { if (arr.length != size()) throw new CardinalityException(size(), arr.length); } /** {@inheritDoc} */ @Override public Vector minus(Vector vec) { checkCardinality(vec); Vector cp = copy(); return cp.map(vec, Functions.MINUS); } /** {@inheritDoc} */ @Override public Vector plus(double x) { Vector cp = copy(); return x != 0.0 ? cp.map(Functions.plus(x)) : cp; } /** {@inheritDoc} */ @Override public Vector divide(double x) { Vector cp = copy(); if (x != 1.0) for (Element element : cp.all()) element.set(element.get() / x); return cp; } /** {@inheritDoc} */ @Override public Vector times(double x) { if (x == 0.0) return like(size()); else return copy().map(Functions.mult(x)); } /** {@inheritDoc} */ @Override public Vector times(Vector vec) { checkCardinality(vec); return copy().map(vec, Functions.MULT); } /** {@inheritDoc} */ @Override public Vector plus(Vector vec) { checkCardinality(vec); Vector cp = copy(); return cp.map(vec, Functions.PLUS); } /** {@inheritDoc} */ @Override public Vector logNormalize() { return logNormalize(2.0, Math.sqrt(getLengthSquared())); } /** {@inheritDoc} */ @Override public Vector logNormalize(double power) { return logNormalize(power, kNorm(power)); } /** * @param power Power. * @param normLen Normalized length. * @return logNormalized value. */ private Vector logNormalize(double power, double normLen) { assert !(Double.isInfinite(power) || power <= 1.0); double denominator = normLen * Math.log(power); Vector cp = copy(); for (Element element : cp.all()) element.set(Math.log1p(element.get()) / denominator); return cp; } /** {@inheritDoc} */ @Override public double kNorm(double power) { assert power >= 0.0; // Special cases. if (Double.isInfinite(power)) return foldMap(Math::max, Math::abs, 0d); else if (power == 2.0) return Math.sqrt(getLengthSquared()); else if (power == 1.0) return foldMap(Functions.PLUS, Math::abs, 0d); else if (power == 0.0) return nonZeroElements(); else // Default case. return Math.pow(foldMap(Functions.PLUS, Functions.pow(power), 0d), 1.0 / power); } /** {@inheritDoc} */ @Override public Vector normalize() { return divide(Math.sqrt(getLengthSquared())); } /** {@inheritDoc} */ @Override public Vector normalize(double power) { return divide(kNorm(power)); } /** {@inheritDoc} */ @Override public Vector copy() { return like(size()).assign(this); } /** * @return Result of dot with self. */ protected double dotSelf() { double sum = 0.0; int len = size(); for (int i = 0; i < len; i++) { double v = storageGet(i); sum += v * v; } return sum; } /** {@inheritDoc} */ @Override public Element getElement(int idx) { return makeElement(idx); } /** {@inheritDoc} */ @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(sto); out.writeObject(meta); out.writeObject(guid); out.writeBoolean(readOnly); } /** {@inheritDoc} */ @SuppressWarnings("unchecked") @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { sto = (VectorStorage)in.readObject(); meta = (Map<String, Object>)in.readObject(); guid = (IgniteUuid)in.readObject(); readOnly = in.readBoolean(); } /** {@inheritDoc} */ @Override public void destroy() { sto.destroy(); } /** {@inheritDoc} */ @Override public int hashCode() { int res = 1; res += res * 37 + guid.hashCode(); res += sto == null ? 0 : res * 37 + sto.hashCode(); return res; } /** {@inheritDoc} */ @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; AbstractVector that = (AbstractVector)obj; return (sto != null ? sto.equals(that.sto) : that.sto == null); } }