/* * Concept profile generation tool suite * Copyright (C) 2015 Biosemantics Group, Erasmus University Medical Center, * Rotterdam, The Netherlands * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published * by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/> */ package org.erasmusmc.math.vector; import java.io.Serializable; import java.util.Iterator; import org.erasmusmc.collections.MapCursor; import org.erasmusmc.collections.SortedIntList2FloatMap; import org.erasmusmc.collections.SortedIntList2FloatMap.MapEntry; import org.erasmusmc.math.space.IntegerSpace; import org.erasmusmc.math.space.Space; public class SparseVectorInt2Float extends Vector<Integer> implements Serializable { private static final long serialVersionUID = 2830857831592496295L; transient public SortedIntList2FloatMap values; public SparseVectorInt2Float() { values = new SortedIntList2FloatMap(); } public SparseVectorInt2Float(SortedIntList2FloatMap map) { values = map; } public SparseVectorInt2Float(Vector<Integer> vector) { values = new SortedIntList2FloatMap(); set(vector); } public SparseVectorInt2Float sparseElementWiseSparseInnerProduct(SparseVectorInt2Float other) { SortedIntList2FloatMap shorter = other.values; SortedIntList2FloatMap longer = values; SortedIntList2FloatMap resultmap = new SortedIntList2FloatMap(shorter.size() + longer.size()); if (shorter.size() != 0 && longer.size() != 0){ if (shorter.size() > longer.size()) { shorter = longer; longer = other.values; } int longerlowestIndex = 0; int longerhighestIndex = longer.size() - 1; int longerlowest = longer.getKey(longerlowestIndex); int shorterlowestIndex = 0; int shorterhighestIndex = shorter.size() - 1; while (shorterlowestIndex <= shorterhighestIndex) { int key = shorter.getKey(shorterlowestIndex); if (key >= longerlowest) { int index = longer.guidedGetIndexForKey(key, longerlowestIndex, longerhighestIndex + 1); if (index < longer.size()) { longerlowestIndex = index; if (longer.getKey(index) == key) { float product = longer.getValue(index) * shorter.getValue(shorterlowestIndex); resultmap.addEntry(key, product); if (longerlowestIndex < longerhighestIndex - 1) longerlowestIndex++; } longerlowest = longer.getKey(longerlowestIndex); } } shorterlowestIndex++; } } return new SparseVectorInt2Float(resultmap); } /** * Calculates cityblock metric (sum of absolute differences between components of the vectors) * @param other * @return Cityblock */ public double cityBlock(SparseVectorInt2Float other) { Iterator<MapEntry> iterator1 = values.entryIterator(); Iterator<MapEntry> iterator2 = other.values.entryIterator(); double score = 0; MapEntry buffer1 = null; MapEntry buffer2 = null; if (iterator1.hasNext()) buffer1 = iterator1.next(); if (iterator2.hasNext()) buffer2 = iterator2.next(); while (buffer1 != null && buffer2 != null){ if (buffer1.getKey() == buffer2.getKey()){ score += Math.abs(buffer1.getValue() - buffer2.getValue()); if (iterator1.hasNext()) buffer1 = iterator1.next(); else buffer1 = null; if (iterator2.hasNext()) buffer2 = iterator2.next(); else buffer2 = null; } else if (buffer1.getKey() > buffer2.getKey()) { score += Math.abs(buffer2.getValue()); if (iterator2.hasNext()) buffer2 = iterator2.next(); else buffer2 = null; } else { score += Math.abs(buffer1.getValue()); if (iterator1.hasNext()) buffer1 = iterator1.next(); else buffer1 = null; } } if (buffer1 != null){ score += Math.abs(buffer1.getValue()); while (iterator1.hasNext()) score += Math.abs(iterator1.next().getValue()); } if (buffer2 != null){ score += Math.abs(buffer2.getValue()); while (iterator2.hasNext()) score += Math.abs(iterator2.next().getValue()); } return score; } public double sparseInnerProduct(SparseVectorInt2Float other) { SortedIntList2FloatMap shorter = other.values; SortedIntList2FloatMap longer = values; if (shorter.size() > longer.size()) { shorter = longer; longer = other.values; } double innerproduct = 0f; if (shorter.size() > 0) { int longerlowestIndex = 0; int longerhighestIndex = longer.size() - 1; int longerlowest = longer.getKey(longerlowestIndex); int longerhighest = longer.getKey(longerhighestIndex); int shorterlowestIndex = 0; int shorterhighestIndex = shorter.size() - 1; while (shorterlowestIndex <= shorterhighestIndex) { int key = shorter.getKey(shorterlowestIndex); if (key >= longerlowest) { int index = longer.guidedGetIndexForKey(key, longerlowestIndex, longerhighestIndex + 1); if (index < longer.size()) { longerlowestIndex = index; if (longer.getKey(index) == key) { innerproduct += longer.getValue(index) * shorter.getValue(shorterlowestIndex); if (longerlowestIndex < longerhighestIndex - 1) longerlowestIndex++; } longerlowest = longer.getKey(longerlowestIndex); } } shorterlowestIndex++; if (shorterlowestIndex < shorterhighestIndex) { key = shorter.getKey(shorterhighestIndex); if (key <= longerhighest) { int index = longer.guidedGetIndexForKey(key, longerlowestIndex, longerhighestIndex + 1); if (index < longer.size()) { longerhighestIndex = index; if (longer.getKey(index) == key) { innerproduct += longer.getValue(index) * shorter.getValue(shorterhighestIndex); if (longerlowestIndex < longerhighestIndex - 1) longerhighestIndex--; } longerhighest = longer.getKey(longerhighestIndex); } } shorterhighestIndex--; } } } return innerproduct; } /** * superfancy high performance cosine function. Whips Any ass up to now. * * @param other * @return */ public double sparseCosine(SparseVectorInt2Float other) { double result = 0; double denominator = norm() * other.norm(); if (denominator > 0) { double innerproduct = sparseInnerProduct(other); result = innerproduct / denominator; } return result; } public double jaccard(SparseVectorInt2Float other) { double result = 0; double innerproduct = sparseInnerProduct(other); double denominator = getSquaredNorm() + other.getSquaredNorm() - innerproduct; if (denominator > 0) { result = innerproduct / denominator; } return result; } public double dice(SparseVectorInt2Float other) { double result = 0; double innerproduct = sparseInnerProduct(other); double denominator = getSquaredNorm() + other.getSquaredNorm(); if (denominator > 0) { result = 2d * innerproduct / denominator; } return result; } @Override public double get(Integer object) { float value = values.get(object.intValue()); if (!Float.isNaN(value)) return new Float(value).doubleValue(); else return 0f; } @Override public Space<Integer> getSpace() { return new IntegerSpace(); } @Override public int getStoredValueCount() { return values.size(); } @Override public void set(Integer index, double value) { if (value != 0d) values.put(index, new Double(value).floatValue()); else values.remove(index); } public void setFloat(Integer index, float value) { if (value != 0d) values.put(index, value); else values.remove(index); } public float getFloat(Integer index) { float value = values.get(index); if (!Float.isNaN(value)) return value; else return 0; } @Override public void set(Vector<Integer> vector) { values.clear(); VectorCursor<Integer> cursor = vector.getNonzeroCursor(); while (cursor.isValid()) { set(cursor.dimension(), cursor.get()); cursor.next(); } } @Override public void setSpace(Space<Integer> space) { } public VectorCursor<Integer> getCursor() { return new SparseVectorNonzeroCursor(); } public VectorCursor<Integer> getNonzeroCursor() { return new SparseVectorNonzeroCursor(); } public VectorSlaveCursor<Integer> getSlaveCursor() { return new SparseVectorSlaveCursor(); } public Iterator<MapEntry> entryIterator() { return values.entryIterator(); } protected class SparseVectorHandle implements VectorHandle<Integer> { Integer dimension; public Integer dimension() { return dimension; } public int index() { return dimension; } public double get() { float value = values.get(dimension); if (!Float.isNaN(value)) return value; else return 0; } public void set(double value) { if (value != 0d) values.put(dimension, new Double(value).floatValue()); else values.remove(dimension); } } protected class SparseVectorSlaveCursor extends SparseVectorHandle implements VectorSlaveCursor<Integer> { public void synchronize(VectorHandle<Integer> vectorHandle) { dimension = vectorHandle.dimension(); } } protected class SparseVectorNonzeroCursor extends SparseVectorHandle implements VectorCursor<Integer>, Serializable { private static final long serialVersionUID = 2287253547250643918L; MapCursor<Integer, Float> cursor; public SparseVectorNonzeroCursor() { cursor = values.getEntryCursor(); } public boolean isValid() { return cursor.isValid(); } public void next() { cursor.next(); } @Override public Integer dimension() { return cursor.key(); } @Override public double get() { return cursor.value(); } @Override public int index() { return values.getIndexForKey(dimension()); } @Override public void set(double value) { if (value != 0d) cursor.setValue(new Double(value).floatValue()); else cursor.remove(); } } }