package mikera.vectorz.ops;
import mikera.arrayz.INDArray;
import mikera.vectorz.AVector;
import mikera.vectorz.Op;
import mikera.vectorz.util.DoubleArrays;
public final class Linear extends ALinearOp {
private final double factor;
private final double constant;
private Linear(double factor, double constant) {
this.factor=factor;
this.constant=constant;
}
public static ALinearOp create(double factor, double constant) {
if (factor==0.0) {
return Constant.create(constant);
}
if (factor==1.0) {
if (constant==0.0) return Identity.INSTANCE;
return Offset.create(constant);
}
if ((factor==-1.0)&& (constant==0.0)) {
// TODO: special negate class?
}
return new Linear(factor,constant);
}
@Override
public double apply(double x) {
return (factor*x)+constant;
}
@Override
public double applyInverse(double y) {
return (y-constant)/factor;
}
@Override
public void applyTo(INDArray v) {
v.scaleAdd(factor,constant);
}
@Override
public void applyTo(AVector v) {
v.scaleAdd(factor,constant);
}
@Override
public void applyTo(double[] data) {
DoubleArrays.scaleAdd(data, factor, constant);
}
@Override
public void applyTo(double[] data, int start,int length) {
DoubleArrays.scaleAdd(data, start, length, factor,constant);
}
@Override
public double getFactor() {
return factor;
}
@Override
public double getConstant() {
return constant;
}
@Override
public double averageValue() {
return constant;
}
@Override
public boolean hasDerivative() {
return true;
}
@Override
public double derivative(double x) {
return factor;
}
@Override
public double derivativeForOutput(double y) {
return factor;
}
@Override
public Op getDerivativeOp() {
return Constant.create(getFactor());
}
@Override
public boolean hasInverse() {
return true;
}
@Override
public ALinearOp getInverse() {
return Linear.create(1.0/factor, -constant/factor);
}
public Op compose(ALinearOp op) {
return Linear.create(factor*op.getFactor(),factor*op.getConstant()+constant);
}
@Override
public Op compose(Op op) {
if (op instanceof ALinearOp) {
return compose((ALinearOp) op);
}
return super.compose(op);
}
}