• Main Page
  • Related Pages
  • Modules
  • Namespaces
  • Classes
  • Files
  • File List
  • File Members

src/RSA.cpp

00001 /*
00002  *   This file is part of the Standard Portable Library (SPL).
00003  *
00004  *   SPL is free software: you can redistribute it and/or modify
00005  *   it under the terms of the GNU General Public License as published by
00006  *   the Free Software Foundation, either version 3 of the License, or
00007  *   (at your option) any later version.
00008  *
00009  *   SPL is distributed in the hope that it will be useful,
00010  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
00011  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00012  *   GNU General Public License for more details.
00013  *
00014  *   You should have received a copy of the GNU General Public License
00015  *   along with SPL.  If not, see <http://www.gnu.org/licenses/>.
00016  */
00017 #include <spl/Int32.h>
00018 #include <spl/Log.h>
00019 #include <spl/math/Math.h>
00020 #include <spl/math/Random.h>
00021 #include <spl/math/RSA.h>
00022 
00023 //static int32 FindPrime(int32 ignore_prime)
00024 //{
00025 //      int32 odd,tmp_odd;
00026 //      
00027 //      Random rand;
00028 //      
00029 //      while(1)
00030 //      {
00031 //              odd = rand.Next();
00032 //              if(odd %2 == 0)
00033 //              {
00034 //                      odd++;
00035 //              }
00036 //
00037 //              if(ignore_prime == odd) 
00038 //              {
00039 //                      continue;
00040 //              }
00041 //
00042 //              int i = 2;
00043 //
00044 //              //test whether 'odd' is prime number or not 
00045 //              for(; i <= odd / 2; i++)
00046 //              {
00047 //                      tmp_odd = odd % i;
00048 //                      if(tmp_odd == 0) 
00049 //                      {
00050 //                              i = -1;
00051 //                              break;
00052 //                      }
00053 //              }//x of 'for' loop 
00054 //              if(i != -1) 
00055 //              {
00056 //                      break;
00057 //              }
00058 //      }//x of 'while' loop 
00059 //      return odd;
00060 //}
00061 
00062 RSAKey::~RSAKey()
00063 {
00064 }
00065 
00066 #if defined(DEBUG) || defined(_DEBUG)
00067 void RSAKey::CheckMem() const
00068 {
00069         m_modulus.CheckMem();
00070         m_exponent.CheckMem();
00071 }
00072 
00073 void RSAKey::ValidateMem() const
00074 {
00075         m_modulus.ValidateMem();
00076         m_exponent.ValidateMem();
00077 }
00078 #endif
00079 
00080 RSAPrivateKey::RSAPrivateKey
00081 (
00082         int bitSize,
00083         BigInteger      modulus,
00084         BigInteger      publicExponent,
00085         BigInteger      privateExponent,
00086         BigInteger      p,
00087         BigInteger      q,
00088         BigInteger      dP,
00089         BigInteger      dQ,
00090         BigInteger      qInv
00091  )
00092  : m_bitSize(bitSize), m_modulus(modulus), m_privateExponent(privateExponent), m_publicExponent(publicExponent), m_p(p), m_q(q), m_dP(dP), m_dQ(dQ), m_qInv(qInv)
00093 {
00094 }
00095 
00096 RSAPrivateKey::RSAPrivateKey(int strength, int certainty)
00097 :       m_bitSize(strength), 
00098         m_modulus(), 
00099         m_privateExponent(), 
00100         m_publicExponent(), 
00101         m_p(), 
00102         m_q(), 
00103         m_dP(), 
00104         m_dQ(), 
00105         m_qInv()
00106 {
00107         BigInteger pSub1, qSub1, phi;
00108         Random rand;
00109 
00110         int pbitlength = (strength + 1) / 2;
00111         int qbitlength = (strength - pbitlength);
00112         int mindiffbits = strength / 3;
00113 
00114         m_publicExponent = BigInteger(24, certainty, rand);
00115         //m_publicExponent = *BigInteger::ValueOf(0x10001);
00116 
00117         //
00118         // Generate p, prime and (p-1) relatively prime to e
00119         //
00120         for (;;)
00121         {
00122                 m_p = BigInteger(pbitlength, certainty, rand);
00123                 ASSERT(m_p.GetBitLength() == pbitlength);
00124 
00125                 if (m_p.Mod(m_publicExponent)->Equals(1))
00126                 {
00127                         continue;
00128                 }
00129                 if (!m_p.IsProbablePrime(certainty))
00130                 {
00131                         continue;
00132                 }
00133                 if (m_publicExponent.Gcd(m_p.Subtract(BigInteger::One()))->Equals(1)) 
00134                 {
00135                         break;
00136                 }
00137         }
00138 
00139         //
00140         // Generate a modulus of the required length
00141         //
00142         for (;;)
00143         {
00144                 // Generate q, prime and (q-1) relatively prime to e,
00145                 // and not equal to p
00146                 //
00147                 for (;;)
00148                 {
00149                         m_q = BigInteger(qbitlength, certainty, rand);
00150                         ASSERT(m_q.GetBitLength() == qbitlength);
00151 
00152                         if (m_q.Subtract(m_p)->Abs()->GetBitLength() < mindiffbits)
00153                         {
00154                                 continue;
00155                         }
00156                         if (m_q.Mod(m_publicExponent)->Equals(BigInteger::One()))
00157                         {
00158                                 continue;
00159                         }
00160                         if (! m_q.IsProbablePrime(certainty))
00161                         {
00162                                 continue;
00163                         }
00164                         if (m_publicExponent.Gcd(m_q.Subtract(BigInteger::One()))->Equals(1)) 
00165                         {
00166                                 break;
00167                         }
00168                 }
00169 
00170                 //
00171                 // calculate the modulus
00172                 //
00173                 m_modulus = *m_p.Multiply(m_q);
00174 
00175                 if (m_modulus.GetBitLength() >= strength)
00176                 {
00177                         break;
00178                 }
00179 
00180                 //
00181                 // if we Get here our primes aren't big enough, make the largest
00182                 // of the two p and try again
00183                 //
00184                 m_p = m_p > m_q ? m_p : m_q;
00185         }
00186 
00187         if (m_p.Compare(m_q) < 0)
00188         {
00189                 phi = m_p;
00190                 m_p = m_q;
00191                 m_q = phi;
00192         }
00193         ASSERT(m_p > m_q);
00194 
00195         pSub1 = m_p - 1;
00196         qSub1 = m_q - 1;
00197         phi = pSub1 * qSub1;
00198 
00199         //
00200         // calculate the private exponent
00201         //
00202         m_privateExponent = *m_publicExponent.ModInverse(phi);
00203 
00204         m_dP = *m_privateExponent.Remainder(pSub1);
00205         m_dQ = *m_privateExponent.Remainder(qSub1);
00206         m_qInv = *m_q.ModInverse(m_p);
00207 }
00208 
00209 RSAPrivateKey::~RSAPrivateKey()
00210 {
00211 }
00212 
00213 #if defined(DEBUG) || defined(_DEBUG)
00214 void RSAPrivateKey::CheckMem() const
00215 {
00216         m_privateExponent.CheckMem();
00217         m_publicExponent.CheckMem();
00218         m_modulus.CheckMem();
00219         m_p.CheckMem();
00220         m_q.CheckMem();
00221         m_dP.CheckMem();
00222         m_dQ.CheckMem();
00223         m_qInv.CheckMem();
00224 }
00225 
00226 void RSAPrivateKey::ValidateMem() const
00227 {
00228         m_privateExponent.ValidateMem();
00229         m_publicExponent.ValidateMem();
00230         m_modulus.ValidateMem();
00231         m_p.ValidateMem();
00232         m_q.ValidateMem();
00233         m_dP.ValidateMem();
00234         m_dQ.ValidateMem();
00235         m_qInv.ValidateMem();
00236 }
00237 #endif
00238 
00239 RSAPublic::RSAPublic(int bitSize, BigInteger& mondulus, BigInteger& exponent)
00240 : m_bitSize(bitSize), m_publicKey(bitSize, mondulus, exponent)
00241 {
00242 }
00243 
00244 RSAPublic::RSAPublic(int bitSize)
00245 : m_bitSize(bitSize), m_publicKey()
00246 {
00247 }
00248 
00249 RSAPublic::RSAPublic(const RSAPublic& key)
00250 : m_publicKey(key.m_publicKey), m_bitSize(key.m_bitSize)
00251 {
00252 }
00253 
00254 RSAPublic::RSAPublic(const RSAKey& key, int bitSize)
00255 : m_publicKey(key), m_bitSize(bitSize)
00256 {
00257 }
00258 
00259 RSAPublic::~RSAPublic()
00260 {
00261 }
00262 
00263 RefCountPtr<Array<byte> > RSAPublic::EncryptBinary(Array<byte>& plainText)
00264 {
00265         Vector<byte> encBytes;
00266         RefCountPtr<BigInteger> enc;
00267         Array<byte> bytes;
00268         int byteCount = m_bitSize / 8 + (( m_bitSize % 8 ) == 0 ? 0 : 1 );
00269         int inputByteCount = GetInputBlockSize(true);
00270 
00271         for(int ptPos = 0; ptPos < plainText.Length(); ptPos += inputByteCount)
00272         {
00273                 BigInteger txt(1, plainText, ptPos, inputByteCount);
00274 
00275                 enc = txt.ModPow(m_publicKey.Exponent(), m_publicKey.Modulus());
00276 
00277                 enc->ToByteArrayUnsigned(bytes);
00278                 //if (bytes.Length() != byteCount)
00279                 //{
00280                 //      Log::WriteInfo("bytes.Length() != byteCount %d", bytes.Length());
00281                 //}
00282 
00283                 encBytes.AddRangeWithLeadingPad(bytes, byteCount);
00284 
00285                 ASSERT((encBytes.Count() % byteCount) == 0);
00286         }
00287 
00288         return encBytes.ToArray();
00289 }
00290 
00291 RefCountPtr<String> RSAPublic::EncryptText(Array<byte>& plainText)
00292 {
00293         RefCountPtr<Array<byte> > encBytes = EncryptBinary(plainText);
00294         return String::Base64Encode(encBytes);
00295 }
00296 
00297 RefCountPtr<String> RSAPublic::EncryptText(const String& plainText)
00298 {
00299         Array<byte> bytes = plainText.ToByteArray();
00300         return EncryptText(bytes);
00301 }
00302 
00310 int RSAPublic::GetInputBlockSize(bool forEncryption) const
00311 {
00312         if (forEncryption)
00313         {
00314                 return (m_bitSize - 1) / 8;
00315         }
00316 
00317         return (m_bitSize + 7) / 8;
00318 }
00319 
00320 void RSAPublic::PadInputBuffer(Array<byte>& plainText) const
00321 {
00322         int blockSize = GetInputBlockSize(true);
00323         int mod = blockSize - (plainText.Length() % blockSize);
00324         if (0 == mod)
00325         {
00326                 return;
00327         }
00328         Array<byte> buf(plainText.Length() + mod);
00329         plainText.CopyToBinary(buf);
00330         plainText = buf;
00331 }
00332 
00340 int RSAPublic::GetOutputBlockSize(bool forEncryption) const
00341 {
00342         if (forEncryption)
00343         {
00344                 return (m_bitSize + 7) / 8;
00345         }
00346 
00347         return (m_bitSize - 1) / 8;
00348 }
00349 
00350 //void RSAPublic::ConvertInput
00351 //(
00352 //      Array<byte>& inBuf,
00353 //      int             inOff,
00354 //      int             inLen,
00355 //      BigInteger& output
00356 //)
00357 //{
00358 //      int maxLength = (m_bitSize + 7) / 8;
00359 //
00360 //      if (inLen > maxLength)
00361 //      {
00362 //              throw new InvalidArgumentException("input too large for RSA cipher.");
00363 //      }
00364 //
00365 //      output = BigInteger(1, inBuf, inOff, inLen);
00366 //
00367 //      if (output.Compare(m_publicKey.Modulus()) >= 0)
00368 //      {
00369 //              throw new InvalidArgumentException("input too large for RSA cipher.");
00370 //      }
00371 //}
00372 
00373 //void RSAPublic::ConvertOutput
00374 //(
00375 //      bool forEncryption,
00376 //      BigInteger& result,
00377 //      Array<byte>& output
00378 //)
00379 //{
00380 //      output = result.ToByteArrayUnsigned();
00381 //
00382 //      if (forEncryption)
00383 //      {
00384 //              int outSize = GetOutputBlockSize(forEncryption);
00385 //
00386 //              // TODO To avoid this, create version of BigInteger.ToByteArray that
00387 //              // writes to an existing array
00388 //              if (output.Length() < outSize) // have ended up with less bytes than normal, lengthen
00389 //              {
00390 //                      Array<byte> tmp(outSize);
00391 //                      output.CopyTo(tmp, tmp.Length() - output.Length());
00392 //                      output = tmp;
00393 //              }
00394 //      }
00395 //}
00396 
00397 #if defined(DEBUG) || defined(_DEBUG)
00398 void RSAPublic::CheckMem() const
00399 {
00400         m_publicKey.CheckMem();
00401 }
00402 
00403 void RSAPublic::ValidateMem() const
00404 {
00405         m_publicKey.ValidateMem();
00406 }
00407 #endif
00408 
00409 RSA::RSA(int bitSize, int certainty)
00410 : RSAPublic(bitSize), m_key(bitSize, certainty)
00411 {
00412         m_publicKey = m_key.GetPublicKey();
00413 }
00414 
00415 RefCountPtr<Array<byte> > RSA::DecryptBinary(Array<byte>& encText)
00416 {
00417         Vector<byte> ptBytes;
00418         BigInteger enc;
00419         Array<byte> bytes;
00420         int byteCount = m_bitSize / 8 + (( m_bitSize % 8 ) == 0 ? 0 : 1 );
00421         int outputByteCount = GetOutputBlockSize(false);
00422 
00423         ASSERT(m_publicKey.Exponent() != m_key.PrivateExponent());
00424         ASSERT(m_publicKey.Modulus() == m_key.Modulus());
00425         ASSERT((encText.Length() % byteCount) == 0);
00426 
00427         for(int etPos = 0; etPos < encText.Length(); etPos += byteCount)
00428         {
00429                 enc = BigInteger(encText, etPos, byteCount);
00430 
00431                 //BigInteger mP, mQ, h;
00432 
00434                 //mP = (enc.Remainder(m_key.P())).ModPow(m_key.DP(), m_key.P());
00435 
00437                 //mQ = (enc.Remainder(m_key.Q())).ModPow(m_key.DQ(), m_key.Q());
00438 
00440                 //h = mP.Subtract(mQ);
00441                 //h = h.Multiply(m_key.QInv());
00442                 //h = h.Mod(m_key.P());               // Mod (in Java) returns the positive residual
00443 
00445                 //txt = h.Multiply(m_key.Q());
00446                 //txt = txt.Add(mQ);
00447 
00448                 BigIntegerPtr txt = enc.ModPow(m_key.PrivateExponent(), m_key.Modulus());
00449 
00450                 txt->ToByteArrayUnsigned(bytes);
00451 
00452                 if (bytes.Length() > byteCount)
00453                 {
00454                         ASSERT(bytes.Length() == byteCount);
00455                 }
00456 
00457                 ptBytes.AddRangeWithPad(bytes, outputByteCount);
00458         }
00459 
00460         return ptBytes.ToArray();
00461 }
00462 
00463 RefCountPtr<String> RSA::DecryptText(const String& encText)
00464 {
00465         RefCountPtr<Array<byte> > ptBytes = DecryptBinary(*String::Base64Decode(encText));
00466         return RefCountPtr<String>(new String((const char *)ptBytes->Data(), ptBytes->Length()));
00467 }
00468 
00469 #if defined(DEBUG)
00470 void RSA::CheckMem() const
00471 {
00472         RSAPublic::CheckMem();
00473         m_key.CheckMem();
00474 }
00475 
00476 void RSA::ValidateMem() const
00477 {
00478         RSAPublic::ValidateMem();
00479         m_key.ValidateMem();
00480 }
00481 #endif
00482