RSA Implementation in Haskell
RSA Implementation in Haskell
a. Introduction
RSA (Rivest-Shamir-Adleman) is one of the most widely used public-key cryptosystems for secure data transmission. It relies on the difficulty of factoring the product of two large prime numbers. The security of RSA is based on this computational complexity. RSA is used in various applications, including secure communications, digital signatures, and encryption.
In RSA, each user generates a pair of keys: a public key and a private key. The public key is used for encrypting messages and can be shared with anyone, while the private key is used for decrypting messages and must be kept secret. The relationship between the public and private keys ensures that only the intended recipient can decrypt the message.
b. Algorithm Explanation
Key Generation:
- Select two distinct large prime numbers ( p ) and ( q ):
- These primes should be large to ensure security.
- Compute $( n = p \times q )$:
- The product ( n ) forms part of both the public and private keys.
- Calculate the totient function $( \phi(n) = (p-1) \times (q-1) )$:
- The totient function is crucial for determining the key pair.
- Choose an integer $( e )$ such that $( 1 < e < \phi(n) )$ and $( \text{gcd}(e, \phi(n)) = 1 )$:
- The number ( e ) is the public exponent, typically chosen as 65537 for efficiency and security.
- Compute $( d )$ as the modular multiplicative inverse of $(e)\mod ( \phi(n) )$:
- $( d )$ is the private exponent, ensuring that $( e \times d \equiv 1 \mod \phi(n) )$.
The public key is ( (e, n) ) and the private key is ( (d, n) ).
Proof of Correctness:
- Encryption: $c = m^e \mod n$
- Decryption: $m = c^d \mod n$
To show that decryption works correctly, we need to prove that $m = (m^e)^d \mod n$.
Given $e \times d \equiv 1 \mod \phi(n)$, there exists an integer $k$ such that $e \times d = 1 + k \times \phi(n)$. Thus:
$(m^e)^d = m^{e \times d} = m^{1 + k \times \phi(n)} = m \times (m^{\phi(n)})^k$
By Euler’s theorem, $( m^{\phi(n)} \equiv 1 \mod n )$ for any $m$ coprime with $n$, leading to:
$m \times (m^{\phi(n)})^k \equiv m \times 1^k \equiv m \mod n$
Thus, the decryption recovers the original message $m$.
c. Implementation
Let’s break down the Haskell implementation into multiple parts.
Extended Euclidean Algorithm
The Extended Euclidean Algorithm is used to find the greatest common divisor (GCD) of two integers and express it in the form of a linear combination of these integers. This is also used to compute the modular inverse.
extendedGcd :: Integral a => a -> a -> (a, a, a)
extendedGcd 0 b = (b, 0, 1)
extendedGcd a b =
let (g, s, t) = extendedGcd (b `mod` a) a
in (g, t - (b `div` a) * s, s)
Modular Inverse
The modular inverse $(a^-1)$ is computed using the Extended Euclidean Algorithm. It ensures that $a \times a^{-1} \equiv 1 \mod m$.
modInv :: Integral a => a -> a -> a
modInv a m = let (_, x, _) = extendedGcd a m in x `mod` m
Prime Number Generation
A simple primality test and a random prime number generator are used to select the primes ( p ) and ( q ).
generatePrime :: IO Integer
generatePrime = randomRIO (2^32, 2^33) >>= \n -> if isPrime n then return n else generatePrime
notDivisible :: Integer -> Integer -> Bool
notDivisible n x = n `mod` x /= 0
isPrime :: Integer -> Bool
isPrime n = n > 1 && all (notDivisible n) [2..((floor . sqrt . fromIntegral $ n))]
Key Generation
The key generation process involves generating two large prime numbers, computing their product and totient, and finding the public and private exponents.
generateKeys :: IO ((Integer, Integer), (Integer, Integer))
generateKeys = do
p <- generatePrime
q <- generatePrime
let n = p * q
phi = (p - 1) * (q - 1)
e = 65537
d = modInv e phi
return ((e, n), (d, n))
Efficient Modular Exponentiation
Modular exponentiation is used to efficiently compute $a^b \mod m$. This implementation uses the right-to-left binary method.
powMod :: Integer -> Integer -> Integer -> Integer
powMod _ 0 _ = 1
powMod base exp modulus
| even exp = let half = powMod base (exp `div` 2) modulus in (half * half) `mod` modulus
| otherwise = (base * powMod base (exp - 1) modulus) `mod` modulus
Encrypt and Decrypt Functions
The encryption function raises the message to the power of the public exponent, and the decryption function raises the ciphertext to the power of the private exponent, both modulo $n$ .
bytesToLong
and longToBytes
function help us to convert the message into number and vice versa for encryption/decryption.
bytesToLong :: String -> Integer
bytesToLong bytes = foldl (\acc byte -> (acc `shiftL` 8) + fromIntegral (ord byte)) 0 bytes
longToBytes :: Integer -> String
longToBytes 0 = ""
longToBytes n = longToBytes (n `div` 2^8) ++ [chr (fromIntegral (n `mod` 2^8))]
encrypt :: String -> (Integer, Integer) -> Integer
encrypt m (e, n) = powMod (bytesToLong m) e n
decrypt :: Integer -> (Integer, Integer) -> String
decrypt c (d, n) = longToBytes ( powMod c d n )
Encrypt and Decrypt Lists of Messages
Using Foldable
and Traversable
type classes, we can encrypt and decrypt lists of messages.
encryptMessages :: Foldable t => t Integer -> (Integer, Integer) -> [Integer]
encryptMessages messages publicKey = foldMap (\m -> [encrypt m publicKey]) messages
decryptMessages :: Traversable t => t Integer -> (Integer, Integer) -> IO (t Integer)
decryptMessages messages privateKey = traverse (pure . flip decrypt privateKey) messages
Encrypt and Decrypt for Longer Messages
The encrypt
and decrypt
functions only work for the messages that have size of 8. For messages that have size larger than 8 such as 123456789
will make the encryption or description go wrong because size of n
is from 64-66 bits and the string have more than 8 characters will have size over 64 bits, making the algorithm does not work as expect.
So I separate the message into chunks of 64 bits and encrypt/decrypt them distinctively:
-- splitMessage split the message into chunks of 64 bits
splitMessage :: String -> [String]
splitMessage [] = []
splitMessage message = let (chunk, rest) = splitAt 8 message in chunk : splitMessage rest
-- mergeChunks help to merge the encrypted chunks into 1 interger
mergeChunks :: [Integer] -> Integer
mergeChunks bytes = foldl (\acc byte -> (acc `shiftL` 66) + fromIntegral byte) 0 bytes
-- splitChunks split the 1 interger back to encrypted chunks and then decrypt them with decrypt2
splitChunks :: Integer -> [Integer]
splitChunks 0 = []
splitChunks n = splitChunks (n `div` (2^66)) ++ [fromIntegral (n `mod` (2^66))]
-- Because the length of the message is limited by the modulus n,
-- we can split the message into chunks of size 8 and encrypt each chunk separately.
encrypt2 :: String -> (Integer, Integer) -> Integer
encrypt2 m (e, n) = mergeChunks (foldMap (\c -> [encrypt c (e, n)]) (splitMessage m))
decrypt2 :: Integer -> (Integer, Integer) -> String
decrypt2 m (d, n) = foldMap (\c -> decrypt c (d, n)) (splitChunks m)
Main Function
The main function demonstrates generating keys, encrypting a list of messages, and decrypting them.
main :: IO ()
main = do
(publicKey, privateKey) <- generateKeys
putStrLn $ "Public Key: " ++ show publicKey
putStrLn $ "Private Key: " ++ show privateKey
let messages = "Hello World"
let encryptedMessages = encrypt2 messages publicKey
putStrLn $ "Encrypted Messages: " ++ show encryptedMessages
let decryptedMessages = decrypt2 encryptedMessages privateKey
putStrLn $ "Decrypted Messages: " ++ decryptedMessages
d. Appendix: Code Listing
Here is the complete Haskell code listing for the RSA implementation.
import Data.Monoid
import Data.Bits (shiftL, (.|.))
import Data.Foldable (Foldable, foldMap)
import Data.Traversable (Traversable, traverse)
import Control.Applicative
import System.Random (randomRIO)
import Data.Char (ord, chr)
-- Extended Euclidean algorithm
-- using for modular inverse
-- as + bt = gcd(a, b)
-- input: a, b
-- output: (gcd(a, b), s, t)
extendedGcd :: Integral a => a -> a -> (a, a, a)
extendedGcd 0 b = (b, 0, 1)
extendedGcd a b =
let (g, s, t) = extendedGcd (b `mod` a) a
in (g, t - (b `div` a) * s, s)
-- Modular inverse
-- a * a^-1 ≡ 1 (mod m)
-- input: a, m
-- output: a^-1
modInv :: Integral a => a -> a -> a
modInv a m = let (_, x, _) = extendedGcd a m in x `mod` m
-- RSA key generation
generateKeys :: IO ((Integer, Integer), (Integer, Integer))
generateKeys = do
p <- generatePrime
q <- generatePrime
let n = p * q
phi = (p - 1) * (q - 1)
e = 65537
d = modInv e phi
return ((e, n), (d, n))
-- String to int
-- input: string
-- output: integer
bytesToLong :: String -> Integer
bytesToLong bytes = foldl (\acc byte -> (acc `shiftL` 8) + fromIntegral (ord byte)) 0 bytes
-- longToBytes :: Integer -> String
-- longToBytes 0 = ""
-- longToBytes n = longToBytes (n `div` 256) ++ [chr (fromIntegral (n `mod` 256))]
longToBytes :: Integer -> String
longToBytes 0 = ""
longToBytes n = longToBytes (n `div` 2^8) ++ [chr (fromIntegral (n `mod` 2^8))]
mergeChunks :: [Integer] -> Integer
mergeChunks bytes = foldl (\acc byte -> (acc `shiftL` 66) + fromIntegral byte) 0 bytes
splitChunks :: Integer -> [Integer]
splitChunks 0 = []
splitChunks n = splitChunks (n `div` (2^66)) ++ [fromIntegral (n `mod` (2^66))]
-- Encrypt a message
-- input: message, public key: (e, n)
-- output: encrypted message
encrypt :: String -> (Integer, Integer) -> Integer
encrypt m (e, n) = powMod (bytesToLong m) e n
-- Decrypt a message
-- input: encrypted message, private key: (d, n)
-- output: decrypted message
decrypt :: Integer -> (Integer, Integer) -> String
decrypt c (d, n) = longToBytes ( powMod c d n )
-- Because the length of the message is limited by the modulus n,
-- we can split the message into chunks of size 8 and encrypt each chunk separately.
encrypt2 :: String -> (Integer, Integer) -> Integer
encrypt2 m (e, n) = mergeChunks (foldMap (\c -> [encrypt c (e, n)]) (splitMessage m))
decrypt2 :: Integer -> (Integer, Integer) -> String
decrypt2 m (d, n) = foldMap (\c -> decrypt c (d, n)) (splitChunks m)
-- Efficient modular exponentiation
-- input: base, exponent, modulus
-- output: base^exponent mod modulus
-- algorithm: https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method
-- a^b = (a^(b/2)) * (a^(b/2)) if b is even
-- a^b = a * (a^(b-1)) if b is odd
powMod :: Integer -> Integer -> Integer -> Integer
powMod _ 0 _ = 1
powMod base exp modulus
| even exp = let half = powMod base (exp `div` 2) modulus in (half * half) `mod` modulus
| otherwise = (base * powMod base (exp - 1) modulus) `mod` modulus
-- Split message into chunks of size 8
-- mainly because each character is represented by 8 bits
-- so chunks size would be under 64 bits
-- and n size is from 64 to 66 bits
-- so each chunk is less than n and the message can
-- be handled properly by the RSA algorithm
-- input: message
-- output: list of chunks
splitMessage :: String -> [String]
splitMessage [] = []
splitMessage message = let (chunk, rest) = splitAt 8 message in chunk : splitMessage rest
-- randomPrime = do
-- candidate <- randomRIO (2^32, 2^33)
-- if isPrime candidate then return candidate else randomPrime
generatePrime :: IO Integer
generatePrime = randomRIO (2^32, 2^33) >>= \n -> if isPrime n then return n else generatePrime
-- Define the helper function for divisibility check
notDivisible :: Integer -> Integer -> Bool
notDivisible n x = n `mod` x /= 0
-- Simple primality test (not efficient for large primes)
-- isPrime algorithm: https://en.wikipedia.org/wiki/Primality_test#Simple_methods
isPrime :: Integer -> Bool
isPrime n = n > 1 && all (notDivisible n) [2..((floor . sqrt . fromIntegral $ n))]
-- Encrypt a list of messages using Foldable
encryptMessages :: Foldable t => t String -> (Integer, Integer) -> [Integer]
encryptMessages messages publicKey = foldMap (\m -> [encrypt m publicKey]) messages
-- Decrypt a list of messages using Traversable
decryptMessages :: Traversable t => t Integer -> (Integer, Integer) -> IO (t String)
decryptMessages messages privateKey = traverse (pure . flip decrypt privateKey) messages
main :: IO ()
main = do
(publicKey, privateKey) <- generateKeys
putStrLn $ "Public Key: " ++ show publicKey
putStrLn $ "Private Key: " ++ show privateKey
let messages = "Hello World"
let encryptedMessages = encrypt2 messages publicKey
putStrLn $ "Encrypted Messages: " ++ show encryptedMessages
let decryptedMessages = decrypt2 encryptedMessages privateKey
putStrLn $ "Decrypted Messages: " ++ decryptedMessages