RSA Implementation in Haskell

a. Introduction

RSA (Rivest-Shamir-Adleman) is one of the most widely used public-key cryptosystems for secure data transmission. It relies on the difficulty of factoring the product of two large prime numbers. The security of RSA is based on this computational complexity. RSA is used in various applications, including secure communications, digital signatures, and encryption.

In RSA, each user generates a pair of keys: a public key and a private key. The public key is used for encrypting messages and can be shared with anyone, while the private key is used for decrypting messages and must be kept secret. The relationship between the public and private keys ensures that only the intended recipient can decrypt the message.

b. Algorithm Explanation

Key Generation:

  1. Select two distinct large prime numbers ( p ) and ( q ):
    • These primes should be large to ensure security.
  2. Compute $( n = p \times q )$:
    • The product ( n ) forms part of both the public and private keys.
  3. Calculate the totient function $( \phi(n) = (p-1) \times (q-1) )$:
    • The totient function is crucial for determining the key pair.
  4. Choose an integer $( e )$ such that $( 1 < e < \phi(n) )$ and $( \text{gcd}(e, \phi(n)) = 1 )$:
    • The number ( e ) is the public exponent, typically chosen as 65537 for efficiency and security.
  5. Compute $( d )$ as the modular multiplicative inverse of $(e)\mod ( \phi(n) )$:
    • $( d )$ is the private exponent, ensuring that $( e \times d \equiv 1 \mod \phi(n) )$.

The public key is ( (e, n) ) and the private key is ( (d, n) ).

Proof of Correctness:

  1. Encryption: $c = m^e \mod n$
  2. Decryption: $m = c^d \mod n$

To show that decryption works correctly, we need to prove that $m = (m^e)^d \mod n$.

Given $e \times d \equiv 1 \mod \phi(n)$, there exists an integer $k$ such that $e \times d = 1 + k \times \phi(n)$. Thus:

$(m^e)^d = m^{e \times d} = m^{1 + k \times \phi(n)} = m \times (m^{\phi(n)})^k$

By Euler’s theorem, $( m^{\phi(n)} \equiv 1 \mod n )$ for any $m$ coprime with $n$, leading to:

$m \times (m^{\phi(n)})^k \equiv m \times 1^k \equiv m \mod n$

Thus, the decryption recovers the original message $m$.

c. Implementation

Let’s break down the Haskell implementation into multiple parts.

Extended Euclidean Algorithm

The Extended Euclidean Algorithm is used to find the greatest common divisor (GCD) of two integers and express it in the form of a linear combination of these integers. This is also used to compute the modular inverse.

extendedGcd :: Integral a => a -> a -> (a, a, a)
extendedGcd 0 b = (b, 0, 1)
extendedGcd a b =
    let (g, s, t) = extendedGcd (b `mod` a) a
    in (g, t - (b `div` a) * s, s)

Modular Inverse

The modular inverse $(a^-1)$ is computed using the Extended Euclidean Algorithm. It ensures that $a \times a^{-1} \equiv 1 \mod m$.

modInv :: Integral a => a -> a -> a
modInv a m = let (_, x, _) = extendedGcd a m in x `mod` m

Prime Number Generation

A simple primality test and a random prime number generator are used to select the primes ( p ) and ( q ).

generatePrime :: IO Integer
generatePrime = randomRIO (2^32, 2^33) >>= \n -> if isPrime n then return n else generatePrime

notDivisible :: Integer -> Integer -> Bool
notDivisible n x = n `mod` x /= 0

isPrime :: Integer -> Bool
isPrime n = n > 1 && all (notDivisible n) [2..((floor . sqrt . fromIntegral $ n))]

Key Generation

The key generation process involves generating two large prime numbers, computing their product and totient, and finding the public and private exponents.

generateKeys :: IO ((Integer, Integer), (Integer, Integer))
generateKeys = do
    p <- generatePrime
    q <- generatePrime
    let n = p * q
        phi = (p - 1) * (q - 1)
        e = 65537 
        d = modInv e phi
    return ((e, n), (d, n))

Efficient Modular Exponentiation

Modular exponentiation is used to efficiently compute $a^b \mod m$. This implementation uses the right-to-left binary method.

powMod :: Integer -> Integer -> Integer -> Integer
powMod _ 0 _ = 1
powMod base exp modulus
    | even exp = let half = powMod base (exp `div` 2) modulus in (half * half) `mod` modulus
    | otherwise = (base * powMod base (exp - 1) modulus) `mod` modulus

Encrypt and Decrypt Functions

The encryption function raises the message to the power of the public exponent, and the decryption function raises the ciphertext to the power of the private exponent, both modulo $n$ . bytesToLong and longToBytes function help us to convert the message into number and vice versa for encryption/decryption.

bytesToLong :: String -> Integer
bytesToLong bytes = foldl (\acc byte -> (acc `shiftL` 8) + fromIntegral (ord byte)) 0 bytes

longToBytes :: Integer -> String
longToBytes 0 = ""
longToBytes n = longToBytes (n `div` 2^8) ++ [chr (fromIntegral (n `mod` 2^8))]

encrypt :: String -> (Integer, Integer) -> Integer
encrypt m (e, n) = powMod (bytesToLong m) e n

decrypt :: Integer -> (Integer, Integer) -> String
decrypt c (d, n) = longToBytes ( powMod c d n )

Encrypt and Decrypt Lists of Messages

Using Foldable and Traversable type classes, we can encrypt and decrypt lists of messages.

encryptMessages :: Foldable t => t Integer -> (Integer, Integer) -> [Integer]
encryptMessages messages publicKey = foldMap (\m -> [encrypt m publicKey]) messages

decryptMessages :: Traversable t => t Integer -> (Integer, Integer) -> IO (t Integer)
decryptMessages messages privateKey = traverse (pure . flip decrypt privateKey) messages

Encrypt and Decrypt for Longer Messages

The encrypt and decrypt functions only work for the messages that have size of 8. For messages that have size larger than 8 such as 123456789 will make the encryption or description go wrong because size of n is from 64-66 bits and the string have more than 8 characters will have size over 64 bits, making the algorithm does not work as expect.

So I separate the message into chunks of 64 bits and encrypt/decrypt them distinctively:

-- splitMessage split the message into chunks of 64 bits
splitMessage :: String -> [String]
splitMessage [] = []
splitMessage message = let (chunk, rest) = splitAt 8 message in chunk : splitMessage rest

-- mergeChunks help to merge the encrypted chunks into 1 interger
mergeChunks :: [Integer] -> Integer
mergeChunks bytes = foldl (\acc byte -> (acc `shiftL` 66) + fromIntegral byte) 0 bytes

-- splitChunks split the 1 interger back to encrypted chunks and then decrypt them with decrypt2 
splitChunks :: Integer -> [Integer]
splitChunks 0 = []
splitChunks n = splitChunks (n `div` (2^66)) ++ [fromIntegral (n `mod` (2^66))]

-- Because the length of the message is limited by the modulus n,
-- we can split the message into chunks of size 8 and encrypt each chunk separately.
encrypt2 :: String -> (Integer, Integer) -> Integer
encrypt2 m (e, n) = mergeChunks (foldMap (\c -> [encrypt c (e, n)]) (splitMessage m))

decrypt2 :: Integer -> (Integer, Integer) -> String
decrypt2 m (d, n) = foldMap (\c -> decrypt c (d, n)) (splitChunks m)

Main Function

The main function demonstrates generating keys, encrypting a list of messages, and decrypting them.

main :: IO ()
main = do
    (publicKey, privateKey) <- generateKeys
    putStrLn $ "Public Key: " ++ show publicKey
    putStrLn $ "Private Key: " ++ show privateKey
    let messages = "Hello World"
    let encryptedMessages = encrypt2 messages publicKey
    putStrLn $ "Encrypted Messages: " ++ show encryptedMessages
    let decryptedMessages = decrypt2 encryptedMessages privateKey
    putStrLn $ "Decrypted Messages: " ++ decryptedMessages

d. Appendix: Code Listing

Here is the complete Haskell code listing for the RSA implementation.

import Data.Monoid
import Data.Bits (shiftL, (.|.))
import Data.Foldable (Foldable, foldMap)
import Data.Traversable (Traversable, traverse)
import Control.Applicative
import System.Random (randomRIO)
import Data.Char (ord, chr)

-- Extended Euclidean algorithm
-- using for modular inverse
-- as + bt = gcd(a, b)
-- input: a, b
-- output: (gcd(a, b), s, t)
extendedGcd :: Integral a => a -> a -> (a, a, a)
extendedGcd 0 b = (b, 0, 1)
extendedGcd a b =
    let (g, s, t) = extendedGcd (b `mod` a) a
    in (g, t - (b `div` a) * s, s)

-- Modular inverse
-- a * a^-1 ≡ 1 (mod m)
-- input: a, m
-- output: a^-1
modInv :: Integral a => a -> a -> a
modInv a m = let (_, x, _) = extendedGcd a m in x `mod` m

-- RSA key generation
generateKeys :: IO ((Integer, Integer), (Integer, Integer))
generateKeys = do
    p <- generatePrime
    q <- generatePrime
    let n = p * q
        phi = (p - 1) * (q - 1)
        e = 65537
        d = modInv e phi
    return ((e, n), (d, n))

-- String to int
-- input: string
-- output: integer
bytesToLong :: String -> Integer
bytesToLong bytes = foldl (\acc byte -> (acc `shiftL` 8) + fromIntegral (ord byte)) 0 bytes

-- longToBytes :: Integer -> String
-- longToBytes 0 = ""
-- longToBytes n = longToBytes (n `div` 256) ++ [chr (fromIntegral (n `mod` 256))]
longToBytes :: Integer -> String
longToBytes 0 = ""
longToBytes n = longToBytes (n `div` 2^8) ++ [chr (fromIntegral (n `mod` 2^8))]

mergeChunks :: [Integer] -> Integer
mergeChunks bytes = foldl (\acc byte -> (acc `shiftL` 66) + fromIntegral byte) 0 bytes

splitChunks :: Integer -> [Integer]
splitChunks 0 = []
splitChunks n = splitChunks (n `div` (2^66)) ++ [fromIntegral (n `mod` (2^66))]

-- Encrypt a message
-- input: message, public key: (e, n)
-- output: encrypted message
encrypt :: String -> (Integer, Integer) -> Integer
encrypt m (e, n) = powMod (bytesToLong m) e n

-- Decrypt a message
-- input: encrypted message, private key: (d, n)
-- output: decrypted message
decrypt :: Integer -> (Integer, Integer) -> String
decrypt c (d, n) = longToBytes ( powMod c d n )

-- Because the length of the message is limited by the modulus n, 
-- we can split the message into chunks of size 8 and encrypt each chunk separately. 
encrypt2 :: String -> (Integer, Integer) -> Integer
encrypt2 m (e, n) = mergeChunks (foldMap (\c -> [encrypt c (e, n)]) (splitMessage m))

decrypt2 :: Integer -> (Integer, Integer) -> String
decrypt2 m (d, n) = foldMap (\c -> decrypt c (d, n)) (splitChunks m)

-- Efficient modular exponentiation
-- input: base, exponent, modulus
-- output: base^exponent mod modulus
-- algorithm: https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method
-- a^b = (a^(b/2)) * (a^(b/2)) if b is even
-- a^b = a * (a^(b-1)) if b is odd
powMod :: Integer -> Integer -> Integer -> Integer
powMod _ 0 _ = 1
powMod base exp modulus
    | even exp = let half = powMod base (exp `div` 2) modulus in (half * half) `mod` modulus
    | otherwise = (base * powMod base (exp - 1) modulus) `mod` modulus


-- Split message into chunks of size 8
-- mainly because each character is represented by 8 bits
-- so chunks size would be under 64 bits
-- and n size is from 64 to 66 bits
-- so each chunk is less than n and the message can
-- be handled properly by the RSA algorithm
-- input: message
-- output: list of chunks
splitMessage :: String -> [String]
splitMessage [] = []
splitMessage message = let (chunk, rest) = splitAt 8 message in chunk : splitMessage rest

-- randomPrime = do
--     candidate <- randomRIO (2^32, 2^33)
--     if isPrime candidate then return candidate else randomPrime
generatePrime :: IO Integer
generatePrime = randomRIO (2^32, 2^33) >>= \n -> if isPrime n then return n else generatePrime

-- Define the helper function for divisibility check
notDivisible :: Integer -> Integer -> Bool
notDivisible n x = n `mod` x /= 0

-- Simple primality test (not efficient for large primes)
-- isPrime algorithm: https://en.wikipedia.org/wiki/Primality_test#Simple_methods
isPrime :: Integer -> Bool
isPrime n = n > 1 && all (notDivisible n) [2..((floor . sqrt . fromIntegral $ n))]

-- Encrypt a list of messages using Foldable
encryptMessages :: Foldable t => t String -> (Integer, Integer) -> [Integer]
encryptMessages messages publicKey = foldMap (\m -> [encrypt m publicKey]) messages

-- Decrypt a list of messages using Traversable
decryptMessages :: Traversable t => t Integer -> (Integer, Integer) -> IO (t String)
decryptMessages messages privateKey = traverse (pure . flip decrypt privateKey) messages

main :: IO ()
main = do
    (publicKey, privateKey) <- generateKeys
    putStrLn $ "Public Key: " ++ show publicKey
    putStrLn $ "Private Key: " ++ show privateKey
    let messages = "Hello World"
    let encryptedMessages = encrypt2 messages publicKey
    putStrLn $ "Encrypted Messages: " ++ show encryptedMessages
    let decryptedMessages = decrypt2 encryptedMessages privateKey
    putStrLn $ "Decrypted Messages: " ++ decryptedMessages