001: package org.bouncycastle.crypto.engines;
002:
003: import org.bouncycastle.crypto.CipherParameters;
004: import org.bouncycastle.crypto.DataLengthException;
005: import org.bouncycastle.crypto.params.ParametersWithRandom;
006: import org.bouncycastle.crypto.params.RSAKeyParameters;
007: import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters;
008:
009: import java.math.BigInteger;
010:
011: /**
012: * this does your basic RSA algorithm.
013: */
014: class RSACoreEngine {
015: private RSAKeyParameters key;
016: private boolean forEncryption;
017:
018: /**
019: * initialise the RSA engine.
020: *
021: * @param forEncryption true if we are encrypting, false otherwise.
022: * @param param the necessary RSA key parameters.
023: */
024: public void init(boolean forEncryption, CipherParameters param) {
025: if (param instanceof ParametersWithRandom) {
026: ParametersWithRandom rParam = (ParametersWithRandom) param;
027:
028: key = (RSAKeyParameters) rParam.getParameters();
029: } else {
030: key = (RSAKeyParameters) param;
031: }
032:
033: this .forEncryption = forEncryption;
034: }
035:
036: /**
037: * Return the maximum size for an input block to this engine.
038: * For RSA this is always one byte less than the key size on
039: * encryption, and the same length as the key size on decryption.
040: *
041: * @return maximum size for an input block.
042: */
043: public int getInputBlockSize() {
044: int bitSize = key.getModulus().bitLength();
045:
046: if (forEncryption) {
047: return (bitSize + 7) / 8 - 1;
048: } else {
049: return (bitSize + 7) / 8;
050: }
051: }
052:
053: /**
054: * Return the maximum size for an output block to this engine.
055: * For RSA this is always one byte less than the key size on
056: * decryption, and the same length as the key size on encryption.
057: *
058: * @return maximum size for an output block.
059: */
060: public int getOutputBlockSize() {
061: int bitSize = key.getModulus().bitLength();
062:
063: if (forEncryption) {
064: return (bitSize + 7) / 8;
065: } else {
066: return (bitSize + 7) / 8 - 1;
067: }
068: }
069:
070: public BigInteger convertInput(byte[] in, int inOff, int inLen) {
071: if (inLen > (getInputBlockSize() + 1)) {
072: throw new DataLengthException(
073: "input too large for RSA cipher.");
074: } else if (inLen == (getInputBlockSize() + 1) && !forEncryption) {
075: throw new DataLengthException(
076: "input too large for RSA cipher.");
077: }
078:
079: byte[] block;
080:
081: if (inOff != 0 || inLen != in.length) {
082: block = new byte[inLen];
083:
084: System.arraycopy(in, inOff, block, 0, inLen);
085: } else {
086: block = in;
087: }
088:
089: BigInteger res = new BigInteger(1, block);
090: if (res.compareTo(key.getModulus()) >= 0) {
091: throw new DataLengthException(
092: "input too large for RSA cipher.");
093: }
094:
095: return res;
096: }
097:
098: public byte[] convertOutput(BigInteger result) {
099: byte[] output = result.toByteArray();
100:
101: if (forEncryption) {
102: if (output[0] == 0 && output.length > getOutputBlockSize()) // have ended up with an extra zero byte, copy down.
103: {
104: byte[] tmp = new byte[output.length - 1];
105:
106: System.arraycopy(output, 1, tmp, 0, tmp.length);
107:
108: return tmp;
109: }
110:
111: if (output.length < getOutputBlockSize()) // have ended up with less bytes than normal, lengthen
112: {
113: byte[] tmp = new byte[getOutputBlockSize()];
114:
115: System.arraycopy(output, 0, tmp, tmp.length
116: - output.length, output.length);
117:
118: return tmp;
119: }
120: } else {
121: if (output[0] == 0) // have ended up with an extra zero byte, copy down.
122: {
123: byte[] tmp = new byte[output.length - 1];
124:
125: System.arraycopy(output, 1, tmp, 0, tmp.length);
126:
127: return tmp;
128: }
129: }
130:
131: return output;
132: }
133:
134: public BigInteger processBlock(BigInteger input) {
135: if (key instanceof RSAPrivateCrtKeyParameters) {
136: //
137: // we have the extra factors, use the Chinese Remainder Theorem - the author
138: // wishes to express his thanks to Dirk Bonekaemper at rtsffm.com for
139: // advice regarding the expression of this.
140: //
141: RSAPrivateCrtKeyParameters crtKey = (RSAPrivateCrtKeyParameters) key;
142:
143: BigInteger p = crtKey.getP();
144: BigInteger q = crtKey.getQ();
145: BigInteger dP = crtKey.getDP();
146: BigInteger dQ = crtKey.getDQ();
147: BigInteger qInv = crtKey.getQInv();
148:
149: BigInteger mP, mQ, h, m;
150:
151: // mP = ((input mod p) ^ dP)) mod p
152: mP = (input.remainder(p)).modPow(dP, p);
153:
154: // mQ = ((input mod q) ^ dQ)) mod q
155: mQ = (input.remainder(q)).modPow(dQ, q);
156:
157: // h = qInv * (mP - mQ) mod p
158: h = mP.subtract(mQ);
159: h = h.multiply(qInv);
160: h = h.mod(p); // mod (in Java) returns the positive residual
161:
162: // m = h * q + mQ
163: m = h.multiply(q);
164: m = m.add(mQ);
165:
166: return m;
167: } else {
168: return input.modPow(key.getExponent(), key.getModulus());
169: }
170: }
171: }
|