package hex.la;
import water.H2O;
import water.fvec.*;
public final class Matrix {
final Frame _x;
public Matrix(Frame x) { _x = x; }
// Matrix multiplication
public Frame mult(Frame y) {
int xrows = (int)_x.numRows();
int xcols = _x.numCols();
int yrows = (int) y.numRows();
int ycols = y.numCols();
if(xcols != yrows)
throw new IllegalArgumentException("Matrices are not compatible for multiplication: ["+xrows+"x"+xcols+"] * ["+yrows+"x"+ycols+"]. Requires [n x m] * [m x p]");
Vec[] x_vecs = _x.vecs();
Vec[] y_vecs = y.vecs();
for(int k = 0; k < xcols; k++) {
if(x_vecs[k].isEnum())
throw new IllegalArgumentException("Multiplication not meaningful for factor column "+k);
}
for(int j = 0; j < ycols; j++) {
if(y_vecs[j].isEnum())
throw new IllegalArgumentException("Multiplication not meaningful for factor column "+j);
}
Vec[] output = new Vec[ycols];
for(int j = 0; j < ycols; j++)
output[j] = Vec.makeSeq(xrows);
for(int i = 0; i < xrows; i++) {
for(int j = 0; j < ycols; j++) {
Vec yvec = y_vecs[j];
double d = 0;
for(int k = 0; k < xcols; k++)
d += x_vecs[k].at(i) * yvec.at(k);
output[j].set(i, d);
}
}
return new Frame(y._names,output);
}
// Outer product
public Frame outerProd() {
int xrows = (int)_x.numRows();
int xcols = _x.numCols();
Vec[] x_vecs = _x.vecs();
for(int j = 0; j < xcols; j++) {
if(x_vecs[j].isEnum())
throw new IllegalArgumentException("Multiplication not meaningful for factor column "+j);
}
Vec[] output = new Vec[xrows];
String[] names = new String[xrows];
for(int i = 0; i < xrows; i++) {
output[i] = Vec.makeSeq(xrows);
names[i] = "C" + String.valueOf(i+1);
}
for(int i = 0; i < xrows; i++) {
for(int j = 0; j < xrows; j++) {
double d = 0;
for(int k = 0; k < xcols; k++)
d += x_vecs[k].at(i)*x_vecs[k].at(k);
output[j].set(i, d);
}
}
return new Frame(names, output);
}
// Transpose
public Frame trans() {
int xrows = (int)_x.numRows();
int xcols = _x.numCols();
Vec[] x_vecs = _x.vecs();
// Currently cannot transpose factors due to domain mismatch
for(int j = 0; j < xcols; j++) {
if(x_vecs[j].isEnum())
throw H2O.unimpl();
}
Vec[] output = new Vec[xrows];
String[] names = new String[xrows];
for(int i = 0; i < xrows; i++) {
output[i] = Vec.makeSeq(xcols);
names[i] = "C" + String.valueOf(i+1);
}
for(int i = 0; i < xrows; i++) {
for(int j = 0; j < xcols; j++) {
double d = x_vecs[j].at(i);
output[i].set(j, d);
}
}
return new Frame(names, output);
}
}