/*
* Copyright (C) 2014 by Array Systems Computing Inc. http://www.array.ca
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the Free
* Software Foundation; either version 3 of the License, or (at your option)
* any later version.
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program; if not, see http://www.gnu.org/licenses/
*/
package org.esa.s1tbx.insar.gpf;
import Jama.Matrix;
import Jama.SingularValueDecomposition;
import com.bc.ceres.core.ProgressMonitor;
import org.esa.snap.core.datamodel.Band;
import org.esa.snap.core.datamodel.Product;
import org.esa.snap.core.datamodel.ProductData;
import org.esa.snap.core.datamodel.VirtualBand;
import org.esa.snap.core.dataop.downloadable.StatusProgressMonitor;
import org.esa.snap.core.gpf.Operator;
import org.esa.snap.core.gpf.OperatorException;
import org.esa.snap.core.gpf.OperatorSpi;
import org.esa.snap.core.gpf.Tile;
import org.esa.snap.core.gpf.annotations.OperatorMetadata;
import org.esa.snap.core.gpf.annotations.Parameter;
import org.esa.snap.core.gpf.annotations.SourceProduct;
import org.esa.snap.core.gpf.annotations.TargetProduct;
import org.esa.snap.core.util.ProductUtils;
import org.esa.snap.core.util.math.MathUtils;
import org.esa.snap.engine_utilities.gpf.ThreadManager;
import org.esa.snap.engine_utilities.gpf.TileIndex;
import org.esa.snap.engine_utilities.util.ResourceUtils;
import java.awt.Desktop;
import java.awt.Dimension;
import java.awt.Rectangle;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/**
* The operator performs principle component analysis for user selected master/slave pairs.
*/
@OperatorMetadata(alias = "Principle-Components",
description = "Principle Component Analysis",
category = "Raster/Image Analysis",
version = "1.0",
authors = "Jun Lu, Luis Veci",
copyright = "Copyright (C) 2014 by Array Systems Computing Inc.")
public class PCAOp extends Operator {
@SourceProduct
private Product sourceProduct;
@TargetProduct
private Product targetProduct;
@Parameter(description = "The list of source bands.", alias = "sourceBands",
rasterDataNodeType = Band.class, label = "Source Bands")
private String[] sourceBandNames;
@Parameter(valueSet = {EIGENVALUE_THRESHOLD, NUMBER_EIGENVALUES},
defaultValue = EIGENVALUE_THRESHOLD, label = "Select Eigenvalues By:")
private String selectEigenvaluesBy = EIGENVALUE_THRESHOLD;
@Parameter(description = "The threshold for selecting eigenvalues", interval = "(0, 100]", defaultValue = "100",
label = "Eigenvalue Threshold (%)")
private double eigenvalueThreshold = 100.0;
@Parameter(description = "The number of PCA images output", interval = "(0, 100]", defaultValue = "1",
label = "Number Of PCA Images")
private int numPCA = 1;
@Parameter(description = "Show the eigenvalues", defaultValue = "1", label = "Show Eigenvalues")
private Boolean showEigenvalues = false;
@Parameter(description = "Subtract mean image", defaultValue = "1", label = "Subtract Mean Image")
private Boolean subtractMeanImage = false;
private boolean statsCalculated = false;
private int numOfPixels = 0; // total number of pixel values
private int numOfSourceBands = 0; // number of user selected bands
private double[] sum = null; // summation of pixel values for each band
private double[][] sumCross = null; // summation of the dot product of each band and the master band
private double[] mean = null; // mean of pixel values for each band
private double[][] meanCross = null;// mean of the dot product of each band and the master band
public static final String EIGENVALUE_THRESHOLD = "Eigenvalue Threshold";
public static final String NUMBER_EIGENVALUES = "Number of Eigenvalues";
private static final String meanImageBandName = "Mean_Image";
private double totalEigenvalues; // summation of all eigenvalues
private boolean pcaImageComputed = false;
private double[][] eigenVectorMatrices = null; // eigenvector matrices for all slave bands
private double[] eigenValues = null; // eigenvalues for all slave bands
private double[] minPCA = null; // min value for first and second PCA images for all master/slave band pairs
/**
* Default constructor. The graph processing framework
* requires that an operator has a default constructor.
*/
public PCAOp() {
}
/**
* Initializes this operator and sets the one and only target product.
* <p>The target product can be either defined by a field of type {@link Product} annotated with the
* {@link TargetProduct TargetProduct} annotation or
* by calling {@link #setTargetProduct} method.</p>
* <p>The framework calls this method after it has created this operator.
* Any client code that must be performed before computation of tile data
* should be placed here.</p>
*
* @throws OperatorException If an error occurs during operator initialisation.
* @see #getTargetProduct()
*/
@Override
public void initialize() throws OperatorException {
try {
if (selectEigenvaluesBy.equals(NUMBER_EIGENVALUES) && numPCA > sourceBandNames.length) {
throw new OperatorException("The number of eigenvalues should not be greater than the number of selected bands");
}
createTargetProduct();
addSelectedBands();
setInitialValues();
} catch (Throwable e) {
throw new OperatorException(e);
}
}
/**
* Set initial values to some internal variables.
*/
private void setInitialValues() {
mean = new double[numOfSourceBands];
meanCross = new double[numOfSourceBands][numOfSourceBands];
sum = new double[numOfSourceBands];
sumCross = new double[numOfSourceBands][numOfSourceBands];
for (int i = 0; i < numOfSourceBands; i++) {
sum[i] = 0.0;
mean[i] = 0.0;
for (int j = 0; j < numOfSourceBands; j++) {
sumCross[i][j] = 0.0;
meanCross[i][j] = 0.0;
}
}
numOfPixels = sourceProduct.getSceneRasterWidth() * sourceProduct.getSceneRasterHeight();
}
/**
* Create target product.
*/
void createTargetProduct() {
targetProduct = new Product(sourceProduct.getName(),
sourceProduct.getProductType(),
sourceProduct.getSceneRasterWidth(),
sourceProduct.getSceneRasterHeight());
ProductUtils.copyMetadata(sourceProduct, targetProduct);
ProductUtils.copyTiePointGrids(sourceProduct, targetProduct);
ProductUtils.copyFlagCodings(sourceProduct, targetProduct);
ProductUtils.copyGeoCoding(sourceProduct, targetProduct);
ProductUtils.copyMasks(sourceProduct, targetProduct);
ProductUtils.copyVectorData(sourceProduct, targetProduct);
targetProduct.setStartTime(sourceProduct.getStartTime());
targetProduct.setEndTime(sourceProduct.getEndTime());
targetProduct.setDescription(sourceProduct.getDescription());
}
/**
* Add user selected slave bands to target product.
*/
private void addSelectedBands() {
// if no source band is selected by user, then select all bands
if (sourceBandNames == null || sourceBandNames.length == 0) {
final Band[] bands = sourceProduct.getBands();
final List<String> bandNameList = new ArrayList<>(sourceProduct.getNumBands());
for (Band band : bands) {
bandNameList.add(band.getName());
}
sourceBandNames = bandNameList.toArray(new String[bandNameList.size()]);
}
numOfSourceBands = sourceBandNames.length;
if (numOfSourceBands <= 1) {
throw new OperatorException("For PCA, more than one band should be selected");
}
// add PCA bands in target product
final Band sourcerBand = sourceProduct.getBand(sourceBandNames[0]);
if (sourcerBand == null) {
throw new OperatorException("Source band not found: " + sourcerBand);
}
if (selectEigenvaluesBy.equals(EIGENVALUE_THRESHOLD)) {
numPCA = numOfSourceBands;
}
final int imageWidth = sourcerBand.getRasterWidth();
final int imageHeight = sourcerBand.getRasterHeight();
final String unit = sourcerBand.getUnit();
for (int i = 0; i < numPCA; i++) {
final String targetBandName = "PC" + i;
final Band targetBand = new Band(targetBandName, ProductData.TYPE_FLOAT32, imageWidth, imageHeight);
targetBand.setUnit(unit);
targetProduct.addBand(targetBand);
}
if (subtractMeanImage) {
createMeanImageVirtualBand(sourceProduct, sourceBandNames, meanImageBandName);
}
}
/**
* Create mean image as a virtual band from user selected bands.
*
* @param sourceProduct The source product.
* @param sourceBandNames The user selected band names.
* @param meanImageBandName The mean image band name.
*/
private static void createMeanImageVirtualBand(final Product sourceProduct,
final String[] sourceBandNames, final String meanImageBandName) {
if (sourceProduct.getBand(meanImageBandName) != null) {
return;
}
boolean isFirstBand = true;
String unit = "";
String expression = "( ";
for (String bandName : sourceBandNames) {
if (isFirstBand) {
expression += bandName;
unit = sourceProduct.getBand(bandName).getUnit();
isFirstBand = false;
} else {
expression += " + " + bandName;
}
}
expression += " ) / " + sourceBandNames.length;
final VirtualBand band = new VirtualBand(meanImageBandName,
ProductData.TYPE_FLOAT32,
sourceProduct.getSceneRasterWidth(),
sourceProduct.getSceneRasterHeight(),
expression);
band.setUnit(unit);
band.setDescription("Mean image");
sourceProduct.addBand(band);
}
/**
* Called by the framework in order to compute a tile for the given target band.
* <p>The default implementation throws a runtime exception with the message "not implemented".</p>
*
* @param targetTileMap The target tiles associated with all target bands to be computed.
* @param targetRectangle The rectangle of target tile.
* @param pm A progress monitor which should be used to determine computation cancelation requests.
* @throws OperatorException If an error occurs during computation of the target raster.
*/
@Override
public void computeTileStack(Map<Band, Tile> targetTileMap, Rectangle targetRectangle, ProgressMonitor pm)
throws OperatorException {
try {
final int x0 = targetRectangle.x;
final int y0 = targetRectangle.y;
final int w = targetRectangle.width;
final int h = targetRectangle.height;
//System.out.println("x0 = " + x0 + ", y0 = " + y0 + ", w = " + w + ", h = " + h);
if (!statsCalculated) {
calculateStatistics();
}
final ProductData[] bandsRawSamples = new ProductData[numOfSourceBands];
for (int i = 0; i < numOfSourceBands; i++) {
bandsRawSamples[i] =
getSourceTile(sourceProduct.getBand(sourceBandNames[i]), targetRectangle).getRawSamples();
}
for (int i = 0; i < numPCA; i++) {
final Band targetBand = targetProduct.getBand("PC" + i);
final Tile targetTile = targetTileMap.get(targetBand);
final ProductData trgData = targetTile.getDataBuffer();
final TileIndex targetIndex = new TileIndex(targetTile);
int index;
int k = 0;
for (int y = y0; y < y0 + h; y++) {
targetIndex.calculateStride(y);
for (int x = x0; x < x0 + w; x++) {
index = targetIndex.getIndex(x);
double vPCA = 0.0;
for (int j = 0; j < numOfSourceBands; j++) {
vPCA += bandsRawSamples[j].getElemDoubleAt(k) * eigenVectorMatrices[j][i];
}
k++;
trgData.setElemDoubleAt(index, vPCA - minPCA[i]);
}
}
}
} catch (Throwable e) {
throw new OperatorException(e);
} finally {
pm.done();
}
pcaImageComputed = true;
}
private synchronized void calculateStatistics() {
if (statsCalculated) {
return;
}
final Dimension tileSize = new Dimension(256, 256);
final Rectangle[] tileRectangles = getAllTileRectangles(sourceProduct, tileSize);
processStatistics(tileRectangles);
processMin(tileRectangles);
statsCalculated = true;
}
/**
* Get an array of rectangles for all source tiles of the image
*
* @param sourceProduct the input product
* @param tileSize the rect sizes
* @return Array of rectangles
*/
private static Rectangle[] getAllTileRectangles(final Product sourceProduct, final Dimension tileSize) {
final int rasterHeight = sourceProduct.getSceneRasterHeight();
final int rasterWidth = sourceProduct.getSceneRasterWidth();
final Rectangle boundary = new Rectangle(rasterWidth, rasterHeight);
final int tileCountX = MathUtils.ceilInt(boundary.width / (double) tileSize.width);
final int tileCountY = MathUtils.ceilInt(boundary.height / (double) tileSize.height);
final Rectangle[] rectangles = new Rectangle[tileCountX * tileCountY];
int index = 0;
for (int tileY = 0; tileY < tileCountY; tileY++) {
for (int tileX = 0; tileX < tileCountX; tileX++) {
final Rectangle tileRectangle = new Rectangle(tileX * tileSize.width,
tileY * tileSize.height,
tileSize.width,
tileSize.height);
final Rectangle intersection = boundary.intersection(tileRectangle);
rectangles[index] = intersection;
index++;
}
}
return rectangles;
}
private void processStatistics(final Rectangle[] tileRectangles) {
final StatusProgressMonitor status = new StatusProgressMonitor(StatusProgressMonitor.TYPE.SUBTASK);
status.beginTask("Computing Statistics... ", tileRectangles.length);
final ThreadManager threadManager = new ThreadManager();
try {
for (final Rectangle rectangle : tileRectangles) {
Thread worker = new Thread() {
final ProductData[] bandsRawSamples = new ProductData[numOfSourceBands];
final double[] tileSum = new double[numOfSourceBands];
final double[][] tileSumCross = new double[numOfSourceBands][numOfSourceBands];
@Override
public void run() {
for (int i = 0; i < numOfSourceBands; i++) {
bandsRawSamples[i] =
getSourceTile(sourceProduct.getBand(sourceBandNames[i]), rectangle).getRawSamples();
}
if (subtractMeanImage) {
final ProductData meanBandRawSamples =
getSourceTile(sourceProduct.getBand(meanImageBandName), rectangle).getRawSamples();
computeTileStatisticsWithMeanImageSubstract(numOfSourceBands,
bandsRawSamples, meanBandRawSamples, tileSum, tileSumCross);
} else {
computeTileStatisticsWithoutMeanImageSubstract(numOfSourceBands,
bandsRawSamples, tileSum, tileSumCross);
}
synchronized (sum) {
computeImageStatistics(tileSum, tileSumCross);
}
}
};
threadManager.add(worker);
status.worked(1);
}
threadManager.finish();
completeStatistics();
} catch (Throwable e) {
throw new OperatorException(e);
} finally {
status.done();
}
}
private void processMin(final Rectangle[] tileRectangles) {
final StatusProgressMonitor status = new StatusProgressMonitor(StatusProgressMonitor.TYPE.SUBTASK);
status.beginTask("Computing Min... ", tileRectangles.length);
final ThreadManager threadManager = new ThreadManager();
try {
initializeMin();
for (final Rectangle rectangle : tileRectangles) {
Thread worker = new Thread() {
final double[] tileMinPCA = new double[numOfSourceBands];
final ProductData[] bandsRawSamples = new ProductData[numOfSourceBands];
@Override
public void run() {
for (int i = 0; i < numOfSourceBands; i++) {
bandsRawSamples[i] =
getSourceTile(sourceProduct.getBand(sourceBandNames[i]), rectangle).getRawSamples();
}
final int n = bandsRawSamples[0].getNumElems();
Arrays.fill(tileMinPCA, Double.MAX_VALUE);
for (int i = 0; i < numPCA; i++) {
for (int k = 0; k < n; k++) {
double vPCA = 0.0;
for (int j = 0; j < numOfSourceBands; j++) {
vPCA += bandsRawSamples[j].getElemDoubleAt(k) * eigenVectorMatrices[j][i];
}
if (vPCA < tileMinPCA[i])
tileMinPCA[i] = vPCA;
}
}
synchronized (minPCA) {
computePCAMin(tileMinPCA);
}
}
};
threadManager.add(worker);
status.worked(1);
}
threadManager.finish();
} catch (Throwable e) {
throw new OperatorException(e);
} finally {
status.done();
}
}
/**
* Compute summation and cross-summation for all bands for a given tile.
*
* @param numOfSourceBands numnber of bands
* @param bandsRawSamples The raw data for all bands for the given tile.
* @param tileSum The summation for all bands for the given tile.
* @param tileSumCross The cross-summation for all bands for the given tile.
*/
private static void computeTileStatisticsWithoutMeanImageSubstract(final int numOfSourceBands,
final ProductData[] bandsRawSamples, final double[] tileSum, final double[][] tileSumCross) {
Arrays.fill(tileSum, 0.0);
final int n = bandsRawSamples[0].getNumElems();
double vi, vj;
for (int i = 0; i < numOfSourceBands; i++) {
Arrays.fill(tileSumCross[i], 0.0);
for (int j = 0; j <= i; j++) {
//System.out.println("i = " + i + ", j = " + j);
if (j < i) {
for (int k = 0; k < n; k++) {
vi = bandsRawSamples[i].getElemDoubleAt(k);
vj = bandsRawSamples[j].getElemDoubleAt(k);
tileSumCross[i][j] += vi * vj;
}
} else { // j == i
for (int k = 0; k < n; k++) {
vi = bandsRawSamples[i].getElemDoubleAt(k);
tileSum[i] += vi;
tileSumCross[i][j] += vi * vi;
}
}
}
}
}
/**
* Compute summation and cross-summation for all bands for a given tile with mean image substracted.
*
* @param numOfSourceBands numnber of bands
* @param bandsRawSamples The raw data for all bands for the given tile.
* @param meanBandRawSamples The raw data for the band of mean image for the given tile.
* @param tileSum The summation for all bands for the given tile.
* @param tileSumCross The cross-summation for all bands for the given tile.
*/
private static void computeTileStatisticsWithMeanImageSubstract(
final int numOfSourceBands,
final ProductData[] bandsRawSamples, final ProductData meanBandRawSamples,
final double[] tileSum, final double[][] tileSumCross) {
Arrays.fill(tileSum, 0.0);
final int n = bandsRawSamples[0].getNumElems();
double vi, vj, vm;
for (int i = 0; i < numOfSourceBands; i++) {
Arrays.fill(tileSumCross[i], 0.0);
for (int j = 0; j <= i; j++) {
//System.out.println("i = " + i + ", j = " + j);
if (j < i) {
for (int k = 0; k < n; k++) {
vm = meanBandRawSamples.getElemDoubleAt(k);
vi = bandsRawSamples[i].getElemDoubleAt(k) - vm;
vj = bandsRawSamples[j].getElemDoubleAt(k) - vm;
tileSumCross[i][j] += vi * vj;
}
} else { // j == i
for (int k = 0; k < n; k++) {
vm = meanBandRawSamples.getElemDoubleAt(k);
vi = bandsRawSamples[i].getElemDoubleAt(k) - vm;
tileSum[i] += vi;
tileSumCross[i][j] += vi * vi;
}
}
}
}
}
/**
* Compute summation and cross-summation for the whole image.
*
* @param tileSum The summation computed for each tile.
* @param tileSumCross The cross-summation computed for each tile.
*/
private void computeImageStatistics(final double[] tileSum, final double[][] tileSumCross) {
for (int i = 0; i < numOfSourceBands; i++) {
for (int j = 0; j <= i; j++) {
if (j < i) {
sumCross[i][j] += tileSumCross[i][j];
} else { // j == i
sum[i] += tileSum[i];
sumCross[i][j] += tileSumCross[i][j];
}
}
}
}
private void completeStatistics() {
for (int i = 0; i < numOfSourceBands; i++) {
mean[i] = sum[i] / numOfPixels;
for (int j = 0; j <= i; j++) {
meanCross[i][j] = sumCross[i][j] / numOfPixels;
if (j != i) {
meanCross[j][i] = meanCross[i][j];
}
}
}
}
/////////////
// Min
/**
* Set initial values to some internal variables.
*/
private void initializeMin() {
minPCA = new double[numOfSourceBands];
for (int i = 0; i < numOfSourceBands; i++) {
minPCA[i] = Double.MAX_VALUE;
}
computeEigenDecompositionOfCovarianceMatrix();
}
/**
* Compute minimum values for all PCA images.
*
* @param tileMinPCA The minimum values for all PCA images for a given tile.
*/
private void computePCAMin(final double[] tileMinPCA) {
for (int i = 0; i < numPCA; i++) {
if (tileMinPCA[i] < minPCA[i]) {
minPCA[i] = tileMinPCA[i];
}
}
}
/**
* Compute covariance matrices and perform EVD on each of them.
*/
private void computeEigenDecompositionOfCovarianceMatrix() {
eigenVectorMatrices = new double[numOfSourceBands][numOfSourceBands];
eigenValues = new double[numOfSourceBands];
final double[][] cov = new double[numOfSourceBands][numOfSourceBands];
for (int i = 0; i < numOfSourceBands; i++) {
for (int j = 0; j < numOfSourceBands; j++) {
cov[i][j] = meanCross[i][j] - mean[i] * mean[j];
}
}
final Matrix Cov = new Matrix(cov);
final SingularValueDecomposition Svd = Cov.svd(); // Cov = USV'
final Matrix S = Svd.getS();
final Matrix U = Svd.getU();
//final Matrix V = Svd.getV();
totalEigenvalues = 0.0;
for (int i = 0; i < numOfSourceBands; i++) {
eigenValues[i] = S.get(i, i);
totalEigenvalues += eigenValues[i];
for (int j = 0; j < numOfSourceBands; j++) {
eigenVectorMatrices[i][j] = U.get(i, j);
}
}
if (selectEigenvaluesBy.equals(EIGENVALUE_THRESHOLD)) {
double sum = 0.0;
for (int i = 0; i < numOfSourceBands; i++) {
sum += eigenValues[i];
if (sum / totalEigenvalues >= eigenvalueThreshold) {
numPCA = i + 1;
break;
}
}
}
}
/**
* Compute statistics for the whole image.
*/
@Override
public void dispose() {
if (!pcaImageComputed) {
return;
}
createReportFile();
}
private void createReportFile() {
final File reportFile = new File(ResourceUtils.getReportFolder(), sourceProduct.getName() + "_pca_report.txt");
try {
final FileOutputStream out = new FileOutputStream(reportFile);
// Connect print stream to the output stream
final PrintStream p = new PrintStream(out);
p.println();
p.println("User Selected Bands: ");
for (int i = 0; i < numOfSourceBands; i++) {
p.println(" " + sourceBandNames[i]);
}
p.println();
if (selectEigenvaluesBy.equals(EIGENVALUE_THRESHOLD)) {
p.println("User Input Eigenvalue Threshold: " + eigenvalueThreshold + " %");
p.println();
}
p.println("Number of PCA Images Output: " + numPCA);
p.println();
p.println("Normalized Eigenvalues: ");
for (int i = 0; i < numOfSourceBands; i++) {
p.println(" " + eigenValues[i]);
}
p.println();
p.close();
if (showEigenvalues) {
Desktop.getDesktop().edit(reportFile);
}
} catch (IOException exc) {
throw new OperatorException(exc);
}
}
/**
* The SPI is used to register this operator in the graph processing framework
* via the SPI configuration file
* {@code META-INF/services/org.esa.snap.core.gpf.OperatorSpi}.
* This class may also serve as a factory for new operator instances.
*
* @see OperatorSpi#createOperator()
* @see OperatorSpi#createOperator(java.util.Map, java.util.Map)
*/
public static class Spi extends OperatorSpi {
public Spi() {
super(PCAOp.class);
}
}
}