/**
Copyright (C) 2008 Verena Kaynig.
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation (http://www.gnu.org/licenses/gpl.txt )
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
**/
/* **************************************************************** *
* Representation of a non linear transform by explicit polynomial
* kernel expansion.
*
* TODO:
* - make different kernels available
* - inverse transform for visualization
* - improve image interpolation
* - apply and applyInPlace should use precalculated transform?
* (What about out of image range pixels?)
*
* Author: Verena Kaynig
* Kontakt: verena.kaynig@inf.ethz.ch
*
* **************************************************************** */
package lenscorrection;
import ij.ImagePlus;
import ij.io.FileSaver;
import ij.process.ByteProcessor;
import ij.process.ColorProcessor;
import ij.process.FloatProcessor;
import ij.process.ImageProcessor;
import java.awt.Color;
import java.awt.geom.GeneralPath;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStreamWriter;
import mpicbg.trakem2.transform.NonLinearCoordinateTransform;
import Jama.Matrix;
public class NonLinearTransform extends NonLinearCoordinateTransform {
private double[][][] transField = null;
public int getDimension(){ return dimension; }
/** Deletes all dimension dependent properties */
public void setDimension( final int dimension )
{
this.dimension = dimension;
length = (dimension + 1)*(dimension + 2)/2;
beta = new double[length][2];
normMean = new double[length];
normVar = new double[length];
for (int i=0; i < length; i++){
normMean[i] = 0;
normVar[i] = 1;
}
transField = null;
precalculated = false;
}
private boolean precalculated = false;
public int getMinNumMatches()
{
return length;
}
public void fit( final double x[][], final double y[][], final double lambda )
{
final double[][] expandedX = kernelExpandMatrixNormalize( x );
final Matrix phiX = new Matrix( expandedX, expandedX.length, length );
final Matrix phiXTransp = phiX.transpose();
final Matrix phiXProduct = phiXTransp.times( phiX );
final int l = phiXProduct.getRowDimension();
final double lambda2 = 2 * lambda;
for (int i = 0; i < l; ++i )
phiXProduct.set( i, i, phiXProduct.get( i, i ) + lambda2 );
final Matrix phiXPseudoInverse = phiXProduct.inverse();
final Matrix phiXProduct2 = phiXPseudoInverse.times( phiXTransp );
final Matrix betaMatrix = phiXProduct2.times( new Matrix( y, y.length, 2 ) );
setBeta( betaMatrix.getArray() );
}
public void estimateDistortion( final double hack1[][], final double hack2[][], final double transformParams[][], final double lambda, final int w, final int h )
{
beta = new double[ length ][ 2 ];
normMean = new double[ length ];
normVar = new double[ length ];
for ( int i = 0; i < length; i++ )
{
normMean[ i ] = 0;
normVar[ i ] = 1;
}
width = w;
height = h;
/* TODO Find out how to keep some target points fixed (check fit method of NLT which is supposed to be exclusively forward) */
final double expandedX[][] = kernelExpandMatrixNormalize( hack1 );
final double expandedY[][] = kernelExpandMatrix( hack2 );
final int s = expandedX[ 0 ].length;
Matrix S1 = new Matrix( 2 * s, 2 * s );
Matrix S2 = new Matrix( 2 * s, 1 );
for ( int i = 0; i < expandedX.length; ++i )
{
final Matrix xk_ij = new Matrix( expandedX[ i ], 1 );
final Matrix xk_ji = new Matrix( expandedY[ i ], 1 );
final Matrix yk1a = xk_ij.minus( xk_ji.times( transformParams[ i ][ 0 ] ) );
final Matrix yk1b = xk_ij.times( 0.0 ).minus( xk_ji.times( -transformParams[ i ][ 2 ] ) );
final Matrix yk2a = xk_ij.times( 0.0 ).minus( xk_ji.times( -transformParams[ i ][ 1 ] ) );
final Matrix yk2b = xk_ij.minus( xk_ji.times( transformParams[ i ][ 3 ] ) );
final Matrix y = new Matrix( 2, 2 * s );
y.setMatrix( 0, 0, 0, s - 1, yk1a );
y.setMatrix( 0, 0, s, 2 * s - 1, yk1b );
y.setMatrix( 1, 1, 0, s - 1, yk2a );
y.setMatrix( 1, 1, s, 2 * s - 1, yk2b );
final Matrix xk = new Matrix( 2, 2 * expandedX[ 0 ].length );
xk.setMatrix( 0, 0, 0, s - 1, xk_ij );
xk.setMatrix( 1, 1, s, 2 * s - 1, xk_ij );
final double[] vals = { hack1[ i ][ 0 ], hack1[ i ][ 1 ] };
final Matrix c = new Matrix( vals, 2 );
final Matrix X = xk.transpose().times( xk ).times( lambda );
final Matrix Y = y.transpose().times( y );
S1 = S1.plus( Y.plus( X ) );
final double trans1 = ( transformParams[ i ][ 2 ] * transformParams[ i ][ 5 ] - transformParams[ i ][ 0 ] * transformParams[ i ][ 4 ] );
final double trans2 = ( transformParams[ i ][ 1 ] * transformParams[ i ][ 4 ] - transformParams[ i ][ 3 ] * transformParams[ i ][ 5 ] );
final double[] trans = { trans1, trans2 };
final Matrix translation = new Matrix( trans, 2 );
final Matrix YT = y.transpose().times( translation );
final Matrix XC = xk.transpose().times( c ).times( lambda );
S2 = S2.plus( YT.plus( XC ) );
}
final Matrix regularize = Matrix.identity( S1.getRowDimension(), S1.getColumnDimension() );
final Matrix newBeta = new Matrix( S1.plus( regularize.times( 0.001 ) ).inverse().times( S2 ).getColumnPackedCopy(), s );
setBeta( newBeta.getArray() );
}
public NonLinearTransform(final double[][] b, final double[] nm, final double[] nv, final int d, final int w, final int h){
beta = b;
normMean = nm;
normVar = nv;
dimension = d;
length = (dimension + 1)*(dimension + 2)/2;
width = w;
height = h;
}
public NonLinearTransform(final int d, final int w, final int h){
dimension = d;
length = (dimension + 1)*(dimension + 2)/2;
beta = new double[length][2];
normMean = new double[length];
normVar = new double[length];
for (int i=0; i < length; i++){
normMean[i] = 0;
normVar[i] = 1;
}
width = w;
height = h;
}
public NonLinearTransform(){};
public NonLinearTransform(final String filename){
this.load(filename);
}
public NonLinearTransform(final double[][] coeffMatrix, final int w, final int h){
length = coeffMatrix.length;
beta = new double[length][2];
normMean = new double[length];
normVar = new double[length];
width = w;
height = h;
dimension = (int)(-1.5 + Math.sqrt(0.25 + 2*length));
for(int i=0; i<length; i++){
beta[i][0] = coeffMatrix[0][i];
beta[i][1] = coeffMatrix[1][i];
normMean[i] = coeffMatrix[2][i];
normVar[i] = coeffMatrix[3][i];
}
}
void precalculateTransfom(){
transField = new double[width][height][2];
//double minX = width, minY = height, maxX = 0, maxY = 0;
for (int x=0; x<width; x++){
for (int y=0; y<height; y++){
final double[] position = {x,y};
final double[] featureVector = kernelExpand(position);
final double[] newPosition = multiply(beta, featureVector);
if ((newPosition[0] < 0) || (newPosition[0] >= width) ||
(newPosition[1] < 0) || (newPosition[1] >= height))
{
transField[x][y][0] = -1;
transField[x][y][1] = -1;
continue;
}
transField[x][y][0] = newPosition[0];
transField[x][y][1] = newPosition[1];
//minX = Math.min(minX, x);
//minY = Math.min(minY, y);
//maxX = Math.max(maxX, x);
//maxY = Math.max(maxY, y);
}
}
precalculated = true;
}
public double[][] getCoefficients(){
final double[][] coeffMatrix = new double[4][length];
for(int i=0; i<length; i++){
coeffMatrix[0][i] = beta[i][0];
coeffMatrix[1][i] = beta[i][1];
coeffMatrix[2][i] = normMean[i];
coeffMatrix[3][i] = normVar[i];
}
return coeffMatrix;
}
public void setBeta(final double[][] b){
beta = b;
//FIXME: test if normMean and normVar are still valid for this beta
}
public void print(){
System.out.println("beta:");
for (int i=0; i < beta.length; i++){
for (int j=0; j < beta[i].length; j++){
System.out.print(beta[i][j]);
System.out.print(" ");
}
System.out.println();
}
System.out.println("normMean:");
for (int i=0; i < normMean.length; i++){
System.out.print(normMean[i]);
System.out.print(" ");
}
System.out.println("normVar:");
for (int i=0; i < normVar.length; i++){
System.out.print(normVar[i]);
System.out.print(" ");
}
System.out.println("Image size:");
System.out.println("width: " + width + " height: " + height);
System.out.println();
}
public void save( final String filename )
{
try{
final BufferedWriter out = new BufferedWriter(
new OutputStreamWriter(
new FileOutputStream( filename) ) );
try{
out.write("Kerneldimension");
out.newLine();
out.write(Integer.toString(dimension));
out.newLine();
out.newLine();
out.write("number of rows");
out.newLine();
out.write(Integer.toString(length));
out.newLine();
out.newLine();
out.write("Coefficients of the transform matrix:");
out.newLine();
for (int i=0; i < length; i++){
String s = Double.toString(beta[i][0]);
s += " ";
s += Double.toString(beta[i][1]);
out.write(s);
out.newLine();
}
out.newLine();
out.write("normMean:");
out.newLine();
for (int i=0; i < length; i++){
out.write(Double.toString(normMean[i]));
out.newLine();
}
out.newLine();
out.write("normVar: ");
out.newLine();
for (int i=0; i < length; i++){
out.write(Double.toString(normVar[i]));
out.newLine();
}
out.newLine();
out.write("image size: ");
out.newLine();
out.write(width + " " + height);
out.close();
}
catch(final IOException e){System.out.println("IOException");}
}
catch(final FileNotFoundException e){System.out.println("File not found!");}
}
public void load(final String filename){
try{
final BufferedReader in = new BufferedReader(new FileReader(filename));
try{
String line = in.readLine(); //comment;
dimension = Integer.parseInt(in.readLine());
line = in.readLine(); //comment;
line = in.readLine(); //comment;
length = Integer.parseInt(in.readLine());
line = in.readLine(); //comment;
line = in.readLine(); //comment;
beta = new double[length][2];
for (int i=0; i < length; i++){
line = in.readLine();
final int ind = line.indexOf(" ");
beta[i][0] = Double.parseDouble(line.substring(0, ind));
beta[i][1] = Double.parseDouble(line.substring(ind+4));
}
line = in.readLine(); //comment;
line = in.readLine(); //comment;
normMean = new double[length];
for (int i=0; i < length; i++){
normMean[i]=Double.parseDouble(in.readLine());
}
line = in.readLine(); //comment;
line = in.readLine(); //comment;
normVar = new double[length];
for (int i=0; i < length; i++){
normVar[i]=Double.parseDouble(in.readLine());
}
line = in.readLine(); //comment;
line = in.readLine(); //comment;
line = in.readLine();
final int ind = line.indexOf(" ");
width = Integer.parseInt(line.substring(0, ind));
height = Integer.parseInt(line.substring(ind+4));
in.close();
print();
}
catch(final IOException e){System.out.println("IOException");}
}
catch(final FileNotFoundException e){System.out.println("File not found!");}
}
public ImageProcessor[] transform(final ImageProcessor ip){
if (!precalculated)
this.precalculateTransfom();
final ImageProcessor newIp = ip.createProcessor(ip.getWidth(), ip.getHeight());
if (ip instanceof ColorProcessor) ip.max(0);
final ImageProcessor maskIp = new ByteProcessor(ip.getWidth(),ip.getHeight());
for (int x=0; x < width; x++){
for (int y=0; y < height; y++){
if (transField[x][y][0] == -1){
continue;
}
newIp.set(x, y, (int) ip.getInterpolatedPixel((int)transField[x][y][0],(int)transField[x][y][1]));
maskIp.set(x,y,255);
}
}
return new ImageProcessor[]{newIp, maskIp};
}
public double[][] kernelExpandMatrixNormalize(final double positions[][]){
normMean = new double[length];
normVar = new double[length];
for (int i=0; i < length; i++){
normMean[i] = 0;
normVar[i] = 1;
}
final double expanded[][] = new double[positions.length][length];
for (int i=0; i < positions.length; i++){
expanded[i] = kernelExpand(positions[i]);
}
for (int i=0; i < length; i++){
double mean = 0;
double var = 0;
for (int j=0; j < expanded.length; j++){
mean += expanded[j][i];
}
mean /= expanded.length;
for (int j=0; j < expanded.length; j++){
var += (expanded[j][i] - mean)*(expanded[j][i] - mean);
}
var /= (expanded.length -1);
var = Math.sqrt(var);
normMean[i] = mean;
normVar[i] = var;
}
return kernelExpandMatrix(positions);
}
//this function uses the parameters already stored
//in this object to normalize the positions given.
public double[][] kernelExpandMatrix(final double positions[][]){
final double expanded[][] = new double[positions.length][length];
for (int i=0; i < positions.length; i++){
expanded[i] = kernelExpand(positions[i]);
}
return expanded;
}
public void inverseTransform(final double range[][]){
Matrix expanded = new Matrix(kernelExpandMatrix(range));
final Matrix b = new Matrix(beta);
final Matrix transformed = expanded.times(b);
expanded = new Matrix(kernelExpandMatrixNormalize(transformed.getArray()));
final Matrix r = new Matrix(range);
final Matrix invBeta = expanded.transpose().times(expanded).inverse().times(expanded.transpose()).times(r);
setBeta(invBeta.getArray());
}
//FIXME this takes way too much memory
public void visualize(){
final int density = Math.max(width,height)/32;
final int border = Math.max(width,height)/8;
final double[][] orig = new double[width * height][2];
final double[][] trans = new double[height * width][2];
final double[][] gridOrigVert = new double[width*height][2];
final double[][] gridTransVert = new double[width*height][2];
final double[][] gridOrigHor = new double[width*height][2];
final double[][] gridTransHor = new double[width*height][2];
final FloatProcessor magnitude = new FloatProcessor(width, height);
final FloatProcessor angle = new FloatProcessor(width, height);
final ColorProcessor quiver = new ColorProcessor(width, height);
final ByteProcessor empty = new ByteProcessor(width+2*border, height+2*border);
quiver.setLineWidth(1);
quiver.setColor(Color.green);
final GeneralPath quiverField = new GeneralPath();
float minM = 1000, maxM = 0;
float minArc = 5, maxArc = -6;
int countVert = 0, countHor = 0, countHorWhole = 0;
for (int i=0; i < width; i++){
countHor = 0;
for (int j=0; j < height; j++){
final double[] position = {(double) i,(double) j};
final double[] posExpanded = kernelExpand(position);
final double[] newPosition = multiply(beta, posExpanded);
orig[i*j][0] = position[0];
orig[i*j][1] = position[1];
trans[i*j][0] = newPosition[0];
trans[i*j][1] = newPosition[1];
double m = (position[0] - newPosition[0]) * (position[0] - newPosition[0]);
m += (position[1] - newPosition[1]) * (position[1] - newPosition[1]);
m = Math.sqrt(m);
magnitude.setf(i,j, (float) m);
minM = Math.min(minM, (float) m);
maxM = Math.max(maxM, (float) m);
final double a = Math.atan2(position[0] - newPosition[0], position[1] - newPosition[1]);
minArc = Math.min(minArc, (float) a);
maxArc = Math.max(maxArc, (float) a);
angle.setf(i,j, (float) a);
if (i%density == 0 && j%density == 0)
drawQuiverField(quiverField, position[0], position[1], newPosition[0], newPosition[1]);
if (i%density == 0){
gridOrigVert[countVert][0] = position[0] + border;
gridOrigVert[countVert][1] = position[1] + border;
gridTransVert[countVert][0] = newPosition[0] + border;
gridTransVert[countVert][1] = newPosition[1] + border;
countVert++;
}
if (j%density == 0){
gridOrigHor[countHor*width+i][0] = position[0] + border;
gridOrigHor[countHor*width+i][1] = position[1] + border;
gridTransHor[countHor*width+i][0] = newPosition[0] + border;
gridTransHor[countHor*width+i][1] = newPosition[1] + border;
countHor++;
countHorWhole++;
}
}
}
magnitude.setMinAndMax(minM, maxM);
angle.setMinAndMax(minArc, maxArc);
//System.out.println(" " + minArc + " " + maxArc);
final ImagePlus magImg = new ImagePlus("Magnitude of Distortion Field", magnitude);
magImg.show();
// ImagePlus angleImg = new ImagePlus("Angle of Distortion Field Vectors", angle);
// angleImg.show();
final ImagePlus quiverImg = new ImagePlus("Quiver Plot of Distortion Field", magnitude);
quiverImg.show();
quiverImg.getCanvas().setDisplayList(quiverField, Color.green, null );
quiverImg.updateAndDraw();
// GeneralPath gridOrig = new GeneralPath();
// drawGrid(gridOrig, gridOrigVert, countVert, height);
// drawGrid(gridOrig, gridOrigHor, countHorWhole, width);
// ImagePlus gridImgOrig = new ImagePlus("Distortion Grid", empty);
// gridImgOrig.show();
// gridImgOrig.getCanvas().setDisplayList(gridOrig, Color.green, null );
// gridImgOrig.updateAndDraw();
final GeneralPath gridTrans = new GeneralPath();
drawGrid(gridTrans, gridTransVert, countVert, height);
drawGrid(gridTrans, gridTransHor, countHorWhole, width);
final ImagePlus gridImgTrans = new ImagePlus("Distortion Grid", empty);
gridImgTrans.show();
gridImgTrans.getCanvas().setDisplayList(gridTrans, Color.green, null );
gridImgTrans.updateAndDraw();
//new FileSaver(quiverImg.getCanvas().imp).saveAsTiff("QuiverCanvas.tif");
new FileSaver(quiverImg).saveAsTiff("QuiverImPs.tif");
System.out.println("FINISHED");
}
public void visualizeSmall(final double lambda){
final int density = Math.max(width,height)/32;
final double[][] orig = new double[2][width * height];
final double[][] trans = new double[2][height * width];
final FloatProcessor magnitude = new FloatProcessor(width, height);
final GeneralPath quiverField = new GeneralPath();
float minM = 1000, maxM = 0;
final float minArc = 5, maxArc = -6;
final int countVert = 0;
int countHor = 0;
final int countHorWhole = 0;
for (int i=0; i < width; i++){
countHor = 0;
for (int j=0; j < height; j++){
final double[] position = {(double) i,(double) j};
final double[] posExpanded = kernelExpand(position);
final double[] newPosition = multiply(beta, posExpanded);
orig[0][i*j] = position[0];
orig[1][i*j] = position[1];
trans[0][i*j] = newPosition[0];
trans[1][i*j] = newPosition[1];
double m = (position[0] - newPosition[0]) * (position[0] - newPosition[0]);
m += (position[1] - newPosition[1]) * (position[1] - newPosition[1]);
m = Math.sqrt(m);
magnitude.setf(i,j, (float) m);
minM = Math.min(minM, (float) m);
maxM = Math.max(maxM, (float) m);
if (i%density == 0 && j%density == 0)
drawQuiverField(quiverField, position[0], position[1], newPosition[0], newPosition[1]);
}
}
magnitude.setMinAndMax(minM, maxM);
final ImagePlus quiverImg = new ImagePlus("Quiver Plot for lambda = "+lambda, magnitude);
quiverImg.show();
quiverImg.getCanvas().setDisplayList(quiverField, Color.green, null );
quiverImg.updateAndDraw();
System.out.println("FINISHED");
}
public static void drawGrid(final GeneralPath g, final double[][] points, final int count, final int s){
for (int i=0; i < count - 1; i++){
if ((i+1)%s != 0){
g.moveTo((float)points[i][0], (float)points[i][1]);
g.lineTo((float)points[i+1][0], (float)points[i+1][1]);
}
}
}
public static void drawQuiverField(final GeneralPath qf, final double x1, final double y1, final double x2, final double y2)
{
qf.moveTo((float)x1, (float)y1);
qf.lineTo((float)x2, (float)y2);
}
public int getWidth(){
return width;
}
public int getHeight(){
return height;
}
/**
* TODO Make this more efficient
*/
@Override
final public NonLinearTransform copy()
{
final NonLinearTransform t = new NonLinearTransform();
t.init( toDataString() );
return t;
}
public void set( final NonLinearTransform nlt )
{
this.dimension = nlt.dimension;
this.height = nlt.height;
this.length = nlt.length;
this.precalculated = nlt.precalculated;
this.width = nlt.width;
/* arrays by deep cloning */
this.beta = new double[ nlt.beta.length ][];
for ( int i = 0; i < nlt.beta.length; ++i )
this.beta[ i ] = nlt.beta[ i ].clone();
this.normMean = nlt.normMean.clone();
this.normVar = nlt.normVar.clone();
this.transField = new double[ nlt.transField.length ][][];
for ( int a = 0; a < nlt.transField.length; ++a )
{
this.transField[ a ] = new double[ nlt.transField[ a ].length ][];
for ( int b = 0; b < nlt.transField[ a ].length; ++b )
this.transField[ a ][ b ] = nlt.transField[ a ][ b ].clone();
}
}
}