Skip to content

Commit f4dd305

Browse files
committed
.WIP
1 parent bcbe5cf commit f4dd305

File tree

2 files changed

+65
-48
lines changed

2 files changed

+65
-48
lines changed

pytensor/compile/builders.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from functools import partial
77
from typing import Union, cast
88

9-
from pytensor.compile.function import function
9+
from pytensor.compile import get_default_mode, insert_deepcopy
1010
from pytensor.compile.function.pfunc import rebuild_collect_shared
11+
from pytensor.compile.function.types import add_supervisor_to_fgraph
12+
from pytensor.compile.io import In, Out
1113
from pytensor.compile.sharedvalue import SharedVariable
1214
from pytensor.configdefaults import config
1315
from pytensor.gradient import DisconnectedType, Rop, grad
@@ -433,6 +435,7 @@ def __init__(
433435
assert isinstance(name, str), "name must be None or string object"
434436
self.name = name
435437
self.destroy_map = destroy_map if destroy_map is not None else {}
438+
self._prepared_fgraph = None
436439

437440
def __eq__(self, other):
438441
# TODO: recognize a copy
@@ -847,16 +850,48 @@ def infer_shape(self, fgraph, node, shapes):
847850

848851
return ret
849852

853+
def _prepare_fgraph(self, impl):
854+
if self._prepared_fgraph is None:
855+
mode = get_default_mode()
856+
if impl == "py":
857+
mode = mode.excluding("cxx")
858+
rewriter = mode.optimizer
859+
860+
fgraph = self.fgraph
861+
wrapped_inputs = [
862+
In(inp, borrow=False, mutable=False) for inp in self.fgraph.inputs
863+
]
864+
wrapped_outputs = [Out(out, borrow=True) for out in self.fgraph.outputs]
865+
add_supervisor_to_fgraph(
866+
fgraph,
867+
wrapped_inputs,
868+
accept_inplace=False,
869+
)
870+
rewriter(fgraph)
871+
insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs)
872+
self._prepared_fgraph = fgraph
873+
874+
return self._prepared_fgraph
875+
850876
@property
851877
def fn(self):
852878
"""Lazily compile the inner function graph."""
853-
if getattr(self, "_fn", None) is not None:
854-
return self._fn
855-
856-
self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
857-
self._fn.trust_input = True
858-
859-
return self._fn
879+
return None
880+
# if getattr(self, "_fn", None) is not None:
881+
# return self._fn
882+
#
883+
# self._fn = pfunc(
884+
# wrapped_inputs,
885+
# wrapped_outputs,
886+
# mode=mode_instance,
887+
# accept_inplace=True,
888+
# on_unused_input="ignore",
889+
# fgraph=self.fgraph,
890+
# )
891+
# self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
892+
# self._fn.trust_input = True
893+
#
894+
# return self._fn
860895

861896
@property
862897
def inner_inputs(self):
@@ -875,11 +910,7 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
875910
from pytensor.link.c.basic import CLinker
876911
from pytensor.link.vm import VMLinker
877912

878-
# FIXME: Don't call self.fn just to get the optimized fgraph
879-
fg = self.fn.maker.fgraph
880-
# fg = self.fgraph
881-
# rewriter = get_default_mode().optimizer
882-
# rewriter(fg)
913+
fg = self._prepare_fgraph(impl)
883914
fg_no_recycling = [
884915
new_o
885916
for (new_o, old_o) in zip(fg.outputs, node.outputs, strict=True)
@@ -890,8 +921,8 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
890921
node_output_storage = [storage_map[r] for r in node.outputs]
891922

892923
def create_thunk(linker):
893-
linker.accept(fg, no_recycling=fg_no_recycling)
894-
thunk, _, _ = linker.make_thunk(
924+
linker.accept(fg.clone(), no_recycling=fg_no_recycling)
925+
thunk, i, o = linker.make_thunk(
895926
input_storage=node_input_storage, output_storage=node_output_storage
896927
)
897928

tests/compile/test_builders.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -538,17 +538,6 @@ def test_infer_shape(self):
538538
assert opt_res.shape_feature.shape_of[x] is None
539539
assert opt_res.shape_feature.shape_of[z][0].data == 2
540540

541-
@config.change_flags(compute_test_value="raise")
542-
def test_compute_test_value(self):
543-
x = scalar("x")
544-
x.tag.test_value = np.array(1.0, dtype=config.floatX)
545-
op = OpFromGraph([x], [x**3])
546-
y = scalar("y")
547-
y.tag.test_value = np.array(1.0, dtype=config.floatX)
548-
f = op(y)
549-
grad_f = grad(f, y)
550-
assert grad_f.tag.test_value is not None
551-
552541
def test_make_node_shared(self):
553542
"""Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`."""
554543

@@ -619,24 +608,24 @@ def test_shared_to_nonshared_input(self):
619608

620609
assert np.array_equal(res_2, 1.0)
621610

622-
def test_outputs_consistency(self):
623-
"""Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`."""
624-
625-
x = scalar("x")
626-
op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN")
627-
628-
# Confirm that the inner-graph is as expected
629-
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
630-
631-
# These outputs of the compiled `op.fgraph` should differ from the
632-
# original, uncompiled `op.fgraph` outputs
633-
fn = op.fn
634-
new_inputs = fn.maker.fgraph.inputs
635-
new_outputs = fn.maker.fgraph.outputs
636-
assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x])
637-
638-
# The original `op.fgraph` outputs should stay the same, though
639-
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
611+
# def test_outputs_consistency(self):
612+
# """Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`."""
613+
#
614+
# x = scalar("x")
615+
# op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN")
616+
#
617+
# # Confirm that the inner-graph is as expected
618+
# assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
619+
#
620+
# # These outputs of the compiled `op.fgraph` should differ from the
621+
# # original, uncompiled `op.fgraph` outputs
622+
# fn = op.fn
623+
# new_inputs = fn.maker.fgraph.inputs
624+
# new_outputs = fn.maker.fgraph.outputs
625+
# assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x])
626+
#
627+
# # The original `op.fgraph` outputs should stay the same, though
628+
# assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
640629

641630
def test_explicit_input_from_constant(self):
642631
x = pt.dscalar("x")
@@ -783,10 +772,7 @@ def _f(x):
783772
out = f(out)
784773

785774
compiled_fn = function([x], out, trust_input=True, mode=mode)
786-
compiled_fn.dprint(print_memory_map=True)
787-
compiled_fn.vm.allow_gc = (
788-
False # For fairness to the default VM, since OFG inner VM does not do GC
789-
)
775+
compiled_fn.vm.allow_gc = False
790776

791777
rng = np.random.default_rng(1)
792778
x_test = rng.normal(size=(10,))

0 commit comments

Comments
 (0)