/*******************************************************************************
* $Id: $
* Copyright (c) 2009-2010 Tim Tiemens.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the GNU Lesser Public License v2.1
* which accompanies this distribution, and is available at
* http://www.gnu.org/licenses/old-licenses/lgpl-2.1.html
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
*
* Contributors:
* Tim Tiemens - initial API and implementation
******************************************************************************/
package com.aegiswallet.helpers.secretshare.math;
import com.aegiswallet.helpers.secretshare.SecretShareException;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* "Easy" implementation of linear equation solver.
* <p/>
* <p/>
* Example: given 3 equations like:
* 1491 = 83a + 32b + 22c
* 329 = 5a + 13b + 22c
* 122 = 3a + 2b + 19c
* The goal is to solve for a, b, and c [3 unknowns, 3 equations].
* <p/>
* The problem above is encoded into a matrix of numbers like:
* 1491 83 32 22
* 329 5 13 22
* 122 3 2 19
* and stored in this class as a List<Row> objects.
* <p/>
* Then, this class can "solve" it into the 'diagonal 1', giving:
* 11 1 0 0
* 16 0 1 0
* 3 0 0 1
* Which in turn means that a=11, b=16 and c=3
* <p/>
* <p/>
* This implementation is called "easy" because it is really straight-forward,
* but also really inefficient.
* Values are "canceled" by multiplying the two together.
* e.g. "8" and "4" just needs the "4" to be multiplied by 2, then subtracted.
* This implementation takes 8*4 and subtracts 4*8 because that just works, and you don't
* need to compute the least common multiple.
* <p/>
* This implementation is also "easy" since it doesn't use any lin-eq library.
* There are a lot of those libraries available: it turns out it was easier
* to write this class than to figure out how to use the horrible APIs they presented.
* [The ones with good APIs didn't support BigInteger]
*
* @author tiemens
*/
public class EasyLinearEquation {
// ==================================================
// class static data
// ==================================================
// want to turn on debug? See EasyLinearEquationUT.enableLogging()
public static Logger logger = Logger.getLogger(EasyLinearEquation.class.getName());
// ==================================================
// class static methods
// ==================================================
public static BigInteger[][] convertIntToBigInteger(int[][] inMatrix) {
BigInteger[][] cvt = new BigInteger[inMatrix.length][];
for (int i = 0, n = inMatrix.length; i < n; i++) {
cvt[i] = new BigInteger[inMatrix[i].length];
for (int c = 0, rn = inMatrix[i].length; c < rn; c++) {
cvt[i][c] = BigInteger.valueOf(inMatrix[i][c]);
}
}
return cvt;
}
// ==================================================
// instance data
// ==================================================
private final List<Row> rows;
// 'modulus' can be null, which means do not perform mod() on values
private final BigInteger modulus;
// ==================================================
// factories
// ==================================================
/**
* Create solver for polynomial equations.
* <p/>
* Polynomial linear equations are a special case, because the C,x,x^2,x^3 coefficients
* can be turned into the rows we need by being given:
* a) which "x" values were used
* b) what "constant" values were computed
* This information happens to be exactly what the holder of a "Secret" in
* "Secret Sharing" has been given.
* So this constructor can be used to recover the secret if
* given enough of the secrets.
*
* @param xarray the "X" values
* @param fofxarray the "f(x)" values
* @return instance for solving this special case
*/
public static EasyLinearEquation createForPolynomial(final BigInteger[] xarray,
final BigInteger[] fofxarray) {
if (xarray.length != fofxarray.length) {
throw new SecretShareException("Unequal length arrays are not allowed");
}
final int numRows = xarray.length;
final int numCols = xarray.length + 1;
BigInteger[][] cvt = new BigInteger[numRows][];
for (int row = 0; row < numRows; row++) {
cvt[row] = new BigInteger[numCols];
fillInPolynomial(cvt[row],
fofxarray[row],
xarray[row]);
}
return create(cvt);
}
/**
* Convenience factory to create an instance with 'int's instead of BigIntegers.
*
* @param inMatrix given in 'int's
* @return instance
*/
public static EasyLinearEquation create(int[][] inMatrix) {
BigInteger[][] cvt = convertIntToBigInteger(inMatrix);
return create(cvt);
}
/**
* Most typical factory, for BigInteger arrays.
*
* @param inMatrix given in BigIntegers, where the first column is the constant
* and all the other columns are the variables
* @return instance
*/
public static EasyLinearEquation create(BigInteger[][] inMatrix) {
EasyLinearEquation ret = null;
final int width = inMatrix[0].length;
for (BigInteger[] row : inMatrix) {
if (width != row.length) {
throw new SecretShareException("All rows must be " +
width + " wide");
}
}
List<Row> rows = new ArrayList<Row>();
for (BigInteger[] row : inMatrix) {
Row add = Row.create(row);
rows.add(add);
}
ret = new EasyLinearEquation(rows);
return ret;
}
// ==================================================
// constructors
// ==================================================
private EasyLinearEquation(final List<Row> inRows) {
this(inRows, null);
}
private EasyLinearEquation(final List<Row> inRows,
final BigInteger inModulus) {
rows = new ArrayList<Row>();
rows.addAll(inRows);
modulus = inModulus;
}
public EasyLinearEquation createWithPrimeModulus(BigInteger primeModulus) {
if (primeModulus != null) {
return new EasyLinearEquation(this.rows, primeModulus);
} else {
throw new SecretShareException("modulus cannot be null");
}
}
// ==================================================
// public methods
// ==================================================
public EasySolve solve() {
EasySolve ret = null;
List<Row> solverows = new ArrayList<Row>();
solverows.addAll(rows);
debugRows("Initial rows", solverows, modulus);
for (int workrowindex = 0, maxindex = solverows.size(); workrowindex < maxindex; workrowindex++) {
Row otherrow = solverows.get(workrowindex);
for (int fixindex = workrowindex + 1; fixindex < maxindex; fixindex++) {
int columnIndexToCancel = workrowindex + 1;
Row cancelrowr = solverows.get(fixindex).cancelColumn(columnIndexToCancel,
otherrow,
modulus);
solverows.set(fixindex, cancelrowr);
}
debugRows("after workrowindex=" + workrowindex + " finished", solverows, modulus);
}
debugRows("after all loops", solverows, modulus);
//
// the matrix should look like this now:
//
// 33 a b c
// -51 0 d e
// -13 0 0 f
// so, start at the bottom, and solve and cancel the other direction:
for (int workrowindex = solverows.size() - 1; workrowindex >= 0; workrowindex--) {
Row reducedToOne = solverows.get(workrowindex).solveThisRow(modulus);
logger.fine("reverse, index=" + workrowindex + " is " + reducedToOne.debugRow());
solverows.set(workrowindex, reducedToOne);
for (int fixindex = workrowindex - 1; fixindex >= 0; fixindex--) {
int columnIndexToCancel = workrowindex + 1;
logger.finer(" going to cancel fixindex=" + fixindex + " is " +
solverows.get(fixindex).debugRow() + " using row " +
reducedToOne.debugRow());
Row cancelrowr = solverows.get(fixindex).cancelColumn(columnIndexToCancel,
reducedToOne,
modulus);
solverows.set(fixindex, cancelrowr);
}
debugRows("After reverse loopindex=" + workrowindex + " finished", solverows, modulus);
}
//
// the matrix should look like this now:
//
// 3 1 0 0
// -5 0 1 0
// -3 0 0 1
BigInteger[] answers = new BigInteger[solverows.size() + 1];
answers[0] = null;
for (int i = 1, n = answers.length; i < n; i++) {
answers[i] = solverows.get(i - 1).getColumn(0);
}
ret = new EasySolve(answers);
return ret;
}
private void debugRows(String where,
List<Row> solverows,
BigInteger modulus) {
// want to turn on debug? See EasyLinearEquationUT.enableLogging()
if (logger.isLoggable(Level.FINE)) {
logger.fine(where + " (modulus=" + modulus + ")");
for (Row row : solverows) {
logger.fine(row.debugRow());
}
}
}
// ==================================================
// non public methods
// ==================================================
public static class EasySolve {
private final BigInteger[] answers;
public EasySolve(BigInteger[] inAnswers) {
answers = new BigInteger[inAnswers.length];
System.arraycopy(inAnswers, 0, answers, 0, answers.length);
}
public BigInteger getAnswer(int i) {
if (i < 0) {
throw new SecretShareException("Answer index cannot be negative: " + i);
}
if (i == 0) {
throw new SecretShareException("Answer index 0 is the constant, not an answer." +
" Use range 1-n [not 0-n-1]");
}
return answers[i];
}
}
private static class Trial {
private final String which;
private final BigInteger result;
private final boolean correct;
public Trial(final String inWhich,
final BigInteger inOriginal,
final BigInteger inDivideby) {
which = inWhich;
result = inOriginal.divide(inDivideby);
correct = result.multiply(inDivideby).equals(inOriginal);
}
public BigInteger getResult() {
if (correct) {
return result;
} else {
throw new SecretShareException("Tried to get result from non-correct trial");
}
}
public String dumpDebug() {
return "Trial[" + which + " result=" + result;
}
/**
* Construct all the permutations we need.
*/
public static List<Trial> createList(final BigInteger original,
final BigInteger divideby,
final BigInteger useModulus) {
List<Trial> list = new ArrayList<Trial>();
BigInteger o = original;
int c = 0;
Trial trial = new Trial("" + c, o, divideby);
list.add(trial);
boolean somethingBroke = false;
while (!trial.correct) {
c++;
if (c > 10000) {
somethingBroke = true;
break;
}
o = o.add(useModulus);
trial = new Trial("" + c, o, divideby);
}
if (somethingBroke) {
System.out.format("ERROR\noriginal %80s\n" +
"dividedby %80s\n" +
"modulus %80s\n",
original,
divideby,
useModulus);
}
// c = 0;
// while (! trial.correct)
// {
// c++;
// if (c > 10000)
// {
// throw new SecretShareException("two loop failure");
// }
// o = o.subtract(useModulus);
// trial = new Trial("" + c, o, divideby);
// }
list.add(trial);
return list;
}
/**
* Pick the "best" correct result [if any].
*/
public static Trial pickSuccess(List<Trial> list) {
if (list.get(list.size() - 1).correct) {
return list.get(list.size() - 1);
} else {
System.out.println("WARN: trial[0] did not succeed.");
for (Trial ret : list) {
if (ret.correct) {
return ret;
}
}
throw new SecretShareException("Programmer error, no trial correct, list.size=" + list.size());
}
}
/**
* Original implementation. Wrong.
*/
public static List<Trial> createList2(final BigInteger original,
final BigInteger divideby,
final BigInteger useModulus) {
List<Trial> list = new ArrayList<Trial>();
list.add(new Trial("original", original, divideby));
list.add(new Trial("modoriginal", original.mod(useModulus), divideby));
list.add(new Trial("moddivide", original, divideby.mod(useModulus)));
list.add(new Trial("mod both", original.mod(useModulus), divideby.mod(useModulus)));
BigInteger gcd = original.gcd(divideby);
if ((gcd != null) &&
(gcd.compareTo(BigInteger.ONE) > 0)) {
BigInteger divO = original.divide(gcd);
BigInteger divD = divideby.divide(gcd);
list.add(new Trial("gcd original", divO, divD));
list.add(new Trial("gcd modoriginal", divO.mod(useModulus), divD));
list.add(new Trial("gcd moddivide", divO, divD.mod(useModulus)));
list.add(new Trial("gcd both", divO.mod(useModulus), divD.mod(useModulus)));
}
return list;
}
/**
* Original implementation. Wrong.
*/
public static Trial pickSuccess2(List<Trial> list) {
Trial ret = null;
for (Trial t : list) {
if (t.correct) {
if (ret != null) {
if (t.result.compareTo(ret.result) > 0) {
ret = t;
} else {
System.out.println("Two different correct answers");
}
} else {
ret = t;
}
}
}
return ret;
}
}
private static class Row {
private final BigInteger[] cols;
public static Row create(BigInteger[] in) {
return new Row(in);
}
public Row(Row copy) {
this(copy.cols);
}
private Row(BigInteger[] in) {
cols = new BigInteger[in.length];
System.arraycopy(in, 0, cols, 0, cols.length);
}
/**
* @return row with col[0] non-zero and one-and-only-one other column non-zero,
* all the other columns must be ZERO OR
* throw exception
* @throws SecretShareException if more than 2 columns [the 1st and 1 other] are non-zero
*/
public Row solveThisRow(final BigInteger useModulus) {
// Determine non-zero column:
Integer nonZeroColumn = null;
for (int col = 1, n = cols.length; col < n; col++) {
if (!this.isColumnZero(col)) {
if (nonZeroColumn != null) {
logger.severe("Row cannot be solved:\n" + debugRow());
throw new SecretShareException("Two columns are non-zero, c=" +
nonZeroColumn + " and c=" + col);
} else {
nonZeroColumn = col;
}
}
}
if (nonZeroColumn == null) {
throw new SecretShareException("No non-zero column found in row; error");
}
Row ret = new Row(this);
final BigInteger divideby = cols[nonZeroColumn];
//
// This is kind of like 'row.divideby()', except:
// a) we know only 2 cols[] are non-zero
// b) we absolutely need to make sure the result does not have a remainder,
// which means we have to "know" about the modulus sometimes
//
for (int col = 0, n = ret.cols.length; col < n; col++) {
if ((col == 0) ||
(col == nonZeroColumn)) {
BigInteger original = ret.cols[col];
BigInteger result = divideNormallyOrModulus(original, divideby, useModulus);
// this doesn't always work: result = original.divide(divideby);
ret.cols[col] = result;
} else {
// Leave the column alone. Just do a safety-check:
if (!ret.isColumnZero(col)) {
throw new SecretShareException("Programmer error. " +
"Column " + col + " must be zero, " +
"but instead is " + ret.getColumn(col));
}
}
}
return ret;
}
/**
* This _should be_ the implementation.
* The problem is, it doesn't work.
*/
@SuppressWarnings("unused")
private BigInteger divideNormallyOrModulusBroken(final BigInteger original,
final BigInteger divideby,
final BigInteger useModulus) {
BigInteger result = null;
if (useModulus == null) {
result = original.divide(divideby);
} else {
result = original.divide(divideby);
// do the "math check" before the modulus:
safetyCheckDivision(result, divideby, original);
// modfix: not proven to help reduce errors: mod down the result
//result = result.mod(useModulus);
}
return result;
}
/**
* The modulus stuff is really strange.
* Sometimes the divide-by just works.
* Sometimes it is negative but "odd", and needs mod() to positive and "even".
* Sometimes it is positive but "too big and odd" and needs mod() to a smaller "even" number.
* Sometimes it needs multiple-add-the-modulus, especially for small values of
* the modulus versus large values of the coefficients.
* <p/>
* This routine just tries a bunch of things, including:
* 1) original / divideby and
* 2) (original mod useModulus) / divideby
* See class Trial.
*
* @param original
* @param divideby
* @param useModulus
* @return
*/
private BigInteger divideNormallyOrModulus(final BigInteger original,
final BigInteger divideby,
final BigInteger useModulus) {
BigInteger result = null;
if (useModulus == null) {
result = original.divide(divideby);
} else {
// Create all of the trial "divide by" combinations
List<Trial> list = Trial.createList(original, divideby, useModulus);
// Pick the "best correct solution"
Trial success = Trial.pickSuccess(list);
if (success == null) {
throw new SecretShareException("All trial divide bys failed");
} else {
result = success.getResult();
result = result.mod(useModulus);
//
// For big (192-bit) modulus, it is almost always "0"
// For small (e.g. 59561) modulus, it ranges from 0 to 20
// run "UT.testFirst()"
if (false) {
// Debug printing
if (!"0".equals(success.which)) {
System.out.println("Trial.sucess.which=" + success.which);
}
}
}
}
// safetyCheckDivision(result, divideby, original);
return result;
}
private void safetyCheckDivision(BigInteger result,
BigInteger divideby,
BigInteger original) {
if (!result.multiply(divideby).equals(original)) {
throw new SecretShareException("division left remainder: original=" +
original + "\nDivided by=" + divideby +
"\nError.");
}
}
private boolean isColumnZero(int index) {
if (this.getColumn(index).compareTo(BigInteger.ZERO) == 0) {
return true;
} else {
return false;
}
}
public String debugRow() {
String ret = "";
String sep = "";
for (BigInteger c : cols) {
ret += sep;
sep = ",";
ret += c.toString();
}
return ret;
}
public BigInteger getColumn(final int index) {
return cols[index];
}
/**
* @param index of column to cancel (range 1-to-n)
* @param otherrow to use for the cancel operation
* @return row with column value set to "0"
*/
public Row cancelColumn(final int index,
final Row otherrow,
final BigInteger useModulus) {
// special case: our col[index] is already zero:
if (this.isColumnZero(index)) {
return this;
}
boolean samesign = this.sameSign(index, otherrow);
BigInteger mult = this.getColumn(index);
if (samesign) {
mult = mult.negate();
}
Row cancel = otherrow.multiplyConstant(mult);
mult = otherrow.getColumn(index);
Row usethis = this.multiplyConstant(mult);
if (!samesign) {
usethis = usethis.negate();
}
Row ret = null;
if (!usethis.sameSign(index, cancel)) {
ret = usethis.add(cancel);
if (useModulus != null) {
if (ret.cols[0].signum() == -1) {
// modfix: not proven to reduce errors: mod down the column value
//ret.cols[0] = ret.cols[0].mod(useModulus);
// modfix: let's try keeping cols[0] positive:
usethis = usethis.negate();
cancel = cancel.negate();
ret = usethis.add(cancel);
}
}
} else {
throw new SecretShareException("prog error this(" + index + ")=" +
this.getColumn(index) +
" other(" + index + ")=" +
cancel.getColumn(index));
}
return ret;
}
public boolean sameSign(final int index,
final Row other) {
return sameSign(this.getColumn(index), other.getColumn(index));
}
private boolean sameSign(final BigInteger one,
final BigInteger other) {
int thisc = one.compareTo(BigInteger.ZERO);
int otherc = other.compareTo(BigInteger.ZERO);
// change ZERO into positive
thisc = (thisc == 0) ? 1 : thisc;
otherc = (otherc == 0) ? 1 : otherc;
boolean ret = (thisc == otherc);
logger.finest(" samesign=" + ret + " on one=" + one +
" other=" + other);
return ret;
}
public Row multiplyConstant(final BigInteger mult) {
Row ret = new Row(this);
for (int c = 0, n = cols.length; c < n; c++) {
ret.cols[c] = this.cols[c].multiply(mult);
}
return ret;
}
public Row addConstant(final BigInteger add) {
Row ret = new Row(this);
for (int c = 0, n = cols.length; c < n; c++) {
ret.cols[c] = this.cols[c].add(add);
}
return ret;
}
public Row add(final Row add) {
Row ret = new Row(this);
for (int c = 0, n = cols.length; c < n; c++) {
ret.cols[c] = this.cols[c].add(add.getColumn(c));
}
return ret;
}
public Row negate() {
return multiplyConstant(BigInteger.valueOf(-1));
}
}
/**
* @param array to fill with values
* @param theconstant the value of the "C"
* @param x the "x" value used
*/
private static void fillInPolynomial(BigInteger[] array,
BigInteger theconstant,
BigInteger x) {
// what was f(x) becomes our constant:
array[0] = theconstant;
// what was "C" becomes an unknown, with the coefficient "1"
array[1] = BigInteger.ONE;
// the other coefficients are x, x^2, x^3, x^4 etc:
BigInteger current = BigInteger.ONE;
for (int i = 2, n = array.length; i < n; i++) {
current = current.multiply(x);
array[i] = current;
}
}
}