package automenta.spacenet.var.graph.map;
import automenta.spacenet.var.Maths;
import automenta.spacenet.var.graph.MemGraph;
import automenta.spacenet.var.map.MapVar;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
public class ScalarGraphMap<N, E> extends MapVar<N, Double> {
//TODO the way this functions it is actually a ScalarNodeMap - how to specify scalars for edges?
private final double bias;
public final MemGraph<N, E> graph;
double maxAllowedValue = 1.0;
double minAllowedValue = 0.0;
private double currentMaxValue;
private double currentMinValue;
public ScalarGraphMap(MemGraph<N, E> g, double bias) {
super();
this.graph = g;
this.bias = bias;
clear();
}
public ScalarGraphMap(MemGraph graph) {
this(graph, 0.0);
}
//public void sharpen(double transferProportion, ...)
public void addRandom(double min, double max) {
List<N> l = new LinkedList(keySet());
if (l.size() > 0) {
N n = l.get((int) Math.floor(Math.random() * l.size()));
add(n, Maths.random(min, max));
}
}
public void blur(N node, double transferProportion) {
Map<N, Double> nextAtt = new HashMap();
double prevA = value(node);
double average = prevA;
Collection<N> neighbors = graph.getNeighbors(node);
for (N ne : neighbors) {
average += value(ne);
}
average /= (neighbors.size() + 1);
double nextA = (average * transferProportion) + (prevA * (1.0 - transferProportion));
nextAtt.put(node, nextA);
for (N ne : neighbors) {
double prevB = value(ne);
double nextB = (average * transferProportion) + (prevB * (1.0 - transferProportion));
nextAtt.put(ne, nextB);
}
for (N n : nextAtt.keySet()) {
set(n, nextAtt.get(n));
}
}
public void blur(double transferProportion) {
Map<N, Double> nextAtt = new HashMap();
for (N n : graph.getNodes()) {
double prevA = value(n);
double average = prevA;
Collection<N> neighbors = graph.getNeighbors(n);
for (N ne : neighbors) {
average += value(ne);
}
average /= (neighbors.size() + 1);
double nextA = (average * transferProportion) + (prevA * (1.0 - transferProportion));
nextAtt.put(n, nextA);
}
for (N n : nextAtt.keySet()) {
set(n, nextAtt.get(n));
}
}
// //blur = diffuse
// public void blur(double transferProportion, Predicate<E> traverseEdge) {
//
// }
public void randomize(double min, double max) {
List<N> l = new LinkedList(graph.getNodes());
for (N n : l) {
set(n, Maths.random(min, max));
}
}
public void set(N n, double a) {
a = Math.min(a, maxAllowedValue);
a = Math.max(a, minAllowedValue);
put(n, a);
if (a > currentMaxValue) {
currentMaxValue = a;
}
if (a < currentMinValue) {
currentMinValue = a;
}
}
public void add(N n, double dA) {
set(n, d(n) + dA);
}
public double d(N n) {
Double d = get(n);
if (d == null) {
d = getDefaultValue();
put(n, d);
}
return d;
}
public double value(N n) {
return d(n) + getBias();
}
public double getBias() {
return bias;
}
public double getDefaultValue() {
return 0.0;
}
public void mult(double d) {
for (N n : graph.getNodes()) {
set(n, value(n) * d);
}
}
public void mult(double d, double minSize) {
for (N n : graph.getNodes()) {
set(n, Math.max(minSize, get(n) * d));
}
}
public MemGraph<N, E> getGraph() {
return graph;
}
public List<N> getNodesSortedNow() {
List<N> l = new ArrayList(graph.getNodes());
Collections.sort(l, new Comparator<N>() {
@Override public int compare(N a, N b) {
double va = value(a);
double vb = value(b);
if (va == vb) {
return 0;
}
if (va < vb) {
return 1;
}
return -1;
}
});
return l;
}
public double getMin() {
if (size() == 0) {
return 0;
}
return currentMinValue;
// //TODO optimize
// double min = Double.POSITIVE_INFINITY;
// for (Double d : values()) {
// if (d < min)
// min = d;
// }
// return min;
}
public double getMax() {
if (size() == 0) {
return 0;
}
return currentMaxValue;
// //TODO optimize
// double max = Double.NEGATIVE_INFINITY;
// for (Double d : values()) {
// if (d > max)
// max = d;
// }
// return max;
}
public double valueNormalized(N node) {
double v = value(node);
double min = getMin();
double max = getMax();
return (v - min) / (max - min);
}
public void focus(double s) {
//apply sigmoid function to all nodes
List<N> l = new LinkedList(graph.getNodes());
for (N n : l) {
double v = value(n);
if (v < 0.5 * (getMin() + getMax())) {
v = v * (1.0 - s);
}
else {
v = v * (1.0 + s);
}
set(n, v);
}
}
}