Skip to content

Commit d24c4f9

Browse files
committed
Insert extra let bindings to avoid evaluating expressions multiple times during pattern matching
1 parent 0dd795c commit d24c4f9

File tree

6 files changed

+167
-53
lines changed

6 files changed

+167
-53
lines changed

fathom/src/core.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,20 @@ impl<'arena> Term<'arena> {
388388
),
389389
}
390390
}
391+
392+
pub fn is_trivial(&self) -> bool {
393+
match self {
394+
Term::ItemVar(_, _)
395+
| Term::LocalVar(_, _)
396+
| Term::MetaVar(_, _)
397+
| Term::InsertedMeta(_, _, _)
398+
| Term::Universe(_)
399+
| Term::Prim(_, _)
400+
| Term::ConstLit(_, _) => true,
401+
Term::RecordProj(_, head, _) => head.is_trivial(),
402+
_ => false,
403+
}
404+
}
391405
}
392406

393407
/// Simple patterns that have had some initial elaboration performed on them
@@ -431,6 +445,16 @@ impl<'arena> CheckedPattern<'arena> {
431445
CheckedPattern::ConstLit(_, _) | CheckedPattern::RecordLit(_, _, _) => false,
432446
}
433447
}
448+
449+
pub fn is_trivial(&self) -> bool {
450+
match self {
451+
CheckedPattern::ReportedError(_)
452+
| CheckedPattern::Placeholder(_)
453+
| CheckedPattern::Binder(_, _)
454+
| CheckedPattern::ConstLit(_, _) => true,
455+
CheckedPattern::RecordLit(_, _, _) => false,
456+
}
457+
}
434458
}
435459

436460
macro_rules! def_prims {

fathom/src/surface/distillation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,11 @@ impl<'interner, 'arena, 'env> Context<'interner, 'arena, 'env> {
393393
match core_term {
394394
core::Term::ItemVar(_span, var) => match self.get_item_name(*var) {
395395
Some(name) => Term::Name((), name),
396-
None => todo!("misbound variable"), // TODO: error?
396+
None => panic!("misbound item variable: {var:?}"),
397397
},
398398
core::Term::LocalVar(_span, var) => match self.get_local_name(*var) {
399399
Some(name) => Term::Name((), name),
400-
None => todo!("misbound variable"), // TODO: error?
400+
None => panic!("Unbound local variable: {var:?}"),
401401
},
402402
core::Term::MetaVar(_span, var) => match self.get_hole_name(*var) {
403403
Some(name) => Term::Hole((), name),

fathom/src/surface/elaboration.rs

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -804,21 +804,52 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
804804
match (surface_term, expected_type.as_ref()) {
805805
(Term::Let(range, def_pattern, def_type, def_expr, body_expr), _) => {
806806
let (def_pattern, def_type_value) = self.synth_ann_pattern(def_pattern, *def_type);
807-
let scrut = self.check_scrutinee(def_expr, def_type_value.clone());
807+
let mut scrut = self.check_scrutinee(def_expr, def_type_value.clone());
808808
let value = self.eval_env().eval(scrut.expr);
809+
810+
// Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times,
811+
// and may be evaluated multiple times by the pattern match compiler
812+
let extra_def = match (scrut.expr.is_trivial(), def_pattern.is_trivial()) {
813+
(false, false) => {
814+
let def_name = None; // TODO: generate a fresh name
815+
let def_type = self.quote_env().quote(self.scope, &scrut.r#type);
816+
let def_expr = scrut.expr.clone();
817+
818+
let var = core::Term::LocalVar(def_expr.span(), env::Index::last());
819+
scrut.expr = self.scope.to_scope(var);
820+
(self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone());
821+
Some((def_name, def_type, def_expr))
822+
}
823+
_ => None,
824+
};
825+
809826
let initial_len = self.local_env.len();
810827
self.push_local_def(&def_pattern, value, scrut.r#type.clone());
811828
let body_expr = self.check(body_expr, &expected_type);
812829
self.local_env.truncate(initial_len);
813830

814831
let matrix = PatMatrix::singleton(scrut, def_pattern);
815-
self.elab_match(
832+
let expr = self.elab_match(
816833
matrix,
817834
&[body_expr],
818835
*range,
819836
def_expr.range(),
820837
PatternMode::Let,
821-
)
838+
);
839+
let expr = match extra_def {
840+
None => expr,
841+
Some((def_name, def_type, def_expr)) => {
842+
self.local_env.pop();
843+
core::Term::Let(
844+
range.into(),
845+
def_name,
846+
self.scope.to_scope(def_type),
847+
self.scope.to_scope(def_expr),
848+
self.scope.to_scope(expr),
849+
)
850+
}
851+
};
852+
expr
822853
}
823854
(Term::If(range, cond_expr, then_expr, else_expr), _) => {
824855
let cond_expr = self.check(cond_expr, &self.bool_type.clone());
@@ -1110,9 +1141,25 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
11101141
}
11111142
Term::Let(range, def_pattern, def_type, def_expr, body_expr) => {
11121143
let (def_pattern, def_type_value) = self.synth_ann_pattern(def_pattern, *def_type);
1113-
let scrut = self.check_scrutinee(def_expr, def_type_value.clone());
1144+
let mut scrut = self.check_scrutinee(def_expr, def_type_value.clone());
11141145
let value = self.eval_env().eval(scrut.expr);
11151146

1147+
// Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times,
1148+
// and may be evaluated multiple times by the pattern match compiler
1149+
let extra_def = match (scrut.expr.is_trivial(), def_pattern.is_trivial()) {
1150+
(false, false) => {
1151+
let def_name = None; // TODO: generate a fresh name
1152+
let def_type = self.quote_env().quote(self.scope, &scrut.r#type);
1153+
let def_expr = scrut.expr.clone();
1154+
1155+
let var = core::Term::LocalVar(def_expr.span(), env::Index::last());
1156+
scrut.expr = self.scope.to_scope(var);
1157+
(self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone());
1158+
Some((def_name, def_type, def_expr))
1159+
}
1160+
_ => None,
1161+
};
1162+
11161163
let initial_len = self.local_env.len();
11171164
self.push_local_def(&def_pattern, value, scrut.r#type.clone());
11181165
let (body_expr, body_type) = self.synth(body_expr);
@@ -1126,6 +1173,19 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
11261173
def_expr.range(),
11271174
PatternMode::Let,
11281175
);
1176+
let expr = match extra_def {
1177+
None => expr,
1178+
Some((def_name, def_type, def_expr)) => {
1179+
self.local_env.pop();
1180+
core::Term::Let(
1181+
range.into(),
1182+
def_name,
1183+
self.scope.to_scope(def_type),
1184+
self.scope.to_scope(def_expr),
1185+
self.scope.to_scope(expr),
1186+
)
1187+
}
1188+
};
11291189
(expr, body_type)
11301190
}
11311191
Term::If(range, cond_expr, then_expr, else_expr) => {
@@ -1817,15 +1877,37 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
18171877
expected_type: &ArcValue<'arena>,
18181878
) -> core::Term<'arena> {
18191879
let expected_type = self.elim_env().force(expected_type);
1820-
let scrut = self.synth_scrutinee(scrutinee_expr);
1880+
let mut scrut = self.synth_scrutinee(scrutinee_expr);
18211881
let value = self.eval_env().eval(scrut.expr);
18221882

1883+
let patterns: Vec<_> = equations
1884+
.iter()
1885+
.map(|(pat, _)| self.check_pattern(pat, &scrut.r#type))
1886+
.collect();
1887+
1888+
// Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times,
1889+
// and may be evaluated multiple times by the pattern match compiler
1890+
let extra_def = match (
1891+
scrut.expr.is_trivial(),
1892+
patterns.iter().all(|pat| pat.is_trivial()),
1893+
) {
1894+
(false, false) => {
1895+
let def_name = None; // TODO: generate a fresh name
1896+
let def_type = self.quote_env().quote(self.scope, &scrut.r#type);
1897+
let def_expr = scrut.expr.clone();
1898+
1899+
let var = core::Term::LocalVar(def_expr.span(), env::Index::last());
1900+
scrut.expr = self.scope.to_scope(var);
1901+
(self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone());
1902+
Some((def_name, def_type, def_expr))
1903+
}
1904+
_ => None,
1905+
};
1906+
18231907
let mut rows = Vec::with_capacity(equations.len());
18241908
let mut exprs = Vec::with_capacity(equations.len());
18251909

1826-
for (pat, expr) in equations {
1827-
let pattern = self.check_pattern(pat, &scrut.r#type);
1828-
1910+
for (pattern, (_, expr)) in patterns.into_iter().zip(equations) {
18291911
let initial_len = self.local_env.len();
18301912
self.push_pattern(
18311913
&pattern,
@@ -1841,7 +1923,21 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
18411923
}
18421924

18431925
let matrix = patterns::PatMatrix::new(rows);
1844-
self.elab_match(matrix, &exprs, range, scrut.range, PatternMode::Match)
1926+
let expr = self.elab_match(matrix, &exprs, range, scrut.range, PatternMode::Match);
1927+
let expr = match extra_def {
1928+
None => expr,
1929+
Some((def_name, def_type, def_expr)) => {
1930+
self.local_env.pop();
1931+
core::Term::Let(
1932+
range.into(),
1933+
def_name,
1934+
self.scope.to_scope(def_type),
1935+
self.scope.to_scope(def_expr),
1936+
self.scope.to_scope(expr),
1937+
)
1938+
}
1939+
};
1940+
expr
18451941
}
18461942

18471943
fn synth_scrutinee(&mut self, scrutinee_expr: &Term<'_, ByteRange>) -> Scrutinee<'arena> {

tests/succeed/record-patterns/let-check.snap

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
stdout = '''
22
let _ : () = ();
3-
let x : Bool = (false, true)._0;
4-
let y : Bool = (false, true)._1;
5-
let a : Bool = (false, true)._0;
6-
let b : Bool = (false, true)._1;
3+
let _ : () = ();
4+
let _ : (Bool, Bool) = (false, true);
5+
let x : Bool = _._0;
6+
let y : Bool = _._1;
7+
let _ : (Bool, Bool) = (false, true);
8+
let a : Bool = _._0;
9+
let _ : (Bool, Bool) = (false, true);
10+
let b : Bool = _._1;
11+
let _ : (Bool, Bool) = (false, true);
712
let _ : (Bool, Bool) = (false, true);
813
() : ()
914
'''

tests/succeed/record-patterns/let-synth.snap

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
stdout = '''
22
let _ : () = ();
3-
let x : Bool = (false, true)._0;
4-
let y : Bool = (false, true)._1;
5-
let a : Bool = (false, true)._0;
6-
let b : Bool = (false, true)._1;
3+
let _ : () = ();
4+
let _ : (Bool, Bool) = (false, true);
5+
let x : Bool = _._0;
6+
let y : Bool = _._1;
7+
let _ : (Bool, Bool) = (false, true);
8+
let a : Bool = _._0;
9+
let _ : (Bool, Bool) = (false, true);
10+
let b : Bool = _._1;
11+
let _ : (Bool, Bool) = (false, true);
712
let _ : (Bool, Bool) = (false, true);
813
() : ()
914
'''
Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,22 @@
11
stdout = '''
2-
let and1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
3-
then if (x, y)._1 then true else false
4-
else false;
5-
let and2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
6-
then if (x, y)._1 then true else false
7-
else false;
8-
let and3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
9-
then if (x, y)._1 then true else false
10-
else if (x, y)._1 then false
11-
else false;
12-
let or1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
13-
then true
14-
else if (x, y)._1 then true
15-
else false;
16-
let or2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
17-
then true
18-
else if (x, y)._1 then true
19-
else false;
20-
let or3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
21-
then if (x, y)._1 then true else true
22-
else if (x, y)._1 then true
23-
else false;
24-
let xor1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
25-
then if (x, y)._1 then false else true
26-
else if (x, y)._1 then true
27-
else false;
28-
let xor2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
29-
then if (x, y)._1 then false else true
30-
else if (x, y)._1 then true
31-
else false;
32-
let xor3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
33-
then if (x, y)._1 then false else true
34-
else if (x, y)._1 then true
35-
else false;
2+
let and1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
3+
if _._0 then if _._1 then true else false else false;
4+
let and2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
5+
if _._0 then if _._1 then true else false else false;
6+
let and3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
7+
if _._0 then if _._1 then true else false else if _._1 then false else false;
8+
let or1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
9+
if _._0 then true else if _._1 then true else false;
10+
let or2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
11+
if _._0 then true else if _._1 then true else false;
12+
let or3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
13+
if _._0 then if _._1 then true else true else if _._1 then true else false;
14+
let xor1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
15+
if _._0 then if _._1 then false else true else if _._1 then true else false;
16+
let xor2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
17+
if _._0 then if _._1 then false else true else if _._1 then true else false;
18+
let xor3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
19+
if _._0 then if _._1 then false else true else if _._1 then true else false;
3620
() : ()
3721
'''
3822
stderr = ''

0 commit comments

Comments
 (0)