diff --git a/src/lib.rs b/src/lib.rs index 29d6f1c..c448252 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,15 +2,25 @@ use reikna::totient::totient; use reikna::factor::quick_factorize; use std::collections::HashMap; -// Modular arithmetic functions using i64 +/// Modular arithmetic functions using i64 fn mod_add(a: i64, b: i64, p: i64) -> i64 { (a + b) % p } +/// Modular multiplication fn mod_mul(a: i64, b: i64, p: i64) -> i64 { (a * b) % p } +/// Modular exponentiation +/// # Arguments +/// +/// * `base` - Base of the exponentiation. +/// * `exp` - Exponent. +/// * `p` - Prime modulus for the operations. +/// +/// # Returns +/// The result of the exponentiation modulo `p`. pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { let mut result = 1; base %= p; @@ -24,6 +34,14 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { result } +/// Extended Euclidean algorithm +/// # Arguments +/// +/// * `a` - First number. +/// * `b` - Second number. +/// +/// # Returns +/// A tuple with the greatest common divisor and the Bézout coefficients. fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { if b == 0 { (a, 1, 0) // gcd, x, y @@ -33,7 +51,8 @@ fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { } } -pub fn mod_inv(a: i64, modulus: i64) -> i64 { +/// Compute the modular inverse of a modulo modulus +fn mod_inv(a: i64, modulus: i64) -> i64 { let (gcd, x, _) = extended_gcd(a, modulus); if gcd != 1 { panic!("{} and {} are not coprime, no inverse exists", a, modulus); @@ -41,7 +60,29 @@ pub fn mod_inv(a: i64, modulus: i64) -> i64 { (x % modulus + modulus) % modulus // Ensure a positive result } -// Compute n-th root of unity (omega) for p not necessarily prime +/// Compute n-th root of unity (omega) for p not necessarily prime +/// # Arguments +/// +/// * `modulus` - Modulus. n must divide each prime power factor. +/// * `n` - Order of the root of unity. +/// +/// # Returns +/// The n-th root of unity modulo `modulus`. +/// +/// # Examples +/// +/// ``` +/// // For modulus = 17^2 = 289, we compute and verify an 8th root of unity. +/// let modulus = 17 * 17; +/// let n = 8; +/// let omega = ntt::omega(modulus, n); +/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus)); +/// +/// // For modulus = 17*41*73, we compute and verify an 8th root of unity. +/// let modulus = 17*41*73; +/// let omega = ntt::omega(modulus, n); +/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus)); +/// ``` pub fn omega(modulus: i64, n: usize) -> i64 { let factors = factorize(modulus as i64); if factors.len() == 1 { @@ -56,7 +97,29 @@ pub fn omega(modulus: i64, n: usize) -> i64 { } } -// Forward transform using NTT, output bit-reversed +/// Forward transform using NTT, output bit-reversed +/// # Arguments +/// +/// * `a` - Input vector. +/// * `omega` - Primitive root of unity modulo `p`. +/// * `n` - Length of the input vector and the result. +/// * `p` - Prime modulus for the operations. +/// +/// # Returns +/// A vector representing the NTT of the input vector. +/// +/// # Examples +/// +/// ``` +/// let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p +/// let n: usize = 8; // Length of the NTT (must be a power of 2) +/// let omega = ntt::omega(modulus, n); // n-th root of unity +/// let mut a = vec![1, 2, 3, 4]; +/// a.resize(n, 0); +/// // Perform the forward NTT +/// let a_ntt = ntt::ntt(&a, omega, n, modulus); +/// let a_ntt_expected = vec![10, 15, 6, 7, 16, 13, 11, 15]; +/// assert_eq!(a_ntt, a_ntt_expected); pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { let mut result = a.to_vec(); let mut step = n/2; @@ -77,7 +140,16 @@ pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { result } -// Inverse transform using INTT, input bit-reversed +/// Inverse transform using INTT, input bit-reversed +/// # Arguments +/// +/// * `a` - Input vector (bit-reversed). +/// * `omega` - Primitive root of unity modulo `p`. +/// * `n` - Length of the input vector and the result. +/// * `p` - Prime modulus for the operations. +/// +/// # Returns +/// A vector representing the inverse NTT of the input vector. pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { let omega_inv = mod_inv(omega, p); let n_inv = mod_inv(n as i64, p); @@ -103,7 +175,16 @@ pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { .collect() } -// Naive polynomial multiplication +/// Naive polynomial multiplication +/// # Arguments +/// +/// * `a` - First polynomial (as a vector of coefficients). +/// * `b` - Second polynomial (as a vector of coefficients). +/// * `n` - Length of the polynomials and the result. +/// * `p` - Prime modulus for the operations. +/// +/// # Returns +/// A vector representing the polynomial product modulo `p`. pub fn polymul(a: &Vec, b: &Vec, n: i64, p: i64) -> Vec { let mut result = vec![0; n as usize]; for i in 0..a.len() { @@ -145,7 +226,14 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec HashMap { let mut factors = HashMap::new(); for factor in quick_factorize(n as u64) { @@ -155,6 +243,23 @@ fn factorize(n: i64) -> HashMap { } /// Fast computation of a primitive root mod p^e +/// Computes a primitive root mod p and lifts it to p^e by adding successive powers of p +/// # Arguments +/// +/// * `p` - Prime modulus. +/// * `e` - Exponent. +/// +/// # Returns +/// A primitive root modulo `p^e`. +/// +/// # Examples +/// +/// ``` +/// // For p = 17 and e = 2, we compute a primitive root modulo 289. +/// let p = 17; +/// let e = 2; +/// let g = ntt::primitive_root(p, e); +/// assert_eq!(ntt::mod_exp(g, p*(p-1), p*p), 1); pub fn primitive_root(p: i64, e: u32) -> i64 { let g = primitive_root_mod_p(p); let mut g_lifted = g; // Lift it to p^e @@ -167,6 +272,12 @@ pub fn primitive_root(p: i64, e: u32) -> i64 { } /// Finds a primitive root modulo a prime p +/// # Arguments +/// +/// * `p` - Prime modulus. +/// +/// # Returns +/// A primitive root modulo `p`. fn primitive_root_mod_p(p: i64) -> i64 { let phi = p - 1; let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities @@ -179,7 +290,16 @@ fn primitive_root_mod_p(p: i64) -> i64 { 0 // Should never happen } -// the Chinese remainder theorem for two moduli +/// the Chinese remainder theorem for two moduli +/// # Arguments +/// +/// * `a1` - First residue. +/// * `n1` - First modulus. +/// * `a2` - Second residue. +/// * `n2` - Second modulus. +/// +/// # Returns +/// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2). pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { let n = n1 * n2; let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2 @@ -188,10 +308,17 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { if x < 0 { x + n } else { x } } -// computes an n^th root of unity modulo a composite modulus -// note we require that an n^th root of unity exists for each multiplicative group modulo p^e -// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus -// for the NTT, we require than a 2n^th root of unity exists +/// computes an n^th root of unity modulo a composite modulus +/// note we require that an n^th root of unity exists for each multiplicative group modulo p^e +/// use the CRT isomorphism to pull back the list of n^th roots of unity to the composite modulus +/// for the NTT, we require than a 2n^th root of unity exists +/// # Arguments +/// +/// * `modulus` - Modulus. n must divide each prime power factor. +/// * `n` - Order of the root of unity. +/// +/// # Returns +/// The n-th root of unity modulo `modulus`. pub fn root_of_unity(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); let mut result = 1; @@ -202,10 +329,17 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 { result } -//ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n +/// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n +/// # Arguments +/// +/// * `omega` - n-th root of unity. +/// * `n` - Order of the root of unity. +/// * `modulus` - Modulus. +/// +/// # Returns +/// True if the root of unity satisfies the condition. pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity"); assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)"); true -} - +} \ No newline at end of file