// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This package implements RSA encryption as specified in PKCS#1. package rsa // TODO(agl): Add support for PSS padding. import ( "big" "crypto/subtle" "hash" "io" "os" ) var bigZero = big.NewInt(0) var bigOne = big.NewInt(1) // randomPrime returns a number, p, of the given size, such that p is prime // with high probability. func randomPrime(rand io.Reader, bits int) (p *big.Int, err os.Error) { if bits < 1 { err = os.EINVAL } bytes := make([]byte, (bits+7)/8) p = new(big.Int) for { _, err = io.ReadFull(rand, bytes) if err != nil { return } // Don't let the value be too small. bytes[0] |= 0x80 // Make the value odd since an even number this large certainly isn't prime. bytes[len(bytes)-1] |= 1 p.SetBytes(bytes) if big.ProbablyPrime(p, 20) { return } } return } // randomNumber returns a uniform random value in [0, max). func randomNumber(rand io.Reader, max *big.Int) (n *big.Int, err os.Error) { k := (max.BitLen() + 7) / 8 // r is the number of bits in the used in the most significant byte of // max. r := uint(max.BitLen() % 8) if r == 0 { r = 8 } bytes := make([]byte, k) n = new(big.Int) for { _, err = io.ReadFull(rand, bytes) if err != nil { return } // Clear bits in the first byte to increase the probability // that the candidate is < max. bytes[0] &= uint8(int(1< k-2*hash.Size()-2 { err = MessageTooLongError{} return } hash.Write(label) lHash := hash.Sum() hash.Reset() em := make([]byte, k) seed := em[1 : 1+hash.Size()] db := em[1+hash.Size():] copy(db[0:hash.Size()], lHash) db[len(db)-len(msg)-1] = 1 copy(db[len(db)-len(msg):], msg) _, err = io.ReadFull(rand, seed) if err != nil { return } mgf1XOR(db, hash, seed) mgf1XOR(seed, hash, db) m := new(big.Int) m.SetBytes(em) c := encrypt(new(big.Int), pub, m) out = c.Bytes() return } // A DecryptionError represents a failure to decrypt a message. // It is deliberately vague to avoid adaptive attacks. type DecryptionError struct{} func (DecryptionError) String() string { return "RSA decryption error" } // A VerificationError represents a failure to verify a signature. // It is deliberately vague to avoid adaptive attacks. type VerificationError struct{} func (VerificationError) String() string { return "RSA verification error" } // modInverse returns ia, the inverse of a in the multiplicative group of prime // order n. It requires that a be a member of the group (i.e. less than n). func modInverse(a, n *big.Int) (ia *big.Int, ok bool) { g := new(big.Int) x := new(big.Int) y := new(big.Int) big.GcdInt(g, x, y, a, n) if g.Cmp(bigOne) != 0 { // In this case, a and n aren't coprime and we cannot calculate // the inverse. This happens because the values of n are nearly // prime (being the product of two primes) rather than truly // prime. return } if x.Cmp(bigOne) < 0 { // 0 is not the multiplicative inverse of any element so, if x // < 1, then x is negative. x.Add(x, n) } return x, true } // decrypt performs an RSA decryption, resulting in a plaintext integer. If a // random source is given, RSA blinding is used. func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.Error) { // TODO(agl): can we get away with reusing blinds? if c.Cmp(priv.N) > 0 { err = DecryptionError{} return } var ir *big.Int if rand != nil { // Blinding enabled. Blinding involves multiplying c by r^e. // Then the decryption operation performs (m^e * r^e)^d mod n // which equals mr mod n. The factor of r can then be removed // by multipling by the multiplicative inverse of r. var r *big.Int for { r, err = randomNumber(rand, priv.N) if err != nil { return } if r.Cmp(bigZero) == 0 { r = bigOne } var ok bool ir, ok = modInverse(r, priv.N) if ok { break } } bigE := big.NewInt(int64(priv.E)) rpowe := new(big.Int).Exp(r, bigE, priv.N) c.Mul(c, rpowe) c.Mod(c, priv.N) } m = new(big.Int).Exp(c, priv.D, priv.N) if ir != nil { // Unblind. m.Mul(m, ir) m.Mod(m, priv.N) } return } // DecryptOAEP decrypts ciphertext using RSA-OAEP. // If rand != nil, DecryptOAEP uses RSA blinding to avoid timing side-channel attacks. func DecryptOAEP(hash hash.Hash, rand io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) (msg []byte, err os.Error) { k := (priv.N.BitLen() + 7) / 8 if len(ciphertext) > k || k < hash.Size()*2+2 { err = DecryptionError{} return } c := new(big.Int).SetBytes(ciphertext) m, err := decrypt(rand, priv, c) if err != nil { return } hash.Write(label) lHash := hash.Sum() hash.Reset() // Converting the plaintext number to bytes will strip any // leading zeros so we may have to left pad. We do this unconditionally // to avoid leaking timing information. (Although we still probably // leak the number of leading zeros. It's not clear that we can do // anything about this.) em := leftPad(m.Bytes(), k) firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0) seed := em[1 : hash.Size()+1] db := em[hash.Size()+1:] mgf1XOR(seed, hash, db) mgf1XOR(db, hash, seed) lHash2 := db[0:hash.Size()] // We have to validate the plaintext in constant time in order to avoid // attacks like: J. Manger. A Chosen Ciphertext Attack on RSA Optimal // Asymmetric Encryption Padding (OAEP) as Standardized in PKCS #1 // v2.0. In J. Kilian, editor, Advances in Cryptology. lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2) // The remainder of the plaintext must be zero or more 0x00, followed // by 0x01, followed by the message. // lookingForIndex: 1 iff we are still looking for the 0x01 // index: the offset of the first 0x01 byte // invalid: 1 iff we saw a non-zero byte before the 0x01. var lookingForIndex, index, invalid int lookingForIndex = 1 rest := db[hash.Size():] for i := 0; i < len(rest); i++ { equals0 := subtle.ConstantTimeByteEq(rest[i], 0) equals1 := subtle.ConstantTimeByteEq(rest[i], 1) index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index) lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex) invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid) } if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 { err = DecryptionError{} return } msg = rest[index+1:] return } // leftPad returns a new slice of length size. The contents of input are right // aligned in the new slice. func leftPad(input []byte, size int) (out []byte) { n := len(input) if n > size { n = size } out = make([]byte, size) copy(out[len(out)-n:], input) return }