Skip to content

Commit 13de849

Browse files
committed
Introduce the CtChoice newtype
1 parent 6687023 commit 13de849

19 files changed

+192
-124
lines changed

src/ct_choice.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use subtle::Choice;
2+
3+
use crate::Word;
4+
5+
/// A boolean value returned by constant-time `const fn`s.
6+
// TODO: should be replaced by `subtle::Choice` or `CtOption`
7+
// when `subtle` starts supporting const fns.
8+
#[derive(Debug, Copy, Clone)]
9+
pub struct CtChoice(Word);
10+
11+
impl CtChoice {
12+
/// The falsy vaue.
13+
pub const FALSE: Self = Self(0);
14+
15+
/// The truthy vaue.
16+
pub const TRUE: Self = Self(Word::MAX);
17+
18+
/// Returns the truthy value if `value == Word::MAX`, and the falsy value if `value == 0`.
19+
/// Panics for other values.
20+
pub(crate) const fn from_mask(value: Word) -> Self {
21+
debug_assert!(value == Self::FALSE.0 || value == Self::TRUE.0);
22+
Self(value)
23+
}
24+
25+
/// Returns the truthy value if `value == 1`, and the falsy value if `value == 0`.
26+
/// Panics for other values.
27+
pub(crate) const fn from_lsb(value: Word) -> Self {
28+
debug_assert!(value == Self::FALSE.0 || value == 1);
29+
Self(value.wrapping_neg())
30+
}
31+
32+
pub(crate) const fn not(&self) -> Self {
33+
Self(!self.0)
34+
}
35+
36+
pub(crate) const fn and(&self, other: Self) -> Self {
37+
Self(self.0 & other.0)
38+
}
39+
40+
pub(crate) const fn or(&self, other: Self) -> Self {
41+
Self(self.0 | other.0)
42+
}
43+
44+
/// Return `b` if `self` is truthy, otherwise return `a`.
45+
pub(crate) const fn select(&self, a: Word, b: Word) -> Word {
46+
a ^ (self.0 & (a ^ b))
47+
}
48+
49+
/// Return `x` if `self` is truthy, otherwise return 0.
50+
pub(crate) const fn if_true(&self, x: Word) -> Word {
51+
x & self.0
52+
}
53+
54+
pub(crate) const fn is_true_vartime(&self) -> bool {
55+
self.0 == CtChoice::TRUE.0
56+
}
57+
}
58+
59+
impl From<CtChoice> for Choice {
60+
fn from(choice: CtChoice) -> Self {
61+
Choice::from(choice.0 as u8 & 1)
62+
}
63+
}
64+
65+
impl From<CtChoice> for bool {
66+
fn from(choice: CtChoice) -> Self {
67+
choice.is_true_vartime()
68+
}
69+
}
70+
71+
#[cfg(test)]
72+
mod tests {
73+
use super::CtChoice;
74+
use crate::Word;
75+
76+
#[test]
77+
fn select() {
78+
let a: Word = 1;
79+
let b: Word = 2;
80+
assert_eq!(CtChoice::TRUE.select(a, b), b);
81+
assert_eq!(CtChoice::FALSE.select(a, b), a);
82+
}
83+
}

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ mod nlimbs;
161161
#[cfg(feature = "generic-array")]
162162
mod array;
163163
mod checked;
164+
mod ct_choice;
164165
mod limb;
165166
mod non_zero;
166167
mod traits;
@@ -169,7 +170,8 @@ mod wrapping;
169170

170171
pub use crate::{
171172
checked::Checked,
172-
limb::{CtChoice, Limb, WideWord, Word},
173+
ct_choice::CtChoice,
174+
limb::{Limb, WideWord, Word},
173175
non_zero::NonZero,
174176
traits::*,
175177
uint::div_limb::Reciprocal,

src/limb.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ mod sub;
2020
#[cfg(feature = "rand_core")]
2121
mod rand;
2222

23-
use crate::{Bounded, Zero};
23+
use crate::{Bounded, CtChoice, Zero};
2424
use core::fmt;
2525
use subtle::{Choice, ConditionallySelectable};
2626

@@ -57,12 +57,6 @@ pub type WideWord = u128;
5757
/// Highest bit in a [`Limb`].
5858
pub(crate) const HI_BIT: usize = Limb::BITS - 1;
5959

60-
/// A boolean value returned by constant-time `const fn`s.
61-
/// `Word::MAX` signifies `true`, and `0` signifies `false`.
62-
// TODO: should be replaced by `subtle::Choice` or `CtOption`
63-
// when `subtle` starts supporting const fns.
64-
pub type CtChoice = Word;
65-
6660
/// Big integers are represented as an array of smaller CPU word-size integers
6761
/// called "limbs".
6862
#[derive(Copy, Clone, Debug, Default, Hash)]
@@ -100,7 +94,7 @@ impl Limb {
10094
/// Return `b` if `c` is truthy, otherwise return `a`.
10195
#[inline]
10296
pub(crate) const fn ct_select(a: Self, b: Self, c: CtChoice) -> Self {
103-
Self(a.0 ^ (c & (a.0 ^ b.0)))
97+
Self(c.select(a.0, b.0))
10498
}
10599
}
106100

src/limb/cmp.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Limb comparisons
22
3-
use super::{CtChoice, Limb, HI_BIT};
3+
use super::HI_BIT;
4+
use crate::{CtChoice, Limb};
45
use core::cmp::Ordering;
56
use subtle::{Choice, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess};
67

@@ -28,7 +29,7 @@ impl Limb {
2829
#[inline]
2930
pub(crate) const fn ct_is_nonzero(&self) -> CtChoice {
3031
let inner = self.0;
31-
((inner | inner.wrapping_neg()) >> HI_BIT).wrapping_neg()
32+
CtChoice::from_lsb((inner | inner.wrapping_neg()) >> HI_BIT)
3233
}
3334

3435
/// Returns the truthy value if `lhs == rhs` and the falsy value otherwise.
@@ -38,7 +39,7 @@ impl Limb {
3839
let y = rhs.0;
3940

4041
// x ^ y == 0 if and only if x == y
41-
!Self(x ^ y).ct_is_nonzero()
42+
Self(x ^ y).ct_is_nonzero().not()
4243
}
4344

4445
/// Returns the truthy value if `lhs < rhs` and the falsy value otherwise.
@@ -47,7 +48,7 @@ impl Limb {
4748
let x = lhs.0;
4849
let y = rhs.0;
4950
let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (Limb::BITS - 1);
50-
bit.wrapping_neg()
51+
CtChoice::from_lsb(bit)
5152
}
5253

5354
/// Returns the truthy value if `lhs <= rhs` and the falsy value otherwise.
@@ -56,7 +57,7 @@ impl Limb {
5657
let x = lhs.0;
5758
let y = rhs.0;
5859
let bit = (((!x) | y) & ((x ^ y) | !(y.wrapping_sub(x)))) >> (Limb::BITS - 1);
59-
bit.wrapping_neg()
60+
CtChoice::from_lsb(bit)
6061
}
6162
}
6263

src/uint/add.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
4646
) -> (Self, CtChoice) {
4747
let actual_rhs = Uint::ct_select(&Uint::ZERO, rhs, choice);
4848
let (sum, carry) = self.adc(&actual_rhs, Limb::ZERO);
49-
50-
debug_assert!(carry.0 == 0 || carry.0 == 1);
51-
(sum, carry.0.wrapping_neg())
49+
(sum, CtChoice::from_lsb(carry.0))
5250
}
5351
}
5452

src/uint/bits.rs

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
use crate::{CtChoice, Limb, Uint, Word};
22

33
impl<const LIMBS: usize> Uint<LIMBS> {
4-
/// Get the value of the bit at position `index`, as a truthy or falsy `CtChoice`.
5-
/// Returns the falsy value for indices out of range.
4+
/// Returns `true` if the bit at position `index` is set, `false` otherwise.
65
#[inline(always)]
7-
pub const fn bit_vartime(self, index: usize) -> CtChoice {
6+
pub const fn bit_vartime(self, index: usize) -> bool {
87
if index >= Self::BITS {
9-
0
8+
false
109
} else {
11-
((self.limbs[index / Limb::BITS].0 >> (index % Limb::BITS)) & 1).wrapping_neg()
10+
(self.limbs[index / Limb::BITS].0 >> (index % Limb::BITS)) & 1 == 1
1211
}
1312
}
1413

@@ -21,14 +20,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
2120
}
2221

2322
let limb = self.limbs[i].0;
24-
let bits = (Limb::BITS * (i + 1)) as Word - limb.leading_zeros() as Word;
25-
26-
Limb::ct_select(
27-
Limb(bits),
28-
Limb::ZERO,
29-
!self.limbs[0].ct_is_nonzero() & !Limb(i as Word).ct_is_nonzero(),
30-
)
31-
.0 as usize
23+
Limb::BITS * (i + 1) - limb.leading_zeros() as usize
3224
}
3325

3426
/// Calculate the number of leading zeros in the binary representation of this number.
@@ -37,13 +29,14 @@ impl<const LIMBS: usize> Uint<LIMBS> {
3729

3830
let mut count: Word = 0;
3931
let mut i = LIMBS;
40-
let mut mask = Word::MAX;
32+
let mut nonzero_limb_not_encountered = CtChoice::TRUE;
4133
while i > 0 {
4234
i -= 1;
4335
let l = limbs[i];
4436
let z = l.leading_zeros() as Word;
45-
count += z & mask;
46-
mask &= !l.ct_is_nonzero();
37+
count += nonzero_limb_not_encountered.if_true(z);
38+
nonzero_limb_not_encountered =
39+
nonzero_limb_not_encountered.and(l.ct_is_nonzero().not());
4740
}
4841

4942
count as usize
@@ -55,12 +48,13 @@ impl<const LIMBS: usize> Uint<LIMBS> {
5548

5649
let mut count: Word = 0;
5750
let mut i = 0;
58-
let mut mask = Word::MAX;
51+
let mut nonzero_limb_not_encountered = CtChoice::TRUE;
5952
while i < LIMBS {
6053
let l = limbs[i];
6154
let z = l.trailing_zeros() as Word;
62-
count += z & mask;
63-
mask &= !l.ct_is_nonzero();
55+
count += nonzero_limb_not_encountered.if_true(z);
56+
nonzero_limb_not_encountered =
57+
nonzero_limb_not_encountered.and(l.ct_is_nonzero().not());
6458
i += 1;
6559
}
6660

@@ -86,17 +80,17 @@ impl<const LIMBS: usize> Uint<LIMBS> {
8680
while i < LIMBS {
8781
let bit = limbs[i] & index_mask;
8882
let is_right_limb = Limb::ct_eq(limb_num, Limb(i as Word));
89-
result |= bit & is_right_limb;
83+
result |= is_right_limb.if_true(bit);
9084
i += 1;
9185
}
9286

93-
(result >> index_in_limb).wrapping_neg()
87+
CtChoice::from_lsb(result >> index_in_limb)
9488
}
9589
}
9690

9791
#[cfg(test)]
9892
mod tests {
99-
use crate::{Word, U256};
93+
use crate::U256;
10094

10195
fn uint_with_bits_at(positions: &[usize]) -> U256 {
10296
let mut result = U256::ZERO;
@@ -109,25 +103,25 @@ mod tests {
109103
#[test]
110104
fn bit_vartime() {
111105
let u = uint_with_bits_at(&[16, 48, 112, 127, 255]);
112-
assert_eq!(u.bit_vartime(0), 0);
113-
assert_eq!(u.bit_vartime(1), 0);
114-
assert_eq!(u.bit_vartime(16), Word::MAX);
115-
assert_eq!(u.bit_vartime(127), Word::MAX);
116-
assert_eq!(u.bit_vartime(255), Word::MAX);
117-
assert_eq!(u.bit_vartime(256), 0);
118-
assert_eq!(u.bit_vartime(260), 0);
106+
assert!(!u.bit_vartime(0));
107+
assert!(!u.bit_vartime(1));
108+
assert!(u.bit_vartime(16));
109+
assert!(u.bit_vartime(127));
110+
assert!(u.bit_vartime(255));
111+
assert!(!u.bit_vartime(256));
112+
assert!(!u.bit_vartime(260));
119113
}
120114

121115
#[test]
122116
fn bit() {
123117
let u = uint_with_bits_at(&[16, 48, 112, 127, 255]);
124-
assert_eq!(u.bit(0), 0);
125-
assert_eq!(u.bit(1), 0);
126-
assert_eq!(u.bit(16), Word::MAX);
127-
assert_eq!(u.bit(127), Word::MAX);
128-
assert_eq!(u.bit(255), Word::MAX);
129-
assert_eq!(u.bit(256), 0);
130-
assert_eq!(u.bit(260), 0);
118+
assert!(!u.bit(0).is_true_vartime());
119+
assert!(!u.bit(1).is_true_vartime());
120+
assert!(u.bit(16).is_true_vartime());
121+
assert!(u.bit(127).is_true_vartime());
122+
assert!(u.bit(255).is_true_vartime());
123+
assert!(!u.bit(256).is_true_vartime());
124+
assert!(!u.bit(260).is_true_vartime());
131125
}
132126

133127
#[test]

0 commit comments

Comments
 (0)