package org.bouncycastle.math.ec.test;
import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import org.bouncycastle.asn1.x9.ECNamedCurveTable;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.crypto.ec.CustomNamedCurves;
import org.bouncycastle.math.ec.ECAlgorithms;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.math.ec.ECPoint;
public class ECAlgorithmsTest extends TestCase
{
private static final int SCALE = 4;
private static final SecureRandom RND = new SecureRandom();
public void testSumOfMultiplies()
{
X9ECParameters x9 = CustomNamedCurves.getByName("secp256r1");
assertNotNull(x9);
doTestSumOfMultiplies(x9);
}
// TODO Ideally, mark this test not to run by default
public void testSumOfMultipliesComplete()
{
ArrayList x9s = getTestCurves();
Iterator it = x9s.iterator();
while (it.hasNext())
{
X9ECParameters x9 = (X9ECParameters)it.next();
doTestSumOfMultiplies(x9);
}
}
public void testSumOfTwoMultiplies()
{
X9ECParameters x9 = CustomNamedCurves.getByName("secp256r1");
assertNotNull(x9);
doTestSumOfTwoMultiplies(x9);
}
// TODO Ideally, mark this test not to run by default
public void testSumOfTwoMultipliesComplete()
{
ArrayList x9s = getTestCurves();
Iterator it = x9s.iterator();
while (it.hasNext())
{
X9ECParameters x9 = (X9ECParameters)it.next();
doTestSumOfTwoMultiplies(x9);
}
}
private void doTestSumOfMultiplies(X9ECParameters x9)
{
ECPoint[] points = new ECPoint[SCALE];
BigInteger[] scalars = new BigInteger[SCALE];
for (int i = 0; i < SCALE; ++i)
{
points[i] = getRandomPoint(x9);
scalars[i] = getRandomScalar(x9);
}
ECPoint u = x9.getCurve().getInfinity();
for (int i = 0; i < SCALE; ++i)
{
u = u.add(points[i].multiply(scalars[i]));
ECPoint v = ECAlgorithms.sumOfMultiplies(copyPoints(points, i + 1), copyScalars(scalars, i + 1));
ECPoint[] results = new ECPoint[]{ u, v };
x9.getCurve().normalizeAll(results);
assertPointsEqual("ECAlgorithms.sumOfMultiplies is incorrect", results[0], results[1]);
}
}
private void doTestSumOfTwoMultiplies(X9ECParameters x9)
{
ECPoint p = getRandomPoint(x9);
BigInteger a = getRandomScalar(x9);
for (int i = 0; i < SCALE; ++i)
{
ECPoint q = getRandomPoint(x9);
BigInteger b = getRandomScalar(x9);
ECPoint u = p.multiply(a).add(q.multiply(b));
ECPoint v = ECAlgorithms.shamirsTrick(p, a, q, b);
ECPoint w = ECAlgorithms.sumOfTwoMultiplies(p, a, q, b);
ECPoint[] results = new ECPoint[]{ u, v, w };
x9.getCurve().normalizeAll(results);
assertPointsEqual("ECAlgorithms.shamirsTrick is incorrect", results[0], results[1]);
assertPointsEqual("ECAlgorithms.sumOfTwoMultiplies is incorrect", results[0], results[2]);
p = q;
a = b;
}
}
private void assertPointsEqual(String message, ECPoint a, ECPoint b)
{
assertEquals(message, a, b);
}
private ECPoint[] copyPoints(ECPoint[] ps, int len)
{
ECPoint[] result = new ECPoint[len];
System.arraycopy(ps, 0, result, 0, len);
return result;
}
private BigInteger[] copyScalars(BigInteger[] ks, int len)
{
BigInteger[] result = new BigInteger[len];
System.arraycopy(ks, 0, result, 0, len);
return result;
}
private ECPoint getRandomPoint(X9ECParameters x9)
{
return x9.getG().multiply(getRandomScalar(x9));
}
private BigInteger getRandomScalar(X9ECParameters x9)
{
return new BigInteger(x9.getN().bitLength(), RND);
}
private ArrayList getTestCurves()
{
ArrayList x9s = new ArrayList();
Set names = new HashSet(AllTests.enumToList(ECNamedCurveTable.getNames()));
names.addAll(AllTests.enumToList(CustomNamedCurves.getNames()));
Iterator it = names.iterator();
while (it.hasNext())
{
String name = (String)it.next();
X9ECParameters x9 = ECNamedCurveTable.getByName(name);
if (x9 != null)
{
addTestCurves(x9s, x9);
}
x9 = CustomNamedCurves.getByName(name);
if (x9 != null)
{
addTestCurves(x9s, x9);
}
}
return x9s;
}
private void addTestCurves(ArrayList x9s, X9ECParameters x9)
{
ECCurve curve = x9.getCurve();
int[] coords = ECCurve.getAllCoordinateSystems();
for (int i = 0; i < coords.length; ++i)
{
int coord = coords[i];
if (curve.getCoordinateSystem() == coord)
{
x9s.add(x9);
}
else if (curve.supportsCoordinateSystem(coord))
{
ECCurve c = curve.configure().setCoordinateSystem(coord).create();
x9s.add(new X9ECParameters(c, c.importPoint(x9.getG()), x9.getN(), x9.getH()));
}
}
}
public static Test suite()
{
return new TestSuite(ECAlgorithmsTest.class);
}
}