package edu.fudan.ml.types;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
/**
* 通过二分查找实现的稀疏向量,并实现各种向量运算
* @version 1.0
* @since 1.0
*
*/
public class LinearSparseVector implements Serializable {
private static final long serialVersionUID = 1467092492463327579L;
protected float[] data = new float[0];
protected int[] index = new int[0];
protected int length;
private int increSize = 8;
public LinearSparseVector() {
}
public LinearSparseVector(int init) {
length = 0;
data = new float[init];
index = new int[init];
Arrays.fill(index, Integer.MAX_VALUE);
}
/**
* 将一般数组转换成稀疏表示
*
* @param w 数组
*/
public LinearSparseVector(float[] w) {
for (int i = 0; i < w.length; i++) {
if (Math.abs((w[i]-0f))<Float.MIN_VALUE) {
put(i, w[i]);
}
}
}
/**
* 将一般数组转换成稀疏表示,并增加常数项
* @param w 数组
* @param b 是否设置常数项
*/
public LinearSparseVector(float[] w, boolean b) {
for (int i = 0; i < w.length; i++) {
if (Math.abs((w[i]-0f))<Float.MIN_VALUE) {
put(i, w[i]);
}
}
if(b)
put(w.length,1.0f);
}
public LinearSparseVector(LinearSparseVector sv) {
index = Arrays.copyOf(sv.index, sv.length);
data = Arrays.copyOf(sv.data, sv.length);
length = sv.length;
}
/**
* 向量减法: x-y
*
* @param sv
*/
public void minus(LinearSparseVector sv) {
for (int r = 0; r < sv.length; r++) {
int p = Arrays.binarySearch(index, sv.index[r]);
if (p >= 0) {
data[p] = (float) data[p] - (float) sv.data[r];
} else {
put(sv.index[r], -(float) sv.data[r]);
}
}
}
/**
* 对应位置加上值: x[i] = x[i]+c
*
* @param id
* @param c
*/
public void add(int id, float c) {
int p = Arrays.binarySearch(index, id);
if (p >= 0) {
data[p] = ((float) data[p]) + c;
} else {
put(id, c);
}
}
/**
* 向量加法:x+y
*
* @param sv
*/
public void plus(LinearSparseVector sv) {
plus(sv, 1);
}
/**
* 计算x+y*w
*
* @param sv
* @param w
*/
public void plus(LinearSparseVector sv, float w) {
if (sv == null)
return;
for (int i = 0; i < sv.length; i++) {
float val = (float) sv.data[i] * w;
add(sv.index[i], val);
}
}
public float elementAt(int id) {
float ret = 0;
int p = Arrays.binarySearch(index, id);
if (p >= 0)
ret = (float) data[p];
return ret;
}
public int[] indices() {
return Arrays.copyOfRange(index, 0, length);
}
/**
* 向量点积: x*y
*
* @param sv
* @return 结果
*/
public float dotProduct(LinearSparseVector sv) {
return dotProduct(sv, 0);
}
/**
* 向量点积: x*(y+c)
*
* @param sv
* @return 结果
*/
public float dotProduct(LinearSparseVector sv, float c) {
float product = 0;
for (int i = 0; i < sv.length; i++) {
int p = Arrays.binarySearch(index, sv.index[i]);
if (p >= 0) {
float val = (float) sv.data[i] + c;
val *= (float) data[p];
product += val;
}
}
return product;
}
/**
* A*(B+c)
* @param sv
* @param li
* @param n
* @return 结果
*/
public float dotProduct(LinearSparseVector sv, int li, int n) {
float product = 0;
int z = n * li;
for (int i = 0; i < length; i++) {
int p = Arrays.binarySearch(sv.index, index[i] + z);
if (p >= 0) {
product += (float) data[i] + (float) sv.data[p];
}
}
return product;
}
public void scaleMultiply(float c) {
if (c == 0)
clear();
for (int i = 0; i < length; i++) {
data[i] = (float) data[i] * c;
}
}
public void scaleDivide(float c) {
if (c == 0)
throw new ArithmeticException();
for (int i = 0; i < length; i++) {
data[i] = (float) data[i] / c;
}
}
public float l1Norm() {
float norm = 0;
for (int i = 0; i < length; i++) {
norm += Math.abs((float) data[i]);
}
return norm;
}
public float l2Norm2() {
float norm = 0;
for (int i = 0; i < length; i++) {
float val =data[i];
norm += val * val;
}
return norm;
}
public float l2Norm() {
float norm = 0;
for (int i = 0; i < length; i++) {
float val = data[i];
norm += val * val;
}
return (float) Math.sqrt(norm);
}
public float infinityNorm() {
float norm = 0;
for (int i = 0; i < length; i++) {
float val = Math.abs((float) data[i]);
if (val > norm)
norm = val;
}
return norm;
}
public LinearSparseVector replicate(ArrayList<Integer> list, int dim) {
LinearSparseVector sv = new LinearSparseVector();
for (int i = 0; i < length; i++) {
for (int j = 0; j < list.size(); j++) {
sv.put(index[i] + dim * list.get(j), (float) data[i]);
}
}
return sv;
}
public String toString() {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < length; i++) {
sb.append(index[i]);
sb.append(':');
sb.append(data[i]);
sb.append(' ');
}
return sb.toString();
}
/**
* 计算两个向量距离
*
* @param sv
* @return 距离值
*/
public float euclideanDistance(LinearSparseVector sv) {
float dist = 0.0f;
int r = 0;
for (int i = 0; i < sv.length; i++) {
if (sv.index[i] == index[r]) {
float cur = (float) data[r] - (float) sv.data[i];
dist += cur * cur;
r++;
} else {
float cur = (float) sv.data[i];
dist += cur * cur;
}
}
for (; r < length; r++) {
float cur = (float) data[r];
dist += cur * cur;
}
return dist;
}
public void clear() {
length = 0;
Arrays.fill(index, Integer.MAX_VALUE);
}
public void normalize() {
float norm = l2Norm();
if (norm > 0)
scaleMultiply(1 / norm);
}
public void normalize2() {
float sum = 0;
for (int i = 0; i < length; i++) {
float value = (float) Math.exp(data[i]);
data[i] = value;
sum += value;
}
scaleDivide(sum);
}
public float dotProduct(float[] weights) {
if (index[length - 1] >= weights.length)
throw new IllegalArgumentException();
float product = 0;
for (int i = 0; i < length; i++) {
product += (float) data[i] * weights[index[i]];
}
return product;
}
public float get(int idx) {
int cur = Arrays.binarySearch(index, 0,length,idx);
if (cur >= 0)
return data[cur];
return -1f;
}
public void put(int idx, float value) {
int cur = Arrays.binarySearch(index, 0,length,idx);
if (cur < 0) {
if (length == data.length)
grow();
int p = -cur-1;
System.arraycopy(data, p, data, p+1, length-p);
System.arraycopy(index, p, index, p+1, length-p);
data[p] = value;
index[p] = idx;
length++;
}else {
data[cur] = value;
}
}
/**
* 去掉第idx维特征
* @param idx
* @return
*/
public float remove(int idx) {
float ret = -1f;
int p = Arrays.binarySearch(index,0,length, idx);
if (p >= 0) {
System.arraycopy(data, p+1, data, p, length-p-1);
System.arraycopy(index, p+1, index, p, length-p-1);
length--;
}else{
System.err.println("error");
}
return ret;
}
protected void grow() {
int nSize = data.length+increSize;
float[] nData = new float[nSize];
Arrays.fill(nData, Float.NaN);
System.arraycopy(data, 0, nData, 0, length);
int[] nIndex = new int[nSize];
Arrays.fill(nIndex, Integer.MAX_VALUE);
System.arraycopy(index, 0, nIndex, 0, length);
data = null; index = null;
data = nData;
index = nIndex;
}
public int capacity() {
return data.length;
}
public void compact() {
float[] nData = new float[length];
System.arraycopy(data, 0, nData, 0, length);
int[] nIndex = new int[length];
System.arraycopy(index, 0, nIndex, 0, length);
data = null; index = null;
data = nData;
index = nIndex;
}
/**
* 稀疏元素个数
* @return
*/
public int size() {
return length;
}
public boolean containsKey(int idx) {
int cur = Arrays.binarySearch(index, 0,length,idx);
if (cur < 0)
return false;
else
return true;
}
public Iterator<Integer> iterator() {
return new IndexIterator();
}
protected class IndexIterator implements Iterator<Integer> {
int cur = 0;
public boolean hasNext() {
return (cur < length);
}
public Integer next() {
return index[cur++];
}
public void remove() {
throw new UnsupportedOperationException();
}
}
}