diff --git a/benches/bench.rs b/benches/bench.rs index 8be78345..9e2dadff 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -89,12 +89,12 @@ fn bench_modpow<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) { let params = moduli .iter() - .map(|modulus| DynResidueParams::new(*modulus)) + .map(|modulus| DynResidueParams::new(modulus)) .collect::>(); let xs_m = xs .iter() .zip(params.iter()) - .map(|(x, p)| DynResidue::new(*x, *p)) + .map(|(x, p)| DynResidue::new(x, *p)) .collect::>(); group.bench_function("modpow, 4^4", |b| { diff --git a/src/ct_choice.rs b/src/ct_choice.rs new file mode 100644 index 00000000..1308dd32 --- /dev/null +++ b/src/ct_choice.rs @@ -0,0 +1,83 @@ +use subtle::Choice; + +use crate::Word; + +/// A boolean value returned by constant-time `const fn`s. +// TODO: should be replaced by `subtle::Choice` or `CtOption` +// when `subtle` starts supporting const fns. +#[derive(Debug, Copy, Clone)] +pub struct CtChoice(Word); + +impl CtChoice { + /// The falsy vaue. + pub const FALSE: Self = Self(0); + + /// The truthy vaue. + pub const TRUE: Self = Self(Word::MAX); + + /// Returns the truthy value if `value == Word::MAX`, and the falsy value if `value == 0`. + /// Panics for other values. + pub(crate) const fn from_mask(value: Word) -> Self { + debug_assert!(value == Self::FALSE.0 || value == Self::TRUE.0); + Self(value) + } + + /// Returns the truthy value if `value == 1`, and the falsy value if `value == 0`. + /// Panics for other values. + pub(crate) const fn from_lsb(value: Word) -> Self { + debug_assert!(value == Self::FALSE.0 || value == 1); + Self(value.wrapping_neg()) + } + + pub(crate) const fn not(&self) -> Self { + Self(!self.0) + } + + pub(crate) const fn and(&self, other: Self) -> Self { + Self(self.0 & other.0) + } + + pub(crate) const fn or(&self, other: Self) -> Self { + Self(self.0 | other.0) + } + + /// Return `b` if `self` is truthy, otherwise return `a`. + pub(crate) const fn select(&self, a: Word, b: Word) -> Word { + a ^ (self.0 & (a ^ b)) + } + + /// Return `x` if `self` is truthy, otherwise return 0. + pub(crate) const fn if_true(&self, x: Word) -> Word { + x & self.0 + } + + pub(crate) const fn is_true_vartime(&self) -> bool { + self.0 == CtChoice::TRUE.0 + } +} + +impl From for Choice { + fn from(choice: CtChoice) -> Self { + Choice::from(choice.0 as u8 & 1) + } +} + +impl From for bool { + fn from(choice: CtChoice) -> Self { + choice.is_true_vartime() + } +} + +#[cfg(test)] +mod tests { + use super::CtChoice; + use crate::Word; + + #[test] + fn select() { + let a: Word = 1; + let b: Word = 2; + assert_eq!(CtChoice::TRUE.select(a, b), b); + assert_eq!(CtChoice::FALSE.select(a, b), a); + } +} diff --git a/src/lib.rs b/src/lib.rs index 76db5221..4c883960 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -161,6 +161,7 @@ mod nlimbs; #[cfg(feature = "generic-array")] mod array; mod checked; +mod ct_choice; mod limb; mod non_zero; mod traits; @@ -169,6 +170,7 @@ mod wrapping; pub use crate::{ checked::Checked, + ct_choice::CtChoice, limb::{Limb, WideWord, Word}, non_zero::NonZero, traits::*, @@ -178,8 +180,6 @@ pub use crate::{ }; pub use subtle; -pub(crate) use limb::{SignedWord, WideSignedWord}; - #[cfg(feature = "generic-array")] pub use { crate::array::{ArrayDecoding, ArrayEncoding, ByteArray}, diff --git a/src/limb.rs b/src/limb.rs index 3e0bc7e0..3842270a 100644 --- a/src/limb.rs +++ b/src/limb.rs @@ -38,18 +38,10 @@ compile_error!("this crate builds on 32-bit and 64-bit platforms only"); #[cfg(target_pointer_width = "32")] pub type Word = u32; -/// Signed integer type that corresponds to [`Word`]. -#[cfg(target_pointer_width = "32")] -pub(crate) type SignedWord = i32; - /// Unsigned wide integer type: double the width of [`Word`]. #[cfg(target_pointer_width = "32")] pub type WideWord = u64; -/// Signed wide integer type: double the width of [`Limb`]. -#[cfg(target_pointer_width = "32")] -pub(crate) type WideSignedWord = i64; - // // 64-bit definitions // @@ -58,18 +50,10 @@ pub(crate) type WideSignedWord = i64; #[cfg(target_pointer_width = "64")] pub type Word = u64; -/// Signed integer type that corresponds to [`Word`]. -#[cfg(target_pointer_width = "64")] -pub(crate) type SignedWord = i64; - /// Wide integer type: double the width of [`Word`]. #[cfg(target_pointer_width = "64")] pub type WideWord = u128; -/// Signed wide integer type: double the width of [`SignedWord`]. -#[cfg(target_pointer_width = "64")] -pub(crate) type WideSignedWord = i128; - /// Highest bit in a [`Limb`]. pub(crate) const HI_BIT: usize = Limb::BITS - 1; @@ -106,14 +90,6 @@ impl Limb { /// Size of the inner integer in bytes. #[cfg(target_pointer_width = "64")] pub const BYTES: usize = 8; - - /// Return `a` if `c`==0 or `b` if `c`==`Word::MAX`. - /// - /// Const-friendly: we can't yet use `subtle` in `const fn` contexts. - #[inline] - pub(crate) const fn ct_select(a: Self, b: Self, c: Word) -> Self { - Self(a.0 ^ (c & (a.0 ^ b.0))) - } } impl Bounded for Limb { diff --git a/src/limb/cmp.rs b/src/limb/cmp.rs index 009ec5ec..4cdec5b5 100644 --- a/src/limb/cmp.rs +++ b/src/limb/cmp.rs @@ -1,6 +1,7 @@ //! Limb comparisons -use super::{Limb, SignedWord, WideSignedWord, Word, HI_BIT}; +use super::HI_BIT; +use crate::{CtChoice, Limb}; use core::cmp::Ordering; use subtle::{Choice, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess}; @@ -24,58 +25,45 @@ impl Limb { self.0 == other.0 } - /// Returns all 1's if `a`!=0 or 0 if `a`==0. - /// - /// Const-friendly: we can't yet use `subtle` in `const fn` contexts. + /// Return `b` if `c` is truthy, otherwise return `a`. #[inline] - pub(crate) const fn is_nonzero(self) -> Word { - let inner = self.0 as SignedWord; - ((inner | inner.saturating_neg()) >> HI_BIT) as Word + pub(crate) const fn ct_select(a: Self, b: Self, c: CtChoice) -> Self { + Self(c.select(a.0, b.0)) } + /// Returns the truthy value if `self != 0` and the falsy value otherwise. #[inline] - pub(crate) const fn ct_cmp(lhs: Self, rhs: Self) -> SignedWord { - let a = lhs.0 as WideSignedWord; - let b = rhs.0 as WideSignedWord; - let gt = ((b - a) >> Limb::BITS) & 1; - let lt = ((a - b) >> Limb::BITS) & 1 & !gt; - (gt as SignedWord) - (lt as SignedWord) + pub(crate) const fn ct_is_nonzero(&self) -> CtChoice { + let inner = self.0; + CtChoice::from_lsb((inner | inner.wrapping_neg()) >> HI_BIT) } - /// Returns `Word::MAX` if `lhs == rhs` and `0` otherwise. + /// Returns the truthy value if `lhs == rhs` and the falsy value otherwise. #[inline] - pub(crate) const fn ct_eq(lhs: Self, rhs: Self) -> Word { + pub(crate) const fn ct_eq(lhs: Self, rhs: Self) -> CtChoice { let x = lhs.0; let y = rhs.0; - // c == 0 if and only if x == y - let c = x ^ y; - - // If c == 0, then c and -c are both equal to zero; - // otherwise, one or both will have its high bit set. - let d = (c | c.wrapping_neg()) >> (Limb::BITS - 1); - - // Result is the opposite of the high bit (now shifted to low). - // Convert 1 to Word::MAX. - (d ^ 1).wrapping_neg() + // x ^ y == 0 if and only if x == y + Self(x ^ y).ct_is_nonzero().not() } - /// Returns `Word::MAX` if `lhs < rhs` and `0` otherwise. + /// Returns the truthy value if `lhs < rhs` and the falsy value otherwise. #[inline] - pub(crate) const fn ct_lt(lhs: Self, rhs: Self) -> Word { + pub(crate) const fn ct_lt(lhs: Self, rhs: Self) -> CtChoice { let x = lhs.0; let y = rhs.0; let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (Limb::BITS - 1); - bit.wrapping_neg() + CtChoice::from_lsb(bit) } - /// Returns `Word::MAX` if `lhs <= rhs` and `0` otherwise. + /// Returns the truthy value if `lhs <= rhs` and the falsy value otherwise. #[inline] - pub(crate) const fn ct_le(lhs: Self, rhs: Self) -> Word { + pub(crate) const fn ct_le(lhs: Self, rhs: Self) -> CtChoice { let x = lhs.0; let y = rhs.0; let bit = (((!x) | y) & ((x ^ y) | !(y.wrapping_sub(x)))) >> (Limb::BITS - 1); - bit.wrapping_neg() + CtChoice::from_lsb(bit) } } diff --git a/src/uint/add.rs b/src/uint/add.rs index a105127e..21aa5d57 100644 --- a/src/uint/add.rs +++ b/src/uint/add.rs @@ -1,6 +1,6 @@ //! [`Uint`] addition operations. -use crate::{Checked, CheckedAdd, Limb, Uint, Word, Wrapping, Zero}; +use crate::{Checked, CheckedAdd, CtChoice, Limb, Uint, Wrapping, Zero}; use core::ops::{Add, AddAssign}; use subtle::CtOption; @@ -37,12 +37,16 @@ impl Uint { self.adc(rhs, Limb::ZERO).0 } - /// Perform wrapping addition, returning the overflow bit as a `Word` that is either 0...0 or 1...1. - pub(crate) const fn conditional_wrapping_add(&self, rhs: &Self, choice: Word) -> (Self, Word) { - let actual_rhs = Uint::ct_select(Uint::ZERO, *rhs, choice); + /// Perform wrapping addition, returning the truthy value as the second element of the tuple + /// if an overflow has occurred. + pub(crate) const fn conditional_wrapping_add( + &self, + rhs: &Self, + choice: CtChoice, + ) -> (Self, CtChoice) { + let actual_rhs = Uint::ct_select(&Uint::ZERO, rhs, choice); let (sum, carry) = self.adc(&actual_rhs, Limb::ZERO); - - (sum, carry.0.wrapping_mul(Word::MAX)) + (sum, CtChoice::from_lsb(carry.0)) } } diff --git a/src/uint/bits.rs b/src/uint/bits.rs index 66fb9ee7..83401434 100644 --- a/src/uint/bits.rs +++ b/src/uint/bits.rs @@ -1,14 +1,13 @@ -use crate::{Limb, Uint, Word}; +use crate::{CtChoice, Limb, Uint, Word}; impl Uint { - /// Get the value of the bit at position `index`, as a 0- or 1-valued Word. - /// Returns 0 for indices out of range. + /// Returns `true` if the bit at position `index` is set, `false` otherwise. #[inline(always)] - pub const fn bit_vartime(self, index: usize) -> Word { + pub const fn bit_vartime(self, index: usize) -> bool { if index >= Self::BITS { - 0 + false } else { - (self.limbs[index / Limb::BITS].0 >> (index % Limb::BITS)) & 1 + (self.limbs[index / Limb::BITS].0 >> (index % Limb::BITS)) & 1 == 1 } } @@ -21,14 +20,7 @@ impl Uint { } let limb = self.limbs[i].0; - let bits = (Limb::BITS * (i + 1)) as Word - limb.leading_zeros() as Word; - - Limb::ct_select( - Limb(bits), - Limb::ZERO, - !self.limbs[0].is_nonzero() & !Limb(i as Word).is_nonzero(), - ) - .0 as usize + Limb::BITS * (i + 1) - limb.leading_zeros() as usize } /// Calculate the number of leading zeros in the binary representation of this number. @@ -37,13 +29,14 @@ impl Uint { let mut count: Word = 0; let mut i = LIMBS; - let mut mask = Word::MAX; + let mut nonzero_limb_not_encountered = CtChoice::TRUE; while i > 0 { i -= 1; let l = limbs[i]; let z = l.leading_zeros() as Word; - count += z & mask; - mask &= !l.is_nonzero(); + count += nonzero_limb_not_encountered.if_true(z); + nonzero_limb_not_encountered = + nonzero_limb_not_encountered.and(l.ct_is_nonzero().not()); } count as usize @@ -55,12 +48,13 @@ impl Uint { let mut count: Word = 0; let mut i = 0; - let mut mask = Word::MAX; + let mut nonzero_limb_not_encountered = CtChoice::TRUE; while i < LIMBS { let l = limbs[i]; let z = l.trailing_zeros() as Word; - count += z & mask; - mask &= !l.is_nonzero(); + count += nonzero_limb_not_encountered.if_true(z); + nonzero_limb_not_encountered = + nonzero_limb_not_encountered.and(l.ct_is_nonzero().not()); i += 1; } @@ -72,9 +66,9 @@ impl Uint { Self::BITS - self.leading_zeros() } - /// Get the value of the bit at position `index`, as a 0- or 1-valued Word. - /// Returns 0 for indices out of range. - pub const fn bit(self, index: usize) -> Word { + /// Get the value of the bit at position `index`, as a truthy or falsy `CtChoice`. + /// Returns the falsy value for indices out of range. + pub const fn bit(self, index: usize) -> CtChoice { let limb_num = Limb((index / Limb::BITS) as Word); let index_in_limb = index % Limb::BITS; let index_mask = 1 << index_in_limb; @@ -86,11 +80,11 @@ impl Uint { while i < LIMBS { let bit = limbs[i] & index_mask; let is_right_limb = Limb::ct_eq(limb_num, Limb(i as Word)); - result |= bit & is_right_limb; + result |= is_right_limb.if_true(bit); i += 1; } - result >> index_in_limb + CtChoice::from_lsb(result >> index_in_limb) } } @@ -109,25 +103,25 @@ mod tests { #[test] fn bit_vartime() { let u = uint_with_bits_at(&[16, 48, 112, 127, 255]); - assert_eq!(u.bit_vartime(0), 0); - assert_eq!(u.bit_vartime(1), 0); - assert_eq!(u.bit_vartime(16), 1); - assert_eq!(u.bit_vartime(127), 1); - assert_eq!(u.bit_vartime(255), 1); - assert_eq!(u.bit_vartime(256), 0); - assert_eq!(u.bit_vartime(260), 0); + assert!(!u.bit_vartime(0)); + assert!(!u.bit_vartime(1)); + assert!(u.bit_vartime(16)); + assert!(u.bit_vartime(127)); + assert!(u.bit_vartime(255)); + assert!(!u.bit_vartime(256)); + assert!(!u.bit_vartime(260)); } #[test] fn bit() { let u = uint_with_bits_at(&[16, 48, 112, 127, 255]); - assert_eq!(u.bit(0), 0); - assert_eq!(u.bit(1), 0); - assert_eq!(u.bit(16), 1); - assert_eq!(u.bit(127), 1); - assert_eq!(u.bit(255), 1); - assert_eq!(u.bit(256), 0); - assert_eq!(u.bit(260), 0); + assert!(!u.bit(0).is_true_vartime()); + assert!(!u.bit(1).is_true_vartime()); + assert!(u.bit(16).is_true_vartime()); + assert!(u.bit(127).is_true_vartime()); + assert!(u.bit(255).is_true_vartime()); + assert!(!u.bit(256).is_true_vartime()); + assert!(!u.bit(260).is_true_vartime()); } #[test] diff --git a/src/uint/cmp.rs b/src/uint/cmp.rs index 146fafe4..8815369e 100644 --- a/src/uint/cmp.rs +++ b/src/uint/cmp.rs @@ -3,16 +3,14 @@ //! By default these are all constant-time and use the `subtle` crate. use super::Uint; -use crate::{limb::HI_BIT, Limb, SignedWord, WideSignedWord, Word, Zero}; +use crate::{CtChoice, Limb}; use core::cmp::Ordering; use subtle::{Choice, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess}; impl Uint { - /// Return `a` if `c`==0 or `b` if `c`==`Word::MAX`. - /// - /// Const-friendly: we can't yet use `subtle` in `const fn` contexts. + /// Return `b` if `c` is truthy, otherwise return `a`. #[inline] - pub(crate) const fn ct_select(a: Uint, b: Uint, c: Word) -> Self { + pub(crate) const fn ct_select(a: &Self, b: &Self, c: CtChoice) -> Self { let mut limbs = [Limb::ZERO; LIMBS]; let mut i = 0; @@ -25,88 +23,81 @@ impl Uint { } #[inline] - pub(crate) const fn ct_swap(a: Uint, b: Uint, c: Word) -> (Self, Self) { + pub(crate) const fn ct_swap(a: &Self, b: &Self, c: CtChoice) -> (Self, Self) { let new_a = Self::ct_select(a, b, c); let new_b = Self::ct_select(b, a, c); (new_a, new_b) } - /// Returns all 1's if `self`!=0 or 0 if `self`==0. - /// - /// Const-friendly: we can't yet use `subtle` in `const fn` contexts. + /// Returns the truthy value if `self`!=0 or the falsy value otherwise. #[inline] - pub(crate) const fn ct_is_nonzero(&self) -> Word { + pub(crate) const fn ct_is_nonzero(&self) -> CtChoice { let mut b = 0; let mut i = 0; while i < LIMBS { b |= self.limbs[i].0; i += 1; } - Limb::is_nonzero(Limb(b)) + Limb(b).ct_is_nonzero() } - pub(crate) const fn ct_is_odd(&self) -> Word { - (self.limbs[0].0 & 1).wrapping_mul(Word::MAX) + /// Returns the truthy value if `self` is odd or the falsy value otherwise. + pub(crate) const fn ct_is_odd(&self) -> CtChoice { + CtChoice::from_lsb(self.limbs[0].0 & 1) } - /// Returns -1 if self < rhs - /// 0 if self == rhs - /// 1 if self > rhs - /// - /// Const-friendly: we can't yet use `subtle` in `const fn` contexts. + /// Returns the truthy value if `self == rhs` or the falsy value otherwise. #[inline] - pub(crate) const fn ct_cmp(&self, rhs: &Self) -> SignedWord { - let mut gt = 0; - let mut lt = 0; - let mut i = LIMBS; - - while i > 0 { - let a = self.limbs[i - 1].0 as WideSignedWord; - let b = rhs.limbs[i - 1].0 as WideSignedWord; - gt |= ((b - a) >> Limb::BITS) & 1 & !lt; - lt |= ((a - b) >> Limb::BITS) & 1 & !gt; - i -= 1; - } - (gt as SignedWord) - (lt as SignedWord) - } - - /// Returns 0 if self == rhs or Word::MAX if self != rhs. - /// Const-friendly: we can't yet use `subtle` in `const fn` contexts. - #[inline] - pub(crate) const fn ct_not_eq(&self, rhs: &Self) -> Word { + pub(crate) const fn ct_eq(lhs: &Self, rhs: &Self) -> CtChoice { let mut acc = 0; let mut i = 0; while i < LIMBS { - acc |= self.limbs[i].0 ^ rhs.limbs[i].0; + acc |= lhs.limbs[i].0 ^ rhs.limbs[i].0; i += 1; } - let acc = acc as SignedWord; - ((acc | acc.wrapping_neg()) >> HI_BIT) as Word + + // acc == 0 if and only if self == rhs + Limb(acc).ct_is_nonzero().not() + } + + /// Returns the truthy value if `self <= rhs` and the falsy value otherwise. + #[inline] + pub(crate) const fn ct_lt(lhs: &Self, rhs: &Self) -> CtChoice { + // We could use the same approach as in Limb::ct_lt(), + // but since we have to use Uint::wrapping_sub(), which calls `sbb()`, + // there are no savings compared to just calling `sbb()` directly. + let (_res, borrow) = lhs.sbb(rhs, Limb::ZERO); + CtChoice::from_mask(borrow.0) + } + + /// Returns the truthy value if `self <= rhs` and the falsy value otherwise. + #[inline] + pub(crate) const fn ct_gt(lhs: &Self, rhs: &Self) -> CtChoice { + let (_res, borrow) = rhs.sbb(lhs, Limb::ZERO); + CtChoice::from_mask(borrow.0) } } impl ConstantTimeEq for Uint { #[inline] fn ct_eq(&self, other: &Self) -> Choice { - Choice::from((!self.ct_not_eq(other) as u8) & 1) + Uint::ct_eq(self, other).into() } } impl ConstantTimeGreater for Uint { #[inline] fn ct_gt(&self, other: &Self) -> Choice { - let underflow = other.sbb(self, Limb::ZERO).1; - !underflow.is_zero() + Uint::ct_gt(self, other).into() } } impl ConstantTimeLess for Uint { #[inline] fn ct_lt(&self, other: &Self) -> Choice { - let underflow = self.sbb(other, Limb::ZERO).1; - !underflow.is_zero() + Uint::ct_lt(self, other).into() } } @@ -114,14 +105,15 @@ impl Eq for Uint {} impl Ord for Uint { fn cmp(&self, other: &Self) -> Ordering { - match self.ct_cmp(other) { - -1 => Ordering::Less, - 1 => Ordering::Greater, - n => { - debug_assert_eq!(n, 0); - debug_assert!(bool::from(self.ct_eq(other))); - Ordering::Equal - } + let is_lt = self.ct_lt(other); + let is_eq = self.ct_eq(other); + + if is_lt.into() { + Ordering::Less + } else if is_eq.into() { + Ordering::Equal + } else { + Ordering::Greater } } } diff --git a/src/uint/div.rs b/src/uint/div.rs index 4d078f50..0f3df256 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -1,9 +1,7 @@ //! [`Uint`] division operations. use super::div_limb::{div_rem_limb_with_reciprocal, Reciprocal}; -use super::Uint; -use crate::limb::Word; -use crate::{Limb, NonZero, Wrapping}; +use crate::{CtChoice, Limb, NonZero, Uint, Word, Wrapping}; use core::ops::{Div, DivAssign, Rem, RemAssign}; use subtle::CtOption; @@ -26,9 +24,11 @@ impl Uint { } /// Computes `self` / `rhs`, returns the quotient (q) and remainder (r). + /// Returns the truthy value as the third element of the tuple if `rhs != 0`, + /// and the falsy value otherwise. #[inline(always)] - pub(crate) fn ct_div_rem_limb(&self, rhs: Limb) -> (Self, Limb, u8) { - let (reciprocal, is_some) = Reciprocal::new_const(rhs); + pub(crate) const fn ct_div_rem_limb(&self, rhs: Limb) -> (Self, Limb, CtChoice) { + let (reciprocal, is_some) = Reciprocal::ct_new(rhs); let (quo, rem) = div_rem_limb_with_reciprocal(self, &reciprocal); (quo, rem, is_some) } @@ -36,22 +36,22 @@ impl Uint { /// Computes `self` / `rhs`, returns the quotient (q) and remainder (r). #[inline(always)] pub fn div_rem_limb(&self, rhs: NonZero) -> (Self, Limb) { - let (quo, rem, is_some) = self.ct_div_rem_limb(*rhs); // Guaranteed to succeed since `rhs` is nonzero. - debug_assert!(is_some == 1); + let (quo, rem, _is_some) = self.ct_div_rem_limb(*rhs); (quo, rem) } /// Computes `self` / `rhs`, returns the quotient (q), remainder (r) - /// and 1 for is_some or 0 for is_none. The results can be wrapped in [`CtOption`]. - /// NOTE: Use only if you need to access const fn. Otherwise use `div_rem` because + /// and the truthy value for is_some or the falsy value for is_none. + /// + /// NOTE: Use only if you need to access const fn. Otherwise use [`div_rem`] because /// the value for is_some needs to be checked before using `q` and `r`. /// /// This is variable only with respect to `rhs`. /// /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. - pub(crate) const fn ct_div_rem(&self, rhs: &Self) -> (Self, Self, u8) { + pub(crate) const fn ct_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) { let mb = rhs.bits_vartime(); let mut bd = Self::BITS - mb; let mut rem = *self; @@ -60,9 +60,9 @@ impl Uint { loop { let (mut r, borrow) = rem.sbb(&c, Limb::ZERO); - rem = Self::ct_select(r, rem, borrow.0); + rem = Self::ct_select(&r, &rem, CtChoice::from_mask(borrow.0)); r = quo.bitor(&Self::ONE); - quo = Self::ct_select(r, quo, borrow.0); + quo = Self::ct_select(&r, &quo, CtChoice::from_mask(borrow.0)); if bd == 0 { break; } @@ -71,19 +71,20 @@ impl Uint { quo = quo.shl_vartime(1); } - let is_some = Limb(mb as Word).is_nonzero(); - quo = Self::ct_select(Self::ZERO, quo, is_some); - (quo, rem, (is_some & 1) as u8) + let is_some = Limb(mb as Word).ct_is_nonzero(); + quo = Self::ct_select(&Self::ZERO, &quo, is_some); + (quo, rem, is_some) } /// Computes `self` % `rhs`, returns the remainder and - /// and 1 for is_some or 0 for is_none. The results can be wrapped in [`CtOption`]. - /// NOTE: Use only if you need to access const fn. Otherwise use `reduce` + /// and the truthy value for is_some or the falsy value for is_none. + /// + /// NOTE: Use only if you need to access const fn. Otherwise use [`rem`]. /// This is variable only with respect to `rhs`. /// /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. - pub(crate) const fn ct_rem(&self, rhs: &Self) -> (Self, u8) { + pub(crate) const fn ct_rem(&self, rhs: &Self) -> (Self, CtChoice) { let mb = rhs.bits_vartime(); let mut bd = Self::BITS - mb; let mut rem = *self; @@ -91,7 +92,7 @@ impl Uint { loop { let (r, borrow) = rem.sbb(&c, Limb::ZERO); - rem = Self::ct_select(r, rem, borrow.0); + rem = Self::ct_select(&r, &rem, CtChoice::from_mask(borrow.0)); if bd == 0 { break; } @@ -99,19 +100,18 @@ impl Uint { c = c.shr_vartime(1); } - let is_some = Limb(mb as Word).is_nonzero(); - (rem, (is_some & 1) as u8) + let is_some = Limb(mb as Word).ct_is_nonzero(); + (rem, is_some) } /// Computes `self` % `rhs`, returns the remainder and - /// and 1 for is_some or 0 for is_none. The results can be wrapped in [`CtOption`]. - /// NOTE: Use only if you need to access const fn. Otherwise use `reduce` + /// and the truthy value for is_some or the falsy value for is_none. + /// /// This is variable only with respect to `rhs`. /// /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. - #[allow(dead_code)] - pub(crate) const fn ct_rem_wide(lower_upper: (Self, Self), rhs: &Self) -> (Self, u8) { + pub(crate) const fn ct_rem_wide(lower_upper: (Self, Self), rhs: &Self) -> (Self, CtChoice) { let mb = rhs.bits_vartime(); // The number of bits to consider is two sets of limbs * BITS - mb (modulus bitcount) @@ -127,8 +127,8 @@ impl Uint { let (lower_sub, borrow) = lower.sbb(&c.0, Limb::ZERO); let (upper_sub, borrow) = upper.sbb(&c.1, borrow); - lower = Self::ct_select(lower_sub, lower, borrow.0); - upper = Self::ct_select(upper_sub, upper, borrow.0); + lower = Self::ct_select(&lower_sub, &lower, CtChoice::from_mask(borrow.0)); + upper = Self::ct_select(&upper_sub, &upper, CtChoice::from_mask(borrow.0)); if bd == 0 { break; } @@ -136,8 +136,8 @@ impl Uint { c = Self::shr_vartime_wide(c, 1); } - let is_some = Limb(mb as Word).is_nonzero(); - (lower, (is_some & 1) as u8) + let is_some = Limb(mb as Word).ct_is_nonzero(); + (lower, is_some) } /// Computes `self` % 2^k. Faster than reduce since its a power of 2. @@ -145,8 +145,7 @@ impl Uint { pub const fn rem2k(&self, k: usize) -> Self { let highest = (LIMBS - 1) as u32; let index = k as u32 / (Limb::BITS as u32); - let res = Limb::ct_cmp(Limb::from_u32(index), Limb::from_u32(highest)) - 1; - let le = Limb::is_nonzero(Limb(res as Word)); + let le = Limb::ct_le(Limb::from_u32(index), Limb::from_u32(highest)); let word = Limb::ct_select(Limb::from_u32(highest), Limb::from_u32(index), le).0 as usize; let base = k % Limb::BITS; @@ -168,17 +167,15 @@ impl Uint { /// Computes self / rhs, returns the quotient, remainder. pub fn div_rem(&self, rhs: &NonZero) -> (Self, Self) { - let (q, r, c) = self.ct_div_rem(rhs); // Since `rhs` is nonzero, this should always hold. - debug_assert!(c == 1); + let (q, r, _c) = self.ct_div_rem(rhs); (q, r) } /// Computes self % rhs, returns the remainder. pub fn rem(&self, rhs: &NonZero) -> Self { - let (r, c) = self.ct_rem(rhs); // Since `rhs` is nonzero, this should always hold. - debug_assert!(c == 1); + let (r, _c) = self.ct_rem(rhs); r } @@ -189,7 +186,7 @@ impl Uint { /// Panics if `rhs == 0`. pub const fn wrapping_div(&self, rhs: &Self) -> Self { let (q, _, c) = self.ct_div_rem(rhs); - assert!(c == 1, "divide by zero"); + assert!(c.is_true_vartime(), "divide by zero"); q } @@ -209,7 +206,7 @@ impl Uint { /// Panics if `rhs == 0`. pub const fn wrapping_rem(&self, rhs: &Self) -> Self { let (r, c) = self.ct_rem(rhs); - assert!(c == 1, "modulo zero"); + assert!(c.is_true_vartime(), "modulo zero"); r } @@ -613,7 +610,7 @@ mod tests { let lhs = U256::from(*n); let rhs = U256::from(*d); let (q, r, is_some) = lhs.ct_div_rem(&rhs); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(U256::from(*e), q); assert_eq!(U256::from(*ee), r); } @@ -627,9 +624,9 @@ mod tests { let num = U256::random(&mut rng).shr_vartime(128); let den = U256::random(&mut rng).shr_vartime(128); let n = num.checked_mul(&den); - if n.is_some().unwrap_u8() == 1 { + if n.is_some().into() { let (q, _, is_some) = n.unwrap().ct_div_rem(&den); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(q, num); } } @@ -651,7 +648,7 @@ mod tests { #[test] fn div_zero() { let (q, r, is_some) = U256::ONE.ct_div_rem(&U256::ZERO); - assert_eq!(is_some, 0); + assert!(!is_some.is_true_vartime()); assert_eq!(q, U256::ZERO); assert_eq!(r, U256::ONE); } @@ -659,7 +656,7 @@ mod tests { #[test] fn div_one() { let (q, r, is_some) = U256::from(10u8).ct_div_rem(&U256::ONE); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(q, U256::from(10u8)); assert_eq!(r, U256::ZERO); } @@ -667,7 +664,7 @@ mod tests { #[test] fn reduce_one() { let (r, is_some) = U256::from(10u8).ct_rem(&U256::ONE); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(r, U256::ZERO); } @@ -675,33 +672,33 @@ mod tests { fn reduce_zero() { let u = U256::from(10u8); let (r, is_some) = u.ct_rem(&U256::ZERO); - assert_eq!(is_some, 0); + assert!(!is_some.is_true_vartime()); assert_eq!(r, u); } #[test] fn reduce_tests() { let (r, is_some) = U256::from(10u8).ct_rem(&U256::from(2u8)); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(r, U256::ZERO); let (r, is_some) = U256::from(10u8).ct_rem(&U256::from(3u8)); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(r, U256::ONE); let (r, is_some) = U256::from(10u8).ct_rem(&U256::from(7u8)); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(r, U256::from(3u8)); } #[test] fn reduce_tests_wide_zero_padded() { let (r, is_some) = U256::ct_rem_wide((U256::from(10u8), U256::ZERO), &U256::from(2u8)); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(r, U256::ZERO); let (r, is_some) = U256::ct_rem_wide((U256::from(10u8), U256::ZERO), &U256::from(3u8)); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(r, U256::ONE); let (r, is_some) = U256::ct_rem_wide((U256::from(10u8), U256::ZERO), &U256::from(7u8)); - assert_eq!(is_some, 1); + assert!(is_some.is_true_vartime()); assert_eq!(r, U256::from(3u8)); } diff --git a/src/uint/div_limb.rs b/src/uint/div_limb.rs index 0709e615..9bbd828e 100644 --- a/src/uint/div_limb.rs +++ b/src/uint/div_limb.rs @@ -3,7 +3,7 @@ //! (DOI: 10.1109/TC.2010.143, ). use subtle::{Choice, ConditionallySelectable, CtOption}; -use crate::{Limb, Uint, WideWord, Word}; +use crate::{CtChoice, Limb, Uint, WideWord, Word}; /// Calculates the reciprocal of the given 32-bit divisor with the highmost bit set. #[cfg(target_pointer_width = "32")] @@ -33,7 +33,7 @@ pub const fn reciprocal(d: Word) -> Word { // Hence the `ct_select()`. let x = v2.wrapping_add(1); let (hi, _lo) = mulhilo(x, d); - let hi = Limb::ct_select(Limb(d), Limb(hi), Limb(x).is_nonzero()).0; + let hi = Limb::ct_select(Limb(d), Limb(hi), Limb(x).ct_is_nonzero()).0; v2.wrapping_sub(hi).wrapping_sub(d) } @@ -63,7 +63,7 @@ pub const fn reciprocal(d: Word) -> Word { // Hence the `ct_select()`. let x = v3.wrapping_add(1); let (hi, _lo) = mulhilo(x, d); - let hi = Limb::ct_select(Limb(d), Limb(hi), Limb(x).is_nonzero()).0; + let hi = Limb::ct_select(Limb(d), Limb(hi), Limb(x).ct_is_nonzero()).0; v3.wrapping_sub(hi).wrapping_sub(d) } @@ -168,17 +168,18 @@ pub struct Reciprocal { impl Reciprocal { /// Pre-calculates a reciprocal for a known divisor, /// to be used in the single-limb division later. - /// Returns the reciprocal, and `1` if `divisor != 0` and `0` otherwise. + /// Returns the reciprocal, and the truthy value if `divisor != 0` + /// and the falsy value otherwise. /// - /// Note: if the returned flag is `0`, the returned reciprocal object is still self-consistent + /// Note: if the returned flag is falsy, the returned reciprocal object is still self-consistent /// and can be passed to functions here without causing them to panic, /// but the results are naturally not to be used. - pub const fn new_const(divisor: Limb) -> (Self, u8) { + pub const fn ct_new(divisor: Limb) -> (Self, CtChoice) { // Assuming this is constant-time for primitive types. let shift = divisor.0.leading_zeros(); #[allow(trivial_numeric_casts)] - let is_some = Limb((Word::BITS - shift) as Word).is_nonzero(); + let is_some = Limb((Word::BITS - shift) as Word).ct_is_nonzero(); // If `divisor = 0`, shifting `divisor` by `leading_zeros == Word::BITS` will cause a panic. // Have to substitute a "bogus" shift in that case. @@ -199,7 +200,7 @@ impl Reciprocal { shift, reciprocal: reciprocal(divisor_normalized), }, - (is_some & 1) as u8, + is_some, ) } @@ -220,8 +221,8 @@ impl Reciprocal { /// A non-const-fn version of `new_const()`, wrapping the result in a `CtOption`. pub fn new(divisor: Limb) -> CtOption { - let (rec, is_some) = Self::new_const(divisor); - CtOption::new(rec, Choice::from(is_some)) + let (rec, is_some) = Self::ct_new(divisor); + CtOption::new(rec, is_some.into()) } } diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index 4e0e650d..ef3c161f 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -1,7 +1,5 @@ -use subtle::{Choice, CtOption}; - use super::Uint; -use crate::{Limb, Word}; +use crate::{CtChoice, Limb}; impl Uint { /// Computes 1/`self` mod 2^k as specified in Algorithm 4 from @@ -22,54 +20,69 @@ impl Uint { x = x.bitor(&x_i.shl_vartime(i)); let t = b.wrapping_sub(self); - b = Self::ct_select(b, t, j.wrapping_neg()).shr_vartime(1); + b = Self::ct_select(&b, &t, CtChoice::from_lsb(j)).shr_vartime(1); i += 1; } x } - /// Computes the multiplicative inverse of `self` mod `modulus`. In other words `self^-1 mod modulus`. Returns `(inverse, 1...1)` if an inverse exists, otherwise `(undefined, 0...0)`. The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`. - pub const fn inv_odd_mod(self, modulus: Uint) -> (Self, Word) { - debug_assert!(modulus.ct_is_odd() == Word::MAX); + /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. + /// In other words `self^-1 mod modulus`. + /// `bits` and `modulus_bits` are the bounds on the bit size + /// of `self` and `modulus`, respectively + /// (the inversion speed will be proportional to `bits + modulus_bits`). + /// The second element of the tuple is the truthy value if an inverse exists, + /// otherwise it is a falsy value. + /// + /// **Note:** variable time in `bits` and `modulus_bits`. + /// + /// The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`. + pub const fn inv_odd_mod_bounded( + &self, + modulus: &Self, + bits: usize, + modulus_bits: usize, + ) -> (Self, CtChoice) { + debug_assert!(modulus.ct_is_odd().is_true_vartime()); - let mut a = self; + let mut a = *self; let mut u = Uint::ONE; let mut v = Uint::ZERO; - let mut b = modulus; + let mut b = *modulus; - // TODO: This can be lower if `self` is known to be small. - let bit_size = 2 * LIMBS * 64; + // `bit_size` can be anything >= `self.bits()` + `modulus.bits()`, setting to the minimum. + let bit_size = bits + modulus_bits; - let mut m1hp = modulus; + let mut m1hp = *modulus; let (m1hp_new, carry) = m1hp.shr_1(); - debug_assert!(carry == Word::MAX); + debug_assert!(carry.is_true_vartime()); m1hp = m1hp_new.wrapping_add(&Uint::ONE); let mut i = 0; while i < bit_size { - debug_assert!(b.ct_is_odd() == Word::MAX); + debug_assert!(b.ct_is_odd().is_true_vartime()); let self_odd = a.ct_is_odd(); // Set `self -= b` if `self` is odd. let (new_a, swap) = a.conditional_wrapping_sub(&b, self_odd); // Set `b += self` if `swap` is true. - b = Uint::ct_select(b, b.wrapping_add(&new_a), swap); + b = Uint::ct_select(&b, &b.wrapping_add(&new_a), swap); // Negate `self` if `swap` is true. a = new_a.conditional_wrapping_neg(swap); - let (new_u, new_v) = Uint::ct_swap(u, v, swap); + let (new_u, new_v) = Uint::ct_swap(&u, &v, swap); let (new_u, cy) = new_u.conditional_wrapping_sub(&new_v, self_odd); - let (new_u, cyy) = new_u.conditional_wrapping_add(&modulus, cy); - debug_assert!(cy == cyy); + let (new_u, cyy) = new_u.conditional_wrapping_add(modulus, cy); + debug_assert!(cy.is_true_vartime() == cyy.is_true_vartime()); let (new_a, overflow) = a.shr_1(); - debug_assert!(overflow == 0); + debug_assert!(!overflow.is_true_vartime()); let (new_u, cy) = new_u.shr_1(); let (new_u, cy) = new_u.conditional_wrapping_add(&m1hp, cy); - debug_assert!(cy == 0); + debug_assert!(!cy.is_true_vartime()); a = new_a; u = new_u; @@ -78,15 +91,15 @@ impl Uint { i += 1; } - debug_assert!(a.ct_cmp(&Uint::ZERO) == 0); + debug_assert!(!a.ct_is_nonzero().is_true_vartime()); - (v, b.ct_not_eq(&Uint::ONE) ^ Word::MAX) + (v, Uint::ct_eq(&b, &Uint::ONE)) } - /// Computes the multiplicative inverse of `self` mod `modulus`. In other words `self^-1 mod modulus`. Returns `None` if the inverse does not exist. The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`. - pub fn inv_odd_mod_option(self, modulus: Uint) -> CtOption { - let (inverse, exists) = self.inv_odd_mod(modulus); - CtOption::new(inverse, Choice::from((exists == Word::MAX) as u8)) + /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. + /// Returns `(inverse, Word::MAX)` if an inverse exists, otherwise `(undefined, Word::ZERO)`. + pub const fn inv_odd_mod(&self, modulus: &Self) -> (Self, CtChoice) { + self.inv_odd_mod_bounded(modulus, Uint::::BITS, Uint::::BITS) } } @@ -96,42 +109,73 @@ mod tests { #[test] fn inv_mod2k() { - let v = U256::from_be_slice(&[ - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, - 0xff, 0xff, 0xfc, 0x2f, - ]); - let e = U256::from_be_slice(&[ - 0x36, 0x42, 0xe6, 0xfa, 0xea, 0xac, 0x7c, 0x66, 0x63, 0xb9, 0x3d, 0x3d, 0x6a, 0x0d, - 0x48, 0x9e, 0x43, 0x4d, 0xdc, 0x01, 0x23, 0xdb, 0x5f, 0xa6, 0x27, 0xc7, 0xf6, 0xe2, - 0x2d, 0xda, 0xca, 0xcf, - ]); + let v = + U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f"); + let e = + U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf"); let a = v.inv_mod2k(256); assert_eq!(e, a); - let v = U256::from_be_slice(&[ - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xfe, 0xba, 0xae, 0xdc, 0xe6, 0xaf, 0x48, 0xa0, 0x3b, 0xbf, 0xd2, 0x5e, 0x8c, - 0xd0, 0x36, 0x41, 0x41, - ]); - let e = U256::from_be_slice(&[ - 0x26, 0x17, 0x76, 0xf2, 0x9b, 0x6b, 0x10, 0x6c, 0x76, 0x80, 0xcf, 0x3e, 0xd8, 0x30, - 0x54, 0xa1, 0xaf, 0x5a, 0xe5, 0x37, 0xcb, 0x46, 0x13, 0xdb, 0xb4, 0xf2, 0x00, 0x99, - 0xaa, 0x77, 0x4e, 0xc1, - ]); + let v = + U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"); + let e = + U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1"); let a = v.inv_mod2k(256); assert_eq!(e, a); } #[test] fn test_invert() { - let a = U1024::from_be_hex("000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"); - let m = U1024::from_be_hex("D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001AD198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"); + let a = U1024::from_be_hex(concat![ + "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E", + "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8", + "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8", + "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985" + ]); + let m = U1024::from_be_hex(concat![ + "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F", + "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A", + "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C", + "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767" + ]); - let res = a.inv_odd_mod_option(m); + let (res, is_some) = a.inv_odd_mod(&m); - let expected = U1024::from_be_hex("B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE5788D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"); - assert_eq!(res.unwrap(), expected); + let expected = U1024::from_be_hex(concat![ + "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55", + "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57", + "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA", + "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336" + ]); + assert!(is_some.is_true_vartime()); + assert_eq!(res, expected); + } + + #[test] + fn test_invert_bounded() { + let a = U1024::from_be_hex(concat![ + "0000000000000000000000000000000000000000000000000000000000000000", + "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8", + "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8", + "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985" + ]); + let m = U1024::from_be_hex(concat![ + "0000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C", + "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767" + ]); + + let (res, is_some) = a.inv_odd_mod_bounded(&m, 768, 512); + + let expected = U1024::from_be_hex(concat![ + "0000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "0DCC94E2FE509E6EBBA0825645A38E73EF85D5927C79C1AD8FFE7C8DF9A822FA", + "09EB396A21B1EF05CBE51E1A8EF284EF01EBDD36A9A4EA17039D8EEFDD934768" + ]); + assert!(is_some.is_true_vartime()); + assert_eq!(res, expected); } #[test] @@ -139,9 +183,10 @@ mod tests { let a = U64::from(3u64); let m = U64::from(13u64); - let res = a.inv_odd_mod_option(m); + let (res, is_some) = a.inv_odd_mod(&m); - assert_eq!(U64::from(9u64), res.unwrap()); + assert!(is_some.is_true_vartime()); + assert_eq!(U64::from(9u64), res); } #[test] @@ -149,8 +194,8 @@ mod tests { let a = U64::from(14u64); let m = U64::from(49u64); - let res = a.inv_odd_mod_option(m); + let (_res, is_some) = a.inv_odd_mod(&m); - assert!(res.is_none().unwrap_u8() == 1); + assert!(!is_some.is_true_vartime()); } } diff --git a/src/uint/modular.rs b/src/uint/modular.rs index a3350b15..c5b7d2eb 100644 --- a/src/uint/modular.rs +++ b/src/uint/modular.rs @@ -64,8 +64,8 @@ mod tests { // Divide the value R by R, which should equal 1 assert_eq!( montgomery_reduction::<{ Modulus2::LIMBS }>( - (Modulus2::R, Uint::ZERO), - Modulus2::MODULUS, + &(Modulus2::R, Uint::ZERO), + &Modulus2::MODULUS, Modulus2::MOD_NEG_INV ), Uint::ONE @@ -77,8 +77,8 @@ mod tests { // Divide the value R^2 by R, which should equal R assert_eq!( montgomery_reduction::<{ Modulus2::LIMBS }>( - (Modulus2::R2, Uint::ZERO), - Modulus2::MODULUS, + &(Modulus2::R2, Uint::ZERO), + &Modulus2::MODULUS, Modulus2::MOD_NEG_INV ), Modulus2::R @@ -91,8 +91,8 @@ mod tests { let (hi, lo) = Modulus2::R.square().split(); assert_eq!( montgomery_reduction::<{ Modulus2::LIMBS }>( - (lo, hi), - Modulus2::MODULUS, + &(lo, hi), + &Modulus2::MODULUS, Modulus2::MOD_NEG_INV ), Modulus2::R @@ -107,8 +107,8 @@ mod tests { let product = x.mul_wide(&Modulus2::R); assert_eq!( montgomery_reduction::<{ Modulus2::LIMBS }>( - product, - Modulus2::MODULUS, + &product, + &Modulus2::MODULUS, Modulus2::MOD_NEG_INV ), x @@ -131,8 +131,8 @@ mod tests { assert_eq!( montgomery_reduction::<{ Modulus2::LIMBS }>( - product, - Modulus2::MODULUS, + &product, + &Modulus2::MODULUS, Modulus2::MOD_NEG_INV ), lo @@ -143,7 +143,7 @@ mod tests { fn test_new_retrieve() { let x = U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); - let x_mod = Residue::::new(x); + let x_mod = Residue::::new(&x); // Confirm that when creating a Modular and retrieving the value, that it equals the original assert_eq!(x, x_mod.retrieve()); @@ -154,7 +154,7 @@ mod tests { let x = U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); assert_eq!( - Residue::::new(x), + Residue::::new(&x), const_residue!(x, Modulus2) ); } diff --git a/src/uint/modular/constant_mod.rs b/src/uint/modular/constant_mod.rs index ddf9e600..2cba8104 100644 --- a/src/uint/modular/constant_mod.rs +++ b/src/uint/modular/constant_mod.rs @@ -41,7 +41,7 @@ pub trait ResidueParams: Copy { /// R^3, used to perform a multiplicative inverse const R3: Uint; /// The lowest limbs of -(MODULUS^-1) mod R - // We only need the LSB because during reduction this value is multiplied modulo 2**64. + // We only need the LSB because during reduction this value is multiplied modulo 2**Limb::BITS. const MOD_NEG_INV: Limb; } @@ -69,24 +69,22 @@ impl, const LIMBS: usize> Residue { }; /// Instantiates a new `Residue` that represents this `integer` mod `MOD`. - pub const fn new(integer: Uint) -> Self { - let mut modular_integer = Residue { - montgomery_form: integer, - phantom: PhantomData, - }; - + pub const fn new(integer: &Uint) -> Self { let product = integer.mul_wide(&MOD::R2); - modular_integer.montgomery_form = - montgomery_reduction::(product, MOD::MODULUS, MOD::MOD_NEG_INV); + let montgomery_form = + montgomery_reduction::(&product, &MOD::MODULUS, MOD::MOD_NEG_INV); - modular_integer + Self { + montgomery_form, + phantom: PhantomData, + } } /// Retrieves the integer currently encoded in this `Residue`, guaranteed to be reduced. pub const fn retrieve(&self) -> Uint { montgomery_reduction::( - (self.montgomery_form, Uint::ZERO), - MOD::MODULUS, + &(self.montgomery_form, Uint::ZERO), + &MOD::MODULUS, MOD::MOD_NEG_INV, ) } @@ -109,7 +107,7 @@ impl + Copy, const LIMBS: usize> ConditionallySelectab impl, const LIMBS: usize> ConstantTimeEq for Residue { fn ct_eq(&self, other: &Self) -> Choice { - self.montgomery_form.ct_eq(&other.montgomery_form) + ConstantTimeEq::ct_eq(&self.montgomery_form, &other.montgomery_form) } } diff --git a/src/uint/modular/constant_mod/const_inv.rs b/src/uint/modular/constant_mod/const_inv.rs index 5d957b97..5c8da7b8 100644 --- a/src/uint/modular/constant_mod/const_inv.rs +++ b/src/uint/modular/constant_mod/const_inv.rs @@ -1,20 +1,20 @@ use core::marker::PhantomData; -use subtle::{Choice, CtOption}; +use subtle::CtOption; -use crate::{modular::inv::inv_montgomery_form, traits::Invert, NonZero}; +use crate::{modular::inv::inv_montgomery_form, traits::Invert, CtChoice, NonZero}; use super::{Residue, ResidueParams}; impl, const LIMBS: usize> Residue { /// Computes the residue `self^-1` representing the multiplicative inverse of `self`. /// I.e. `self * self^-1 = 1`. - /// If the number was invertible, the second element of the tuple is `1`, - /// otherwise it is `0` (in which case the first element's value is unspecified). - pub const fn invert(&self) -> (Self, u8) { + /// If the number was invertible, the second element of the tuple is the truthy value, + /// otherwise it is the falsy value (in which case the first element's value is unspecified). + pub const fn invert(&self) -> (Self, CtChoice) { let (montgomery_form, is_some) = inv_montgomery_form( - self.montgomery_form, - MOD::MODULUS, + &self.montgomery_form, + &MOD::MODULUS, &MOD::R3, MOD::MOD_NEG_INV, ); @@ -24,7 +24,7 @@ impl, const LIMBS: usize> Residue { phantom: PhantomData, }; - (value, (is_some & 1) as u8) + (value, is_some) } } @@ -32,7 +32,7 @@ impl, const LIMBS: usize> Invert for Residue; fn invert(&self) -> Self::Output { let (value, is_some) = self.invert(); - CtOption::new(value, Choice::from(is_some)) + CtOption::new(value, is_some.into()) } } diff --git a/src/uint/modular/constant_mod/const_mul.rs b/src/uint/modular/constant_mod/const_mul.rs index 1693790e..c3ea4558 100644 --- a/src/uint/modular/constant_mod/const_mul.rs +++ b/src/uint/modular/constant_mod/const_mul.rs @@ -17,7 +17,7 @@ impl, const LIMBS: usize> Residue { montgomery_form: mul_montgomery_form( &self.montgomery_form, &rhs.montgomery_form, - MOD::MODULUS, + &MOD::MODULUS, MOD::MOD_NEG_INV, ), phantom: PhantomData, @@ -80,7 +80,7 @@ impl, const LIMBS: usize> Square for Residue, const LIMBS: usize> Residue { impl, const LIMBS: usize> Neg for Residue { type Output = Self; fn neg(self) -> Self { - (&self).neg() + Residue::neg(&self) + } +} + +impl, const LIMBS: usize> Neg for &Residue { + type Output = Residue; + fn neg(self) -> Residue { + Residue::neg(self) } } diff --git a/src/uint/modular/constant_mod/const_pow.rs b/src/uint/modular/constant_mod/const_pow.rs index acfc2613..60455b05 100644 --- a/src/uint/modular/constant_mod/const_pow.rs +++ b/src/uint/modular/constant_mod/const_pow.rs @@ -20,11 +20,11 @@ impl, const LIMBS: usize> Residue { ) -> Residue { Self { montgomery_form: pow_montgomery_form( - self.montgomery_form, + &self.montgomery_form, exponent, exponent_bits, - MOD::MODULUS, - MOD::R, + &MOD::MODULUS, + &MOD::R, MOD::MOD_NEG_INV, ), phantom: core::marker::PhantomData, diff --git a/src/uint/modular/constant_mod/macros.rs b/src/uint/modular/constant_mod/macros.rs index 5299e91b..b47b4cf9 100644 --- a/src/uint/modular/constant_mod/macros.rs +++ b/src/uint/modular/constant_mod/macros.rs @@ -27,8 +27,8 @@ macro_rules! impl_modulus { ); const R3: $crate::Uint<{ nlimbs!(<$uint_type>::BITS) }> = $crate::uint::modular::reduction::montgomery_reduction( - Self::R2.square_wide(), - Self::MODULUS, + &Self::R2.square_wide(), + &Self::MODULUS, Self::MOD_NEG_INV, ); } @@ -41,7 +41,7 @@ macro_rules! impl_modulus { macro_rules! const_residue { ($variable:ident, $modulus:ident) => { $crate::uint::modular::constant_mod::Residue::<$modulus, { $modulus::LIMBS }>::new( - $variable, + &$variable, ) }; } diff --git a/src/uint/modular/inv.rs b/src/uint/modular/inv.rs index 7fb6e282..408c03fb 100644 --- a/src/uint/modular/inv.rs +++ b/src/uint/modular/inv.rs @@ -1,14 +1,14 @@ -use crate::{modular::reduction::montgomery_reduction, Limb, Uint, Word}; +use crate::{modular::reduction::montgomery_reduction, CtChoice, Limb, Uint}; pub const fn inv_montgomery_form( - x: Uint, - modulus: Uint, + x: &Uint, + modulus: &Uint, r3: &Uint, mod_neg_inv: Limb, -) -> (Uint, Word) { - let (inverse, error) = x.inv_odd_mod(modulus); +) -> (Uint, CtChoice) { + let (inverse, is_some) = x.inv_odd_mod(modulus); ( - montgomery_reduction(inverse.mul_wide(r3), modulus, mod_neg_inv), - error, + montgomery_reduction(&inverse.mul_wide(r3), modulus, mod_neg_inv), + is_some, ) } diff --git a/src/uint/modular/mul.rs b/src/uint/modular/mul.rs index 74da6b04..b84ceb5c 100644 --- a/src/uint/modular/mul.rs +++ b/src/uint/modular/mul.rs @@ -5,18 +5,18 @@ use super::reduction::montgomery_reduction; pub(crate) const fn mul_montgomery_form( a: &Uint, b: &Uint, - modulus: Uint, + modulus: &Uint, mod_neg_inv: Limb, ) -> Uint { let product = a.mul_wide(b); - montgomery_reduction::(product, modulus, mod_neg_inv) + montgomery_reduction::(&product, modulus, mod_neg_inv) } pub(crate) const fn square_montgomery_form( a: &Uint, - modulus: Uint, + modulus: &Uint, mod_neg_inv: Limb, ) -> Uint { let product = a.square_wide(); - montgomery_reduction::(product, modulus, mod_neg_inv) + montgomery_reduction::(&product, modulus, mod_neg_inv) } diff --git a/src/uint/modular/pow.rs b/src/uint/modular/pow.rs index 2695e30b..5ab1fd63 100644 --- a/src/uint/modular/pow.rs +++ b/src/uint/modular/pow.rs @@ -7,26 +7,26 @@ use super::mul::{mul_montgomery_form, square_montgomery_form}; /// /// NOTE: this value is leaked in the time pattern. pub const fn pow_montgomery_form( - x: Uint, + x: &Uint, exponent: &Uint, exponent_bits: usize, - modulus: Uint, - r: Uint, + modulus: &Uint, + r: &Uint, mod_neg_inv: Limb, ) -> Uint { if exponent_bits == 0 { - return r; // 1 in Montgomery form + return *r; // 1 in Montgomery form } const WINDOW: usize = 4; const WINDOW_MASK: Word = (1 << WINDOW) - 1; // powers[i] contains x^i - let mut powers = [r; 1 << WINDOW]; - powers[1] = x; + let mut powers = [*r; 1 << WINDOW]; + powers[1] = *x; let mut i = 2; while i < powers.len() { - powers[i] = mul_montgomery_form(&powers[i - 1], &x, modulus, mod_neg_inv); + powers[i] = mul_montgomery_form(&powers[i - 1], x, modulus, mod_neg_inv); i += 1; } @@ -35,7 +35,7 @@ pub const fn pow_montgomery_form( let starting_window = starting_bit_in_limb / WINDOW; let starting_window_mask = (1 << (starting_bit_in_limb % WINDOW + 1)) - 1; - let mut z = r; // 1 in Montgomery form + let mut z = *r; // 1 in Montgomery form let mut limb_num = starting_limb + 1; while limb_num > 0 { @@ -67,7 +67,7 @@ pub const fn pow_montgomery_form( let mut i = 1; while i < 1 << WINDOW { let choice = Limb::ct_eq(Limb(i as Word), Limb(idx)); - power = Uint::::ct_select(power, powers[i], choice); + power = Uint::::ct_select(&power, &powers[i], choice); i += 1; } diff --git a/src/uint/modular/reduction.rs b/src/uint/modular/reduction.rs index cd967a35..10c12888 100644 --- a/src/uint/modular/reduction.rs +++ b/src/uint/modular/reduction.rs @@ -1,14 +1,14 @@ -use crate::{Limb, Uint, WideWord, Word}; +use crate::{CtChoice, Limb, Uint, WideWord, Word}; /// Algorithm 14.32 in Handbook of Applied Cryptography (https://cacr.uwaterloo.ca/hac/about/chap14.pdf) pub(crate) const fn montgomery_reduction( - lower_upper: (Uint, Uint), - modulus: Uint, + lower_upper: &(Uint, Uint), + modulus: &Uint, mod_neg_inv: Limb, ) -> Uint { - let (mut lower, mut upper) = lower_upper; + let (mut lower, mut upper) = *lower_upper; - let mut meta_carry = 0; + let mut meta_carry: WideWord = 0; let mut i = 0; while i < LIMBS { @@ -49,9 +49,8 @@ pub(crate) const fn montgomery_reduction( // Division is simply taking the upper half of the limbs // Final reduction (at this point, the value is at most 2 * modulus) - let must_reduce = (meta_carry as Word).saturating_mul(Word::MAX) - | ((upper.ct_cmp(&modulus) != -1) as Word).saturating_mul(Word::MAX); - upper = upper.wrapping_sub(&Uint::ct_select(Uint::ZERO, modulus, must_reduce)); + let must_reduce = CtChoice::from_lsb(meta_carry as Word).or(Uint::ct_gt(modulus, &upper).not()); + upper = upper.wrapping_sub(&Uint::ct_select(&Uint::ZERO, modulus, must_reduce)); upper } diff --git a/src/uint/modular/runtime_mod.rs b/src/uint/modular/runtime_mod.rs index 13ce3abd..9ddb5f94 100644 --- a/src/uint/modular/runtime_mod.rs +++ b/src/uint/modular/runtime_mod.rs @@ -27,21 +27,21 @@ pub struct DynResidueParams { // R^3, used to compute the multiplicative inverse r3: Uint, // The lowest limbs of -(MODULUS^-1) mod R - // We only need the LSB because during reduction this value is multiplied modulo 2**64. + // We only need the LSB because during reduction this value is multiplied modulo 2**Limb::BITS. mod_neg_inv: Limb, } impl DynResidueParams { /// Instantiates a new set of `ResidueParams` representing the given `modulus`. - pub fn new(modulus: Uint) -> Self { - let r = Uint::MAX.ct_rem(&modulus).0.wrapping_add(&Uint::ONE); - let r2 = Uint::ct_rem_wide(r.square_wide(), &modulus).0; + pub fn new(modulus: &Uint) -> Self { + let r = Uint::MAX.ct_rem(modulus).0.wrapping_add(&Uint::ONE); + let r2 = Uint::ct_rem_wide(r.square_wide(), modulus).0; let mod_neg_inv = Limb(Word::MIN.wrapping_sub(modulus.inv_mod2k(Word::BITS as usize).limbs[0].0)); - let r3 = montgomery_reduction(r2.square_wide(), modulus, mod_neg_inv); + let r3 = montgomery_reduction(&r2.square_wide(), modulus, mod_neg_inv); Self { - modulus, + modulus: *modulus, r, r2, r3, @@ -59,24 +59,25 @@ pub struct DynResidue { impl DynResidue { /// Instantiates a new `Residue` that represents this `integer` mod `MOD`. - pub const fn new(integer: Uint, residue_params: DynResidueParams) -> Self { - let mut modular_integer = Self { - montgomery_form: integer, - residue_params, - }; - + pub const fn new(integer: &Uint, residue_params: DynResidueParams) -> Self { let product = integer.mul_wide(&residue_params.r2); - modular_integer.montgomery_form = - montgomery_reduction(product, residue_params.modulus, residue_params.mod_neg_inv); + let montgomery_form = montgomery_reduction( + &product, + &residue_params.modulus, + residue_params.mod_neg_inv, + ); - modular_integer + Self { + montgomery_form, + residue_params, + } } /// Retrieves the integer currently encoded in this `Residue`, guaranteed to be reduced. pub const fn retrieve(&self) -> Uint { montgomery_reduction( - (self.montgomery_form, Uint::ZERO), - self.residue_params.modulus, + &(self.montgomery_form, Uint::ZERO), + &self.residue_params.modulus, self.residue_params.mod_neg_inv, ) } diff --git a/src/uint/modular/runtime_mod/runtime_add.rs b/src/uint/modular/runtime_mod/runtime_add.rs index 6ee6ad4f..eb470860 100644 --- a/src/uint/modular/runtime_mod/runtime_add.rs +++ b/src/uint/modular/runtime_mod/runtime_add.rs @@ -70,17 +70,17 @@ mod tests { #[test] fn add_overflow() { - let params = DynResidueParams::new(U256::from_be_hex( + let params = DynResidueParams::new(&U256::from_be_hex( "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551", )); let x = U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); - let mut x_mod = DynResidue::new(x, params); + let mut x_mod = DynResidue::new(&x, params); let y = U256::from_be_hex("d5777c45019673125ad240f83094d4252d829516fac8601ed01979ec1ec1a251"); - let y_mod = DynResidue::new(y, params); + let y_mod = DynResidue::new(&y, params); x_mod += &y_mod; diff --git a/src/uint/modular/runtime_mod/runtime_inv.rs b/src/uint/modular/runtime_mod/runtime_inv.rs index 03437748..5e639d43 100644 --- a/src/uint/modular/runtime_mod/runtime_inv.rs +++ b/src/uint/modular/runtime_mod/runtime_inv.rs @@ -1,18 +1,18 @@ -use subtle::{Choice, CtOption}; +use subtle::CtOption; -use crate::{modular::inv::inv_montgomery_form, traits::Invert}; +use crate::{modular::inv::inv_montgomery_form, traits::Invert, CtChoice}; use super::DynResidue; impl DynResidue { /// Computes the residue `self^-1` representing the multiplicative inverse of `self`. /// I.e. `self * self^-1 = 1`. - /// If the number was invertible, the second element of the tuple is `1`, - /// otherwise it is `0` (in which case the first element's value is unspecified). - pub const fn invert(&self) -> (Self, u8) { + /// If the number was invertible, the second element of the tuple is the truthy value, + /// otherwise it is the falsy value (in which case the first element's value is unspecified). + pub const fn invert(&self) -> (Self, CtChoice) { let (montgomery_form, is_some) = inv_montgomery_form( - self.montgomery_form, - self.residue_params.modulus, + &self.montgomery_form, + &self.residue_params.modulus, &self.residue_params.r3, self.residue_params.mod_neg_inv, ); @@ -22,7 +22,7 @@ impl DynResidue { residue_params: self.residue_params, }; - (value, (is_some & 1) as u8) + (value, is_some) } } @@ -30,6 +30,6 @@ impl Invert for DynResidue { type Output = CtOption; fn invert(&self) -> Self::Output { let (value, is_some) = self.invert(); - CtOption::new(value, Choice::from(is_some)) + CtOption::new(value, is_some.into()) } } diff --git a/src/uint/modular/runtime_mod/runtime_mul.rs b/src/uint/modular/runtime_mod/runtime_mul.rs index 662f187d..b260461c 100644 --- a/src/uint/modular/runtime_mod/runtime_mul.rs +++ b/src/uint/modular/runtime_mod/runtime_mul.rs @@ -14,7 +14,7 @@ impl DynResidue { montgomery_form: mul_montgomery_form( &self.montgomery_form, &rhs.montgomery_form, - self.residue_params.modulus, + &self.residue_params.modulus, self.residue_params.mod_neg_inv, ), residue_params: self.residue_params, @@ -70,7 +70,7 @@ impl Square for DynResidue { Self { montgomery_form: square_montgomery_form( &self.montgomery_form, - self.residue_params.modulus, + &self.residue_params.modulus, self.residue_params.mod_neg_inv, ), residue_params: self.residue_params, diff --git a/src/uint/modular/runtime_mod/runtime_neg.rs b/src/uint/modular/runtime_mod/runtime_neg.rs index c97e9283..fca1ff87 100644 --- a/src/uint/modular/runtime_mod/runtime_neg.rs +++ b/src/uint/modular/runtime_mod/runtime_neg.rs @@ -12,6 +12,13 @@ impl DynResidue { impl Neg for DynResidue { type Output = Self; fn neg(self) -> Self { - (&self).neg() + DynResidue::neg(&self) + } +} + +impl Neg for &DynResidue { + type Output = DynResidue; + fn neg(self) -> DynResidue { + DynResidue::neg(self) } } diff --git a/src/uint/modular/runtime_mod/runtime_pow.rs b/src/uint/modular/runtime_mod/runtime_pow.rs index d7e23f0d..270f4c01 100644 --- a/src/uint/modular/runtime_mod/runtime_pow.rs +++ b/src/uint/modular/runtime_mod/runtime_pow.rs @@ -16,11 +16,11 @@ impl DynResidue { pub const fn pow_bounded_exp(&self, exponent: &Uint, exponent_bits: usize) -> Self { Self { montgomery_form: pow_montgomery_form( - self.montgomery_form, + &self.montgomery_form, exponent, exponent_bits, - self.residue_params.modulus, - self.residue_params.r, + &self.residue_params.modulus, + &self.residue_params.r, self.residue_params.mod_neg_inv, ), residue_params: self.residue_params, diff --git a/src/uint/modular/runtime_mod/runtime_sub.rs b/src/uint/modular/runtime_mod/runtime_sub.rs index 635bd8ed..dd6fd84c 100644 --- a/src/uint/modular/runtime_mod/runtime_sub.rs +++ b/src/uint/modular/runtime_mod/runtime_sub.rs @@ -70,17 +70,17 @@ mod tests { #[test] fn sub_overflow() { - let params = DynResidueParams::new(U256::from_be_hex( + let params = DynResidueParams::new(&U256::from_be_hex( "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551", )); let x = U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56"); - let mut x_mod = DynResidue::new(x, params); + let mut x_mod = DynResidue::new(&x, params); let y = U256::from_be_hex("d5777c45019673125ad240f83094d4252d829516fac8601ed01979ec1ec1a251"); - let y_mod = DynResidue::new(y, params); + let y_mod = DynResidue::new(&y, params); x_mod -= &y_mod; diff --git a/src/uint/mul_mod.rs b/src/uint/mul_mod.rs index bac9864e..0916ede4 100644 --- a/src/uint/mul_mod.rs +++ b/src/uint/mul_mod.rs @@ -20,7 +20,7 @@ impl Uint { let (lo, hi) = self.mul_wide(rhs); // Now use Algorithm 14.47 for the reduction - let (lo, carry) = mac_by_limb(lo, hi, c, Limb::ZERO); + let (lo, carry) = mac_by_limb(&lo, &hi, c, Limb::ZERO); let (lo, carry) = { let rhs = (carry.0 + 1) as WideWord * c.0 as WideWord; @@ -38,12 +38,14 @@ impl Uint { /// Computes `a + (b * c) + carry`, returning the result along with the new carry. const fn mac_by_limb( - mut a: Uint, - b: Uint, + a: &Uint, + b: &Uint, c: Limb, - mut carry: Limb, + carry: Limb, ) -> (Uint, Limb) { let mut i = 0; + let mut a = *a; + let mut carry = carry; while i < LIMBS { let (n, c) = a.limbs[i].mac(b.limbs[i], c, carry); diff --git a/src/uint/neg.rs b/src/uint/neg.rs index 6ea16a14..5cdba201 100644 --- a/src/uint/neg.rs +++ b/src/uint/neg.rs @@ -1,6 +1,6 @@ use core::ops::Neg; -use crate::{Uint, Word, Wrapping}; +use crate::{CtChoice, Uint, Wrapping}; impl Neg for Wrapping> { type Output = Self; @@ -13,10 +13,10 @@ impl Neg for Wrapping> { impl Uint { /// Negates based on `choice` by wrapping the integer. - pub(crate) const fn conditional_wrapping_neg(self, choice: Word) -> Uint { + pub(crate) const fn conditional_wrapping_neg(&self, choice: CtChoice) -> Uint { let (shifted, _) = self.shl_1(); let negated_self = self.wrapping_sub(&shifted); - Uint::ct_select(self, negated_self, choice) + Uint::ct_select(self, &negated_self, choice) } } diff --git a/src/uint/neg_mod.rs b/src/uint/neg_mod.rs index 7472a558..aaed2768 100644 --- a/src/uint/neg_mod.rs +++ b/src/uint/neg_mod.rs @@ -12,7 +12,7 @@ impl Uint { while i < LIMBS { // Set ret to 0 if the original value was 0, in which // case ret would be p. - ret.limbs[i].0 &= z; + ret.limbs[i].0 = z.if_true(ret.limbs[i].0); i += 1; } ret diff --git a/src/uint/shl.rs b/src/uint/shl.rs index e08ada57..eb8c713c 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -1,11 +1,11 @@ //! [`Uint`] bitwise left shift operations. -use crate::{limb::HI_BIT, Limb, Uint, Word}; +use crate::{limb::HI_BIT, CtChoice, Limb, Uint, Word}; use core::ops::{Shl, ShlAssign}; impl Uint { - /// Computes `self << 1` in constant-time, returning the overflowing bit as a `Word` that is either 0...0 or 1...1. - pub(crate) const fn shl_1(&self) -> (Self, Word) { + /// Computes `self << 1` in constant-time, returning the overflowing bit as a `CtChoice`. + pub(crate) const fn shl_1(&self) -> (Self, CtChoice) { let mut shifted_bits = [0; LIMBS]; let mut i = 0; while i < LIMBS { @@ -29,10 +29,7 @@ impl Uint { i += 1; } - ( - Uint::new(limbs), - carry_bits[LIMBS - 1].wrapping_mul(Word::MAX), - ) + (Uint::new(limbs), CtChoice::from_lsb(carry_bits[LIMBS - 1])) } /// Computes `self << shift` where `0 <= shift < Limb::BITS`, @@ -41,7 +38,7 @@ impl Uint { pub(crate) const fn shl_limb(&self, n: usize) -> (Self, Limb) { let mut limbs = [Limb::ZERO; LIMBS]; - let nz = Limb(n as Word).is_nonzero(); + let nz = Limb(n as Word).ct_is_nonzero(); let lshift = n as Word; let rshift = Limb::ct_select(Limb::ZERO, Limb((Limb::BITS - n) as Word), nz).0; let carry = Limb::ct_select( @@ -54,7 +51,7 @@ impl Uint { while i > 0 { let mut limb = self.limbs[i].0 << lshift; let hi = self.limbs[i - 1].0 >> rshift; - limb |= hi & nz; + limb |= nz.if_true(hi); limbs[i] = Limb(limb); i -= 1 } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index c7cb0f23..071788fb 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -1,12 +1,12 @@ //! [`Uint`] bitwise right shift operations. use super::Uint; -use crate::{limb::HI_BIT, Limb, Word}; +use crate::{limb::HI_BIT, CtChoice, Limb}; use core::ops::{Shr, ShrAssign}; impl Uint { /// Computes `self >> 1` in constant-time, returning the overflowing bit as a `Word` that is either 0...0 or 1...1. - pub(crate) const fn shr_1(&self) -> (Self, Word) { + pub(crate) const fn shr_1(&self) -> (Self, CtChoice) { let mut shifted_bits = [0; LIMBS]; let mut i = 0; while i < LIMBS { @@ -30,9 +30,10 @@ impl Uint { } limbs[LIMBS - 1] = Limb(shifted_bits[LIMBS - 1]); + debug_assert!(carry_bits[LIMBS - 1] == 0 || carry_bits[LIMBS - 1] == (1 << HI_BIT)); ( Uint::new(limbs), - (carry_bits[0] >> HI_BIT).wrapping_mul(Word::MAX), + CtChoice::from_lsb(carry_bits[0] >> HI_BIT), ) } diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index 26d39410..56815e2d 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -21,13 +21,12 @@ impl Uint { // If guess increased, the initial guess was low. // Repeat until reverse course. - while guess.ct_cmp(&xn) == -1 { + while Uint::ct_lt(&guess, &xn).is_true_vartime() { // Sometimes an increase is too far, especially with large // powers, and then takes a long time to walk back. The upper // bound is based on bit size, so saturate on that. - let res = Limb::ct_cmp(Limb(xn.bits_vartime() as Word), Limb(max_bits as Word)) - 1; - let le = Limb::is_nonzero(Limb(res as Word)); - guess = Self::ct_select(cap, xn, le); + let le = Limb::ct_le(Limb(xn.bits_vartime() as Word), Limb(max_bits as Word)); + guess = Self::ct_select(&cap, &xn, le); xn = { let q = self.wrapping_div(&guess); let t = guess.wrapping_add(&q); @@ -36,7 +35,7 @@ impl Uint { } // Repeat while guess decreases. - while guess.ct_cmp(&xn) == 1 && xn.ct_is_nonzero() == Word::MAX { + while Uint::ct_gt(&guess, &xn).is_true_vartime() && xn.ct_is_nonzero().is_true_vartime() { guess = xn; xn = { let q = self.wrapping_div(&guess); @@ -45,7 +44,7 @@ impl Uint { }; } - Self::ct_select(Self::ZERO, guess, self.ct_is_nonzero()) + Self::ct_select(&Self::ZERO, &guess, self.ct_is_nonzero()) } /// Wrapped sqrt is just normal √(`self`) @@ -60,7 +59,7 @@ impl Uint { pub fn checked_sqrt(&self) -> CtOption { let r = self.sqrt(); let s = r.wrapping_mul(&r); - CtOption::new(r, self.ct_eq(&s)) + CtOption::new(r, ConstantTimeEq::ct_eq(self, &s)) } } diff --git a/src/uint/sub.rs b/src/uint/sub.rs index aafbc3f9..c39e5492 100644 --- a/src/uint/sub.rs +++ b/src/uint/sub.rs @@ -1,7 +1,7 @@ //! [`Uint`] addition operations. use super::Uint; -use crate::{Checked, CheckedSub, Limb, Word, Wrapping, Zero}; +use crate::{Checked, CheckedSub, CtChoice, Limb, Wrapping, Zero}; use core::ops::{Sub, SubAssign}; use subtle::CtOption; @@ -39,13 +39,16 @@ impl Uint { self.sbb(rhs, Limb::ZERO).0 } - /// Perform wrapping subtraction, returning the underflow bit as a `Word` that is either 0...0 or 1...1. - pub(crate) const fn conditional_wrapping_sub(&self, rhs: &Self, choice: Word) -> (Self, Word) { - let actual_rhs = Uint::ct_select(Uint::ZERO, *rhs, choice); + /// Perform wrapping subtraction, returning the truthy value as the second element of the tuple + /// if an underflow has occurred. + pub(crate) const fn conditional_wrapping_sub( + &self, + rhs: &Self, + choice: CtChoice, + ) -> (Self, CtChoice) { + let actual_rhs = Uint::ct_select(&Uint::ZERO, rhs, choice); let (res, borrow) = self.sbb(&actual_rhs, Limb::ZERO); - - // Here we use a saturating multiplication to get the result to 0...0 or 1...1 - (res, borrow.0.saturating_mul(Word::MAX)) + (res, CtChoice::from_mask(borrow.0)) } } diff --git a/tests/proptests.rs b/tests/proptests.rs index b1862c0f..9f489f0a 100644 --- a/tests/proptests.rs +++ b/tests/proptests.rs @@ -258,8 +258,8 @@ proptest! { let expected = to_uint(a_bi.modpow(&b_bi, &p_bi)); - let params = DynResidueParams::new(P); - let a_m = DynResidue::new(a, params); + let params = DynResidueParams::new(&P); + let a_m = DynResidue::new(&a, params); let actual = a_m.pow(&b).retrieve(); assert_eq!(expected, actual); @@ -276,8 +276,8 @@ proptest! { let expected = to_uint(a_bi.modpow(&b_bi, &p_bi)); - let params = DynResidueParams::new(P); - let a_m = DynResidue::new(a, params); + let params = DynResidueParams::new(&P); + let a_m = DynResidue::new(&a, params); let actual = a_m.pow_bounded_exp(&b, exponent_bits.into()).retrieve(); assert_eq!(expected, actual);