6
6
from functools import partial
7
7
from typing import Union , cast
8
8
9
- from pytensor .compile . function import function
9
+ from pytensor .compile import get_default_mode , insert_deepcopy
10
10
from pytensor .compile .function .pfunc import rebuild_collect_shared
11
+ from pytensor .compile .function .types import add_supervisor_to_fgraph
12
+ from pytensor .compile .io import In , Out
11
13
from pytensor .compile .sharedvalue import SharedVariable
12
14
from pytensor .configdefaults import config
13
15
from pytensor .gradient import DisconnectedType , Rop , grad
@@ -433,6 +435,7 @@ def __init__(
433
435
assert isinstance (name , str ), "name must be None or string object"
434
436
self .name = name
435
437
self .destroy_map = destroy_map if destroy_map is not None else {}
438
+ self ._prepared_fgraph = None
436
439
437
440
def __eq__ (self , other ):
438
441
# TODO: recognize a copy
@@ -847,16 +850,48 @@ def infer_shape(self, fgraph, node, shapes):
847
850
848
851
return ret
849
852
853
+ def _prepare_fgraph (self , impl ):
854
+ if self ._prepared_fgraph is None :
855
+ mode = get_default_mode ()
856
+ if impl == "py" :
857
+ mode = mode .excluding ("cxx" )
858
+ rewriter = mode .optimizer
859
+
860
+ fgraph = self .fgraph
861
+ wrapped_inputs = [
862
+ In (inp , borrow = False , mutable = False ) for inp in self .fgraph .inputs
863
+ ]
864
+ wrapped_outputs = [Out (out , borrow = True ) for out in self .fgraph .outputs ]
865
+ add_supervisor_to_fgraph (
866
+ fgraph ,
867
+ wrapped_inputs ,
868
+ accept_inplace = False ,
869
+ )
870
+ rewriter (fgraph )
871
+ insert_deepcopy (fgraph , wrapped_inputs , wrapped_outputs )
872
+ self ._prepared_fgraph = fgraph
873
+
874
+ return self ._prepared_fgraph
875
+
850
876
@property
851
877
def fn (self ):
852
878
"""Lazily compile the inner function graph."""
853
- if getattr (self , "_fn" , None ) is not None :
854
- return self ._fn
855
-
856
- self ._fn = function (self .inner_inputs , self .inner_outputs , ** self .kwargs )
857
- self ._fn .trust_input = True
858
-
859
- return self ._fn
879
+ return None
880
+ # if getattr(self, "_fn", None) is not None:
881
+ # return self._fn
882
+ #
883
+ # self._fn = pfunc(
884
+ # wrapped_inputs,
885
+ # wrapped_outputs,
886
+ # mode=mode_instance,
887
+ # accept_inplace=True,
888
+ # on_unused_input="ignore",
889
+ # fgraph=self.fgraph,
890
+ # )
891
+ # self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
892
+ # self._fn.trust_input = True
893
+ #
894
+ # return self._fn
860
895
861
896
@property
862
897
def inner_inputs (self ):
@@ -875,11 +910,7 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
875
910
from pytensor .link .c .basic import CLinker
876
911
from pytensor .link .vm import VMLinker
877
912
878
- # FIXME: Don't call self.fn just to get the optimized fgraph
879
- fg = self .fn .maker .fgraph
880
- # fg = self.fgraph
881
- # rewriter = get_default_mode().optimizer
882
- # rewriter(fg)
913
+ fg = self ._prepare_fgraph (impl )
883
914
fg_no_recycling = [
884
915
new_o
885
916
for (new_o , old_o ) in zip (fg .outputs , node .outputs , strict = True )
@@ -890,8 +921,8 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
890
921
node_output_storage = [storage_map [r ] for r in node .outputs ]
891
922
892
923
def create_thunk (linker ):
893
- linker .accept (fg , no_recycling = fg_no_recycling )
894
- thunk , _ , _ = linker .make_thunk (
924
+ linker .accept (fg . clone () , no_recycling = fg_no_recycling )
925
+ thunk , i , o = linker .make_thunk (
895
926
input_storage = node_input_storage , output_storage = node_output_storage
896
927
)
897
928
0 commit comments