Back to index...
/*
 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package sun.security.util.math.intpoly;
import sun.security.util.math.*;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
/**
 * A large number polynomial representation using sparse limbs of signed
 * long (64-bit) values. Limb values will always fit within a long, so inputs
 * to multiplication must be less than 32 bits. All IntegerPolynomial
 * implementations allow at most one addition before multiplication. Additions
 * after that will result in an ArithmeticException.
 *
 * The following element operations are branch-free for all subclasses:
 *
 * fixed
 * mutable
 * add
 * additiveInverse
 * multiply
 * square
 * subtract
 * conditionalSwapWith
 * setValue (may branch on high-order byte parameter only)
 * setSum
 * setDifference
 * setProduct
 * setSquare
 * addModPowerTwo
 * asByteArray
 *
 * All other operations may branch in some subclasses.
 *
 */
public abstract class IntegerPolynomial implements IntegerFieldModuloP {
    protected static final BigInteger TWO = BigInteger.valueOf(2);
    protected final int numLimbs;
    private final BigInteger modulus;
    protected final int bitsPerLimb;
    private final long[] posModLimbs;
    private final int maxAdds;
    /**
     * Reduce an IntegerPolynomial representation (a) and store the result
     * in a. Requires that a.length == numLimbs.
     */
    protected abstract void reduce(long[] a);
    /**
     * Multiply an IntegerPolynomial representation (a) with a long (b) and
     * store the result in an IntegerPolynomial representation in a. Requires
     * that a.length == numLimbs.
     */
    protected void multByInt(long[] a, long b) {
        for (int i = 0; i < a.length; i++) {
            a[i] *= b;
        }
        reduce(a);
    }
    /**
     * Multiply two IntegerPolynomial representations (a and b) and store the
     * result in an IntegerPolynomial representation (r). Requires that
     * a.length == b.length == r.length == numLimbs. It is allowed for a and r
     * to be the same array.
     */
    protected abstract void mult(long[] a, long[] b, long[] r);
    /**
     * Multiply an IntegerPolynomial representation (a) with itself and store
     * the result in an IntegerPolynomialRepresentation (r). Requires that
     * a.length == r.length == numLimbs. It is allowed for a and r
     * to be the same array.
     */
    protected abstract void square(long[] a, long[] r);
    IntegerPolynomial(int bitsPerLimb,
                      int numLimbs,
                      int maxAdds,
                      BigInteger modulus) {
        this.numLimbs = numLimbs;
        this.modulus = modulus;
        this.bitsPerLimb = bitsPerLimb;
        this.maxAdds = maxAdds;
        posModLimbs = setPosModLimbs();
    }
    private long[] setPosModLimbs() {
        long[] result = new long[numLimbs];
        setLimbsValuePositive(modulus, result);
        return result;
    }
    protected int getNumLimbs() {
        return numLimbs;
    }
    public int getMaxAdds() {
        return maxAdds;
    }
    @Override
    public BigInteger getSize() {
        return modulus;
    }
    @Override
    public ImmutableElement get0() {
        return new ImmutableElement(false);
    }
    @Override
    public ImmutableElement get1() {
        return new ImmutableElement(true);
    }
    @Override
    public ImmutableElement getElement(BigInteger v) {
        return new ImmutableElement(v);
    }
    @Override
    public SmallValue getSmallValue(int value) {
        int maxMag = 1 << (bitsPerLimb - 1);
        if (Math.abs(value) >= maxMag) {
            throw new IllegalArgumentException(
                "max magnitude is " + maxMag);
        }
        return new Limb(value);
    }
    /**
     * This version of encode takes a ByteBuffer that is properly ordered, and
     * may extract larger values (e.g. long) from the ByteBuffer for better
     * performance. The implementation below only extracts bytes from the
     * buffer, but this method may be overridden in field-specific
     * implementations.
     */
    protected void encode(ByteBuffer buf, int length, byte highByte,
                          long[] result) {
        int numHighBits = 32 - Integer.numberOfLeadingZeros(highByte);
        int numBits = 8 * length + numHighBits;
        int requiredLimbs = (numBits + bitsPerLimb - 1) / bitsPerLimb;
        if (requiredLimbs > numLimbs) {
            long[] temp = new long[requiredLimbs];
            encodeSmall(buf, length, highByte, temp);
            // encode does a full carry/reduce
            System.arraycopy(temp, 0, result, 0, result.length);
        } else {
            encodeSmall(buf, length, highByte, result);
        }
    }
    protected void encodeSmall(ByteBuffer buf, int length, byte highByte,
                               long[] result) {
        int limbIndex = 0;
        long curLimbValue = 0;
        int bitPos = 0;
        for (int i = 0; i < length; i++) {
            long curV = buf.get() & 0xFF;
            if (bitPos + 8 >= bitsPerLimb) {
                int bitsThisLimb = bitsPerLimb - bitPos;
                curLimbValue += (curV & (0xFF >> (8 - bitsThisLimb))) << bitPos;
                result[limbIndex++] = curLimbValue;
                curLimbValue = curV >> bitsThisLimb;
                bitPos = 8 - bitsThisLimb;
            }
            else {
                curLimbValue += curV << bitPos;
                bitPos += 8;
            }
        }
        // one more for the high byte
        if (highByte != 0) {
            long curV = highByte & 0xFF;
            if (bitPos + 8 >= bitsPerLimb) {
                int bitsThisLimb = bitsPerLimb - bitPos;
                curLimbValue += (curV & (0xFF >> (8 - bitsThisLimb))) << bitPos;
                result[limbIndex++] = curLimbValue;
                curLimbValue = curV >> bitsThisLimb;
            }
            else {
                curLimbValue += curV << bitPos;
            }
        }
        if (limbIndex < result.length) {
            result[limbIndex++] = curLimbValue;
        }
        Arrays.fill(result, limbIndex, result.length, 0);
        postEncodeCarry(result);
    }
    protected void encode(byte[] v, int offset, int length, byte highByte,
                          long[] result) {
        ByteBuffer buf = ByteBuffer.wrap(v, offset, length);
        buf.order(ByteOrder.LITTLE_ENDIAN);
        encode(buf, length, highByte, result);
    }
    // Encode does not produce compressed limbs. A simplified carry/reduce
    // operation can be used to compress the limbs.
    protected void postEncodeCarry(long[] v) {
        reduce(v);
    }
    public ImmutableElement getElement(byte[] v, int offset, int length,
                                       byte highByte) {
        long[] result = new long[numLimbs];
        encode(v, offset, length, highByte, result);
        return new ImmutableElement(result, 0);
    }
    protected BigInteger evaluate(long[] limbs) {
        BigInteger result = BigInteger.ZERO;
        for (int i = limbs.length - 1; i >= 0; i--) {
            result = result.shiftLeft(bitsPerLimb)
                .add(BigInteger.valueOf(limbs[i]));
        }
        return result.mod(modulus);
    }
    protected long carryValue(long x) {
        // compressing carry operation
        // if large positive number, carry one more to make it negative
        // if large negative number (closer to zero), carry one fewer
        return (x + (1 << (bitsPerLimb - 1))) >> bitsPerLimb;
    }
    protected void carry(long[] limbs, int start, int end) {
        for (int i = start; i < end; i++) {
            long carry = carryOut(limbs, i);
            limbs[i + 1] += carry;
        }
    }
    protected void carry(long[] limbs) {
        carry(limbs, 0, limbs.length - 1);
    }
    /**
     * Carry out of the specified position and return the carry value.
     */
    protected long carryOut(long[] limbs, int index) {
        long carry = carryValue(limbs[index]);
        limbs[index] -= (carry << bitsPerLimb);
        return carry;
    }
    private void setLimbsValue(BigInteger v, long[] limbs) {
        // set all limbs positive, and then carry
        setLimbsValuePositive(v, limbs);
        carry(limbs);
    }
    protected void setLimbsValuePositive(BigInteger v, long[] limbs) {
        BigInteger mod = BigInteger.valueOf(1 << bitsPerLimb);
        for (int i = 0; i < limbs.length; i++) {
            limbs[i] = v.mod(mod).longValue();
            v = v.shiftRight(bitsPerLimb);
        }
    }
    /**
     * Carry out of the last limb and reduce back in. This method will be
     * called as part of the "finalReduce" operation that puts the
     * representation into a fully-reduced form. It is representation-
     * specific, because representations have different amounts of empty
     * space in the high-order limb. Requires that limbs.length=numLimbs.
     */
    protected abstract void finalCarryReduceLast(long[] limbs);
    /**
     * Convert reduced limbs into a number between 0 and MODULUS-1.
     * Requires that limbs.length == numLimbs. This method only works if the
     * modulus has at most three terms.
     */
    protected void finalReduce(long[] limbs) {
        // This method works by doing several full carry/reduce operations.
        // Some representations have extra high bits, so the carry/reduce out
        // of the high position is implementation-specific. The "unsigned"
        // carry operation always carries some (negative) value out of a
        // position occupied by a negative value. So after a number of
        // passes, all negative values are removed.
        // The first pass may leave a negative value in the high position, but
        // this only happens if something was carried out of the previous
        // position. So the previous position must have a "small" value. The
        // next full carry is guaranteed not to carry out of that position.
        for (int pass = 0; pass < 2; pass++) {
            // unsigned carry out of last position and reduce in to
            // first position
            finalCarryReduceLast(limbs);
            // unsigned carry on all positions
            long carry = 0;
            for (int i = 0; i < numLimbs - 1; i++) {
                limbs[i] += carry;
                carry = limbs[i] >> bitsPerLimb;
                limbs[i] -= carry << bitsPerLimb;
            }
            limbs[numLimbs - 1] += carry;
        }
        // Limbs are positive and all less than 2^bitsPerLimb, and the
        // high-order limb may be even smaller due to the representation-
        // specific carry/reduce out of the high position.
        // The value may still be greater than the modulus.
        // Subtract the max limb values only if all limbs end up non-negative
        // This only works if there is at most one position where posModLimbs
        // is less than 2^bitsPerLimb - 1 (not counting the high-order limb,
        // if it has extra bits that are cleared by finalCarryReduceLast).
        int smallerNonNegative = 1;
        long[] smaller = new long[numLimbs];
        for (int i = numLimbs - 1; i >= 0; i--) {
            smaller[i] = limbs[i] - posModLimbs[i];
            // expression on right is 1 if smaller[i] is nonnegative,
            // 0 otherwise
            smallerNonNegative *= (int) (smaller[i] >> 63) + 1;
        }
        conditionalSwap(smallerNonNegative, limbs, smaller);
    }
    /**
     * Decode the value in v and store it in dst. Requires that v is final
     * reduced. I.e. all limbs in [0, 2^bitsPerLimb) and value in [0, modulus).
     */
    protected void decode(long[] v, byte[] dst, int offset, int length) {
        int nextLimbIndex = 0;
        long curLimbValue = v[nextLimbIndex++];
        int bitPos = 0;
        for (int i = 0; i < length; i++) {
            int dstIndex = i + offset;
            if (bitPos + 8 >= bitsPerLimb) {
                dst[dstIndex] = (byte) curLimbValue;
                curLimbValue = 0;
                if (nextLimbIndex < v.length) {
                    curLimbValue = v[nextLimbIndex++];
                }
                int bitsAdded = bitsPerLimb - bitPos;
                int bitsLeft = 8 - bitsAdded;
                dst[dstIndex] += (curLimbValue & (0xFF >> bitsAdded))
                    << bitsAdded;
                curLimbValue >>= bitsLeft;
                bitPos = bitsLeft;
            } else {
                dst[dstIndex] = (byte) curLimbValue;
                curLimbValue >>= 8;
                bitPos += 8;
            }
        }
    }
    /**
     * Add two IntegerPolynomial representations (a and b) and store the result
     * in an IntegerPolynomialRepresentation (dst). Requires that
     * a.length == b.length == dst.length. It is allowed for a and
     * dst to be the same array.
     */
    protected void addLimbs(long[] a, long[] b, long[] dst) {
        for (int i = 0; i < dst.length; i++) {
            dst[i] = a[i] + b[i];
        }
    }
    /**
     * Branch-free conditional assignment of b to a. Requires that set is 0 or
     * 1, and that a.length == b.length. If set==0, then the values of a and b
     * will be unchanged. If set==1, then the values of b will be assigned to a.
     * The behavior is undefined if swap has any value other than 0 or 1.
     */
    protected static void conditionalAssign(int set, long[] a, long[] b) {
        int maskValue = 0 - set;
        for (int i = 0; i < a.length; i++) {
            long dummyLimbs = maskValue & (a[i] ^ b[i]);
            a[i] = dummyLimbs ^ a[i];
        }
    }
    /**
     * Branch-free conditional swap of a and b. Requires that swap is 0 or 1,
     * and that a.length == b.length. If swap==0, then the values of a and b
     * will be unchanged. If swap==1, then the values of a and b will be
     * swapped. The behavior is undefined if swap has any value other than
     * 0 or 1.
     */
    protected static void conditionalSwap(int swap, long[] a, long[] b) {
        int maskValue = 0 - swap;
        for (int i = 0; i < a.length; i++) {
            long dummyLimbs = maskValue & (a[i] ^ b[i]);
            a[i] = dummyLimbs ^ a[i];
            b[i] = dummyLimbs ^ b[i];
        }
    }
    /**
     * Stores the reduced, little-endian value of limbs in result.
     */
    protected void limbsToByteArray(long[] limbs, byte[] result) {
        long[] reducedLimbs = limbs.clone();
        finalReduce(reducedLimbs);
        decode(reducedLimbs, result, 0, result.length);
    }
    /**
     * Add the reduced number corresponding to limbs and other, and store
     * the low-order bytes of the sum in result. Requires that
     * limbs.length==other.length. The result array may have any length.
     */
    protected void addLimbsModPowerTwo(long[] limbs, long[] other,
                                       byte[] result) {
        long[] reducedOther = other.clone();
        long[] reducedLimbs = limbs.clone();
        finalReduce(reducedOther);
        finalReduce(reducedLimbs);
        addLimbs(reducedLimbs, reducedOther, reducedLimbs);
        // may carry out a value which can be ignored
        long carry = 0;
        for (int i = 0; i < numLimbs; i++) {
            reducedLimbs[i] += carry;
            carry  = reducedLimbs[i] >> bitsPerLimb;
            reducedLimbs[i] -= carry << bitsPerLimb;
        }
        decode(reducedLimbs, result, 0, result.length);
    }
    private abstract class Element implements IntegerModuloP {
        protected long[] limbs;
        protected int numAdds;
        public Element(BigInteger v) {
            limbs = new long[numLimbs];
            setValue(v);
        }
        public Element(boolean v) {
            this.limbs = new long[numLimbs];
            this.limbs[0] = v ? 1l : 0l;
            this.numAdds = 0;
        }
        private Element(long[] limbs, int numAdds) {
            this.limbs = limbs;
            this.numAdds = numAdds;
        }
        private void setValue(BigInteger v) {
            setLimbsValue(v, limbs);
            this.numAdds = 0;
        }
        @Override
        public IntegerFieldModuloP getField() {
            return IntegerPolynomial.this;
        }
        @Override
        public BigInteger asBigInteger() {
            return evaluate(limbs);
        }
        @Override
        public MutableElement mutable() {
            return new MutableElement(limbs.clone(), numAdds);
        }
        protected boolean isSummand() {
            return numAdds < maxAdds;
        }
        @Override
        public ImmutableElement add(IntegerModuloP genB) {
            Element b = (Element) genB;
            if (!(isSummand() && b.isSummand())) {
                throw new ArithmeticException("Not a valid summand");
            }
            long[] newLimbs = new long[limbs.length];
            for (int i = 0; i < limbs.length; i++) {
                newLimbs[i] = limbs[i] + b.limbs[i];
            }
            int newNumAdds = Math.max(numAdds, b.numAdds) + 1;
            return new ImmutableElement(newLimbs, newNumAdds);
        }
        @Override
        public ImmutableElement additiveInverse() {
            long[] newLimbs = new long[limbs.length];
            for (int i = 0; i < limbs.length; i++) {
                newLimbs[i] = -limbs[i];
            }
            ImmutableElement result = new ImmutableElement(newLimbs, numAdds);
            return result;
        }
        protected long[] cloneLow(long[] limbs) {
            long[] newLimbs = new long[numLimbs];
            copyLow(limbs, newLimbs);
            return newLimbs;
        }
        protected void copyLow(long[] limbs, long[] out) {
            System.arraycopy(limbs, 0, out, 0, out.length);
        }
        @Override
        public ImmutableElement multiply(IntegerModuloP genB) {
            Element b = (Element) genB;
            long[] newLimbs = new long[limbs.length];
            mult(limbs, b.limbs, newLimbs);
            return new ImmutableElement(newLimbs, 0);
        }
        @Override
        public ImmutableElement square() {
            long[] newLimbs = new long[limbs.length];
            IntegerPolynomial.this.square(limbs, newLimbs);
            return new ImmutableElement(newLimbs, 0);
        }
        public void addModPowerTwo(IntegerModuloP arg, byte[] result) {
            Element other = (Element) arg;
            if (!(isSummand() && other.isSummand())) {
                throw new ArithmeticException("Not a valid summand");
            }
            addLimbsModPowerTwo(limbs, other.limbs, result);
        }
        public void asByteArray(byte[] result) {
            if (!isSummand()) {
                throw new ArithmeticException("Not a valid summand");
            }
            limbsToByteArray(limbs, result);
        }
    }
    protected class MutableElement extends Element
        implements MutableIntegerModuloP {
        protected MutableElement(long[] limbs, int numAdds) {
            super(limbs, numAdds);
        }
        @Override
        public ImmutableElement fixed() {
            return new ImmutableElement(limbs.clone(), numAdds);
        }
        @Override
        public void conditionalSet(IntegerModuloP b, int set) {
            Element other = (Element) b;
            conditionalAssign(set, limbs, other.limbs);
            numAdds = other.numAdds;
        }
        @Override
        public void conditionalSwapWith(MutableIntegerModuloP b, int swap) {
            MutableElement other = (MutableElement) b;
            conditionalSwap(swap, limbs, other.limbs);
            int numAddsTemp = numAdds;
            numAdds = other.numAdds;
            other.numAdds = numAddsTemp;
        }
        @Override
        public MutableElement setValue(IntegerModuloP v) {
            Element other = (Element) v;
            System.arraycopy(other.limbs, 0, limbs, 0, other.limbs.length);
            numAdds = other.numAdds;
            return this;
        }
        @Override
        public MutableElement setValue(byte[] arr, int offset,
                                       int length, byte highByte) {
            encode(arr, offset, length, highByte, limbs);
            this.numAdds = 0;
            return this;
        }
        @Override
        public MutableElement setValue(ByteBuffer buf, int length,
                                       byte highByte) {
            encode(buf, length, highByte, limbs);
            numAdds = 0;
            return this;
        }
        @Override
        public MutableElement setProduct(IntegerModuloP genB) {
            Element b = (Element) genB;
            mult(limbs, b.limbs, limbs);
            numAdds = 0;
            return this;
        }
        @Override
        public MutableElement setProduct(SmallValue v) {
            int value = ((Limb) v).value;
            multByInt(limbs, value);
            numAdds = 0;
            return this;
        }
        @Override
        public MutableElement setSum(IntegerModuloP genB) {
            Element b = (Element) genB;
            if (!(isSummand() && b.isSummand())) {
                throw new ArithmeticException("Not a valid summand");
            }
            for (int i = 0; i < limbs.length; i++) {
                limbs[i] = limbs[i] + b.limbs[i];
            }
            numAdds = Math.max(numAdds, b.numAdds) + 1;
            return this;
        }
        @Override
        public MutableElement setDifference(IntegerModuloP genB) {
            Element b = (Element) genB;
            if (!(isSummand() && b.isSummand())) {
                throw new ArithmeticException("Not a valid summand");
            }
            for (int i = 0; i < limbs.length; i++) {
                limbs[i] = limbs[i] - b.limbs[i];
            }
            numAdds = Math.max(numAdds, b.numAdds) + 1;
            return this;
        }
        @Override
        public MutableElement setSquare() {
            IntegerPolynomial.this.square(limbs, limbs);
            numAdds = 0;
            return this;
        }
        @Override
        public MutableElement setAdditiveInverse() {
            for (int i = 0; i < limbs.length; i++) {
                limbs[i] = -limbs[i];
            }
            return this;
        }
        @Override
        public MutableElement setReduced() {
            reduce(limbs);
            numAdds = 0;
            return this;
        }
    }
    class ImmutableElement extends Element implements ImmutableIntegerModuloP {
        protected ImmutableElement(BigInteger v) {
            super(v);
        }
        protected ImmutableElement(boolean v) {
            super(v);
        }
        protected ImmutableElement(long[] limbs, int numAdds) {
            super(limbs, numAdds);
        }
        @Override
        public ImmutableElement fixed() {
            return this;
        }
    }
    class Limb implements SmallValue {
        int value;
        Limb(int value) {
            this.value = value;
        }
    }
}
Back to index...