/*-
*******************************************************************************
* Copyright (c) 2011, 2014 Diamond Light Source Ltd.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Peter Chang - initial API and implementation and/or initial documentation
*******************************************************************************/
package org.eclipse.dawnsci.analysis.dataset.impl;
import java.util.Arrays;
import org.eclipse.january.dataset.Dataset;
import org.eclipse.january.dataset.DatasetFactory;
import org.eclipse.january.dataset.DatasetUtils;
import org.eclipse.january.dataset.DoubleDataset;
import org.eclipse.january.dataset.FloatDataset;
import org.eclipse.january.dataset.PositionIterator;
import org.jtransforms.dct.DoubleDCT_1D;
import org.jtransforms.dct.DoubleDCT_2D;
import org.jtransforms.dct.DoubleDCT_3D;
import org.jtransforms.dct.FloatDCT_1D;
import org.jtransforms.dct.FloatDCT_2D;
import org.jtransforms.dct.FloatDCT_3D;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
// TODO fast path for defaults?
// axes and shape correspondence
/**
* Class to hold methods to compute Discrete Cosine Transforms (DCT)
*
*/
public class DCT {
/**
* Setup the logging facilities
*/
protected static final Logger logger = LoggerFactory.getLogger(DCT.class);
/**
* forward 1D Discrete Cosine Transform (DCT-II)
* @param a dataset
* @return new dataset holding transform
*/
public static Dataset dct(final Dataset a) {
return dct(a, a.getShapeRef()[a.getRank() - 1], -1);
}
/**
* forward 1D Discrete Cosine Transform (DCT-II)
* @param a dataset
* @param n number of points
* @param axis (negative numbers refer to axes from end, eg. -1 is last axis)
* @return new dataset holding transform
*/
public static Dataset dct(final Dataset a, final int n, int axis) {
if (n <= 0) {
logger.error("number of points should be greater than zero");
throw new IllegalArgumentException("number of points should be greater than zero");
}
axis = a.checkAxis(axis);
return dct1d(a, n, axis);
}
/**
* forward 2D Discrete Cosine Transform (DCT-II)
* @param a dataset
* @param s shape of DCT dataset (if null, use whole dataset)
* @param axes for DCT (if null, default as [-2,-1])
* @return new dataset holding transform
*/
public static Dataset dct2(final Dataset a, int[] s, int[] axes) {
int rank = a.getRank();
if (rank < 2) {
logger.error("dataset should be at least 2 dimensional");
throw new IllegalArgumentException("dataset should be at least 2 dimensional");
}
if (axes == null) {
axes = new int[] {rank-2, rank-1};
} else if (axes.length != 2) {
logger.error("axes should have two entries");
throw new IllegalArgumentException("axes should have two entries");
}
if (s == null) {
s = new int[2];
int[] shape = a.getShape();
s[0] = shape[axes[0]];
s[1] = shape[axes[1]];
} else if (s.length < 2) {
logger.error("shape should not have more than 2 dimensions");
throw new IllegalArgumentException("shape should not have more than 2 dimensions");
}
return dctn(a, s, axes);
}
/**
* forward nD Discrete Cosine Transform (DCT-II)
* @param a dataset
* @param s shape of DCT dataset (if null, use whole dataset)
* @param axes for DCT (if null, default as [..., -1])
* @return new dataset holding transform
*/
public static Dataset dctn(final Dataset a, int[] s, int[] axes) {
int[] shape = a.getShape();
int rank = shape.length;
Dataset result = null;
if (s == null) {
if (axes == null) {
s = shape;
axes = new int[rank];
for (int i = 0; i < rank; i++)
axes[i] = i;
} else {
s = new int[axes.length];
Arrays.sort(axes);
for (int i = 0; i < axes.length; i++) {
axes[i] = a.checkAxis(axes[i]);
s[i] = shape[axes[i]];
}
}
} else {
if (s.length > rank) {
logger.error("shape of DCT should not have more dimensions than dataset");
throw new IllegalArgumentException("shape of DCT should not have more dimensions than dataset");
}
if (axes == null) {
axes = new int[s.length];
for (int i = 0; i < s.length; i++)
axes[i] = rank - s.length + i;
} else {
if (s.length != axes.length) {
logger.error("shape of DCT should have same rank as axes");
throw new IllegalArgumentException("shape of DCT should have same rank as axes");
}
}
}
if (s.length > 3) {
logger.error("DCT across more than 3 dimensions are not supported");
throw new IllegalArgumentException("DCT across more than 3 dimensions are not supported");
}
for (int i = 0; i < axes.length; i++) {
if (s[i] <= 0) {
logger.error("dimensions should be greater than zero");
throw new IllegalArgumentException("dimensions should be greater than zero");
}
axes[i] = a.checkAxis(axes[i]);
}
switch (s.length) {
case 1:
result = dct1d(a, s[0], axes[0]);
break;
case 2:
result = dct2d(a, s, axes);
break;
case 3:
result = dct3d(a, s, axes);
break;
}
return result;
}
private static int[] newShape(final int[] shape, final int[] s, final int[] axes) {
int[] nshape = shape.clone();
for (int i = 0; i < s.length; i++) {
nshape[axes[i]] = s[i];
}
return nshape;
}
private static Dataset dct1d(final Dataset a, final int n, final int axis) {
Dataset result = null;
Dataset dest = null;
int[] shape;
PositionIterator pi;
int[] pos;
boolean[] hit;
switch (a.getDType()) {
case Dataset.FLOAT32:
FloatDCT_1D ffft = new FloatDCT_1D(n);
shape = a.getShape().clone();
shape[axis] = n;
result = DatasetFactory.zeros(FloatDataset.class, shape);
dest = DatasetFactory.zeros(FloatDataset.class, new int[] {n});
float[] fdata = (float[]) dest.getBuffer();
pi = a.getPositionIterator(axis);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(fdata, 0.f);
a.copyItemsFromAxes(pos, hit, dest);
ffft.forward(fdata, true);
result.setItemsOnAxes(pos, hit, fdata);
}
break;
case Dataset.FLOAT64:
DoubleDCT_1D dfft = new DoubleDCT_1D(n);
shape = a.getShape().clone();
shape[axis] = n;
result = DatasetFactory.zeros(DoubleDataset.class, shape);
dest = DatasetFactory.zeros(DoubleDataset.class, n);
double[] ddata = (double[]) dest.getBuffer();
pi = a.getPositionIterator(axis);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(ddata, 0.);
a.copyItemsFromAxes(pos, hit, dest);
dfft.forward(ddata, true);
result.setItemsOnAxes(pos, hit, ddata);
}
break;
default:
logger.warn("Non-float dataset not yet supported");
break;
}
return result;
}
private static Dataset dct2d(final Dataset a, final int[] s, final int[] axes) {
Dataset result = null;
Dataset dest = null;
PositionIterator pi;
int[] pos;
boolean[] hit;
switch (a.getDType()) {
case Dataset.FLOAT32:
FloatDCT_2D ffft = new FloatDCT_2D(s[0], s[1]);
float[] fdata = null;
result = DatasetFactory.zeros(FloatDataset.class, newShape(a.getShapeRef(), s, axes));
dest = DatasetFactory.zeros(FloatDataset.class, s);
fdata = (float[]) dest.getBuffer();
pi = a.getPositionIterator(axes);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(fdata, 0.f);
a.copyItemsFromAxes(pos, hit, dest);
ffft.forward(fdata, true);
result.setItemsOnAxes(pos, hit, fdata);
}
break;
case Dataset.FLOAT64:
DoubleDCT_2D dfft = new DoubleDCT_2D(s[0], s[1]);
double[] ddata = null;
result = DatasetFactory.zeros(DoubleDataset.class, newShape(a.getShapeRef(), s, axes));
dest = DatasetFactory.zeros(DoubleDataset.class, s);
ddata = (double[]) dest.getBuffer();
pi = a.getPositionIterator(axes);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(ddata, 0.);
a.copyItemsFromAxes(pos, hit, dest);
dfft.forward(ddata, true);
result.setItemsOnAxes(pos, hit, ddata);
}
break;
default:
logger.warn("Non-float dataset not yet supported");
break;
}
return result;
}
private static Dataset dct3d(final Dataset a, final int[] s, final int[] axes) {
Dataset result = null;
Dataset dest = null;
PositionIterator pi;
int[] pos;
boolean[] hit;
switch (a.getDType()) {
case Dataset.FLOAT32:
FloatDCT_3D ffft = new FloatDCT_3D(s[0], s[1], s[2]);
float[] fdata = null;
result = DatasetFactory.zeros(FloatDataset.class, newShape(a.getShapeRef(), s, axes));
dest = DatasetFactory.zeros(FloatDataset.class, s);
fdata = (float[]) dest.getBuffer();
pi = a.getPositionIterator(axes);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(fdata, 0.f);
a.copyItemsFromAxes(pos, hit, dest);
ffft.forward(fdata, true);
result.setItemsOnAxes(pos, hit, fdata);
}
break;
case Dataset.FLOAT64:
DoubleDCT_3D dfft = new DoubleDCT_3D(s[0], s[1], s[2]);
double[] ddata = null;
result = DatasetFactory.zeros(DoubleDataset.class, newShape(a.getShapeRef(), s, axes));
dest = DatasetFactory.zeros(DoubleDataset.class, s);
ddata = (double[]) dest.getBuffer();
pi = a.getPositionIterator(axes);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(ddata, 0.);
a.copyItemsFromAxes(pos, hit, dest);
dfft.forward(ddata, true);
result.setItemsOnAxes(pos, hit, ddata);
}
break;
default:
logger.warn("Non-float dataset not yet supported");
break;
}
return result;
}
/**
* inverse 1D Discrete Cosine Transform (DCT-III)
* @param a dataset
* @return new dataset holding transform
*/
public static Dataset idct(final Dataset a) {
return idct(a, a.getShapeRef()[a.getRank() - 1], -1);
}
/**
* inverse 1D Discrete Cosine Transform (DCT-III)
* @param a dataset
* @param n number of points
* @param axis (negative numbers refer to axes from end, eg. -1 is last axis)
* @return new dataset holding transform
*/
public static Dataset idct(final Dataset a, final int n, int axis) {
if (n <= 0) {
logger.error("number of points should be greater than zero");
throw new IllegalArgumentException("number of points should be greater than zero");
}
axis = a.checkAxis(axis);
return idct1d(a, n, axis);
}
/**
* inverse 2D Discrete Cosine Transform (DCT-III)
* @param a dataset
* @param s shape of DCT dataset (if null, use whole dataset)
* @param axes for DCT (default as [-2,-1])
* @return new dataset holding transform
*/
public static Dataset idct2(final Dataset a, int[] s, int[] axes) {
int rank = a.getRank();
if (rank < 2) {
logger.error("dataset should be at least 2 dimensional");
throw new IllegalArgumentException("dataset should be at least 2 dimensional");
}
if (axes == null) {
axes = new int[] {rank-2, rank-1};
} else if (axes.length != 2) {
logger.error("axes should have two entries");
throw new IllegalArgumentException("axes should have two entries");
}
if (s == null) {
s = new int[2];
int[] shape = a.getShape();
s[0] = shape[axes[0]];
s[1] = shape[axes[1]];
} else if (s.length < 2) {
logger.error("shape should not have more than 2 dimensions");
throw new IllegalArgumentException("shape should not have more than 2 dimensions");
}
return idctn(a, s, axes);
}
/**
* inverse nD Discrete Cosine Transform (DCT-III)
* @param a dataset
* @param s shape of DCT dataset (if null, use whole dataset)
* @param axes for DCT (if null, default as [..., -1])
* @return new dataset holding transform
*/
public static Dataset idctn(final Dataset a, int[] s, int[] axes) {
int[] shape = a.getShape();
int rank = shape.length;
Dataset result = null;
if (s == null) {
if (axes == null) {
s = shape;
axes = new int[rank];
for (int i = 0; i < rank; i++)
axes[i] = i;
} else {
s = new int[axes.length];
Arrays.sort(axes);
for (int i = 0; i < axes.length; i++) {
axes[i] = a.checkAxis(axes[i]);
s[i] = shape[axes[i]];
}
}
} else {
if (s.length > rank) {
logger.error("shape of DCT should not have more dimensions than dataset");
throw new IllegalArgumentException("shape of DCT should not have more dimensions than dataset");
}
if (axes == null) {
axes = new int[s.length];
for (int i = 0; i < s.length; i++)
axes[i] = rank - s.length + i;
} else {
if (s.length != axes.length) {
logger.error("shape of DCT should have same rank as axes");
throw new IllegalArgumentException("shape of DCT should have same rank as axes");
}
}
}
if (s.length > 3) {
logger.error("DCT across more than 3 dimensions are not supported");
throw new IllegalArgumentException("DCT across more than 3 dimensions are not supported");
}
for (int i = 0; i < axes.length; i++) {
if (s[i] <= 0) {
logger.error("dimensions should be greater than zero");
throw new IllegalArgumentException("dimensions should be greater than zero");
}
axes[i] = a.checkAxis(axes[i]);
}
switch (s.length) {
case 1:
result = idct1d(a, s[0], axes[0]);
break;
case 2:
result = idct2d(a, s, axes);
break;
case 3:
result = idct3d(a, s, axes);
break;
}
return result;
}
private static Dataset idct1d(final Dataset a, final int n, final int axis) {
Dataset result = null;
Dataset dest = null;
int[] shape;
PositionIterator pi;
int[] pos;
boolean[] hit;
switch (a.getDType()) {
case Dataset.FLOAT32:
FloatDCT_1D ffft = new FloatDCT_1D(n);
float[] fdata = null;
shape = a.getShape();
shape[axis] = n;
result = DatasetFactory.zeros(FloatDataset.class, shape);
dest = DatasetFactory.zeros(FloatDataset.class, new int[] {n});
fdata = (float[]) dest.getBuffer();
pi = a.getPositionIterator(axis);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(fdata, 0.f);
a.copyItemsFromAxes(pos, hit, dest);
ffft.inverse(fdata, true);
result.setItemsOnAxes(pos, hit, fdata);
}
break;
case Dataset.FLOAT64:
DoubleDCT_1D dfft = new DoubleDCT_1D(n);
double[] ddata = null;
shape = a.getShape();
shape[axis] = n;
result = DatasetFactory.zeros(DoubleDataset.class, shape);
dest = DatasetFactory.zeros(DoubleDataset.class, n);
ddata = (double[]) dest.getBuffer();
pi = a.getPositionIterator(axis);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(ddata, 0.);
a.copyItemsFromAxes(pos, hit, dest);
dfft.inverse(ddata, true);
result.setItemsOnAxes(pos, hit, ddata);
}
break;
default:
logger.warn("Non-complex dataset not yet supported");
break;
}
return result;
}
private static Dataset idct2d(final Dataset a, final int[] s, final int[] axes) {
Dataset result = null;
Dataset dest = null;
PositionIterator pi;
int[] pos;
boolean[] hit;
switch (a.getDType()) {
case Dataset.FLOAT32:
FloatDCT_2D ffft = new FloatDCT_2D(s[0], s[1]);
float[] fdata = null;
result = DatasetFactory.zeros(FloatDataset.class, newShape(a.getShapeRef(), s, axes));
dest = DatasetFactory.zeros(FloatDataset.class, s);
fdata = (float[]) dest.getBuffer();
pi = a.getPositionIterator(axes);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(fdata, 0.f);
a.copyItemsFromAxes(pos, hit, dest);
ffft.inverse(fdata, true);
result.setItemsOnAxes(pos, hit, fdata);
}
break;
case Dataset.FLOAT64:
DoubleDCT_2D dfft = new DoubleDCT_2D(s[0], s[1]);
double[] ddata = null;
result = DatasetFactory.zeros(DoubleDataset.class, newShape(a.getShapeRef(), s, axes));
dest = DatasetFactory.zeros(DoubleDataset.class, s);
ddata = (double[]) dest.getBuffer();
pi = a.getPositionIterator(axes);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(ddata, 0.);
a.copyItemsFromAxes(pos, hit, dest);
dfft.inverse(ddata, true);
result.setItemsOnAxes(pos, hit, ddata);
}
break;
default:
logger.warn("Non-complex dataset not yet supported");
break;
}
return result;
}
private static Dataset idct3d(final Dataset a, final int[] s, final int[] axes) {
Dataset result = null;
Dataset dest = null;
PositionIterator pi;
int[] pos;
boolean[] hit;
switch (a.getDType()) {
case Dataset.FLOAT32:
FloatDCT_3D ffft = new FloatDCT_3D(s[0], s[1], s[2]);
float[] fdata = null;
result = DatasetFactory.zeros(FloatDataset.class, newShape(a.getShapeRef(), s, axes));
dest = DatasetFactory.zeros(FloatDataset.class, s);
fdata = (float[]) dest.getBuffer();
pi = a.getPositionIterator(axes);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(fdata, 0.f);
a.copyItemsFromAxes(pos, hit, dest);
ffft.inverse(fdata, true);
result.setItemsOnAxes(pos, hit, fdata);
}
break;
case Dataset.FLOAT64:
DoubleDCT_3D dfft = new DoubleDCT_3D(s[0], s[1], s[2]);
double[] ddata = null;
result = DatasetFactory.zeros(DoubleDataset.class, newShape(a.getShapeRef(), s, axes));
dest = DatasetFactory.zeros(DoubleDataset.class, s);
ddata = (double[]) dest.getBuffer();
pi = a.getPositionIterator(axes);
pos = pi.getPos();
hit = pi.getOmit();
while (pi.hasNext()) {
Arrays.fill(ddata, 0.);
a.copyItemsFromAxes(pos, hit, dest);
dfft.inverse(ddata, true);
result.setItemsOnAxes(pos, hit, ddata);
}
break;
default:
logger.warn("Non-complex dataset not yet supported");
break;
}
return result;
}
/**
* Shift zero-frequency component to centre of dataset
* @param a
* @param axes (if null, then shift all axes)
* @return shifted dataset
*/
public static Dataset dctshift(final Dataset a, int[] axes) {
int alen;
if (axes == null) {
alen = a.getRank();
axes = new int[alen];
for (int i = 0; i < alen; i++)
axes[i] = i;
} else {
alen = axes.length;
for (int i = 0; i < alen; i++)
axes[i] = a.checkAxis(axes[i]);
}
Dataset result = a;
int[] indices;
for (int i = 0; i < alen; i++) {
int axis = axes[i];
int n = a.getShapeRef()[axis];
int p = (n+1)/2;
logger.info("Shift {} by {}", axis, p);
indices = new int[n];
for (int j = 0; j < n; j++) {
if (j < n - p)
indices[j] = p + j;
else
indices[j] = j - n + p;
}
result = DatasetUtils.take(result, indices, axis);
}
return result;
}
/**
* Shift zero-frequency component to centre of dataset
* @param a
* @param axes (if null, then shift all axes)
* @return shifted dataset
*/
public static Dataset idctshift(final Dataset a, int[] axes) {
int alen;
if (axes == null) {
alen = a.getRank();
axes = new int[alen];
for (int i = 0; i < alen; i++)
axes[i] = i;
} else {
alen = axes.length;
for (int i = 0; i < alen; i++)
axes[i] = a.checkAxis(axes[i]);
}
Dataset result = a;
int[] indices;
for (int i = 0; i < alen; i++) {
int axis = axes[i];
int n = a.getShapeRef()[axis];
int p = n - (n+1)/2;
logger.info("Shift {} by {}", axis, p);
indices = new int[n];
for (int j = 0; j < n; j++) {
if (j < n - p)
indices[j] = p + j;
else
indices[j] = j - n + p;
}
result = DatasetUtils.take(result, indices, axis);
}
return result;
}
}