Skip to content

Commit 3eca659

Browse files
version1.2.1
1 parent 92a985c commit 3eca659

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
## Unreleased
44

5+
## v1.2.1
6+
7+
### Fixed
8+
9+
- Jax can be omited in the env
10+
511
## v1.2.0
612

713
### Added

tensorcircuit/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.2.0"
1+
__version__ = "1.2.1"
22
__author__ = "TensorCircuit Authors"
33
__creator__ = "refraction-ray"
44

tensorcircuit/interfaces/jax.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
from typing import Any, Callable, Tuple, Optional, Union, Sequence
66
from functools import wraps, partial
77

8-
import jax
9-
from jax import custom_vjp
10-
118
from ..cons import backend
129
from .tensortrans import general_args_to_backend
1310

@@ -22,6 +19,8 @@ def jax_wrapper(
2219
] = None,
2320
output_dtype: Optional[Union[Any, Sequence[Any]]] = None,
2421
) -> Callable[..., Any]:
22+
import jax
23+
2524
@wraps(fun)
2625
def fun_jax(*x: Any) -> Any:
2726
def wrapped_fun(*args: Any) -> Any:
@@ -129,6 +128,9 @@ def create_jax_function(
129128
output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None,
130129
output_dtype: Optional[Any] = None,
131130
) -> Callable[..., Any]:
131+
import jax
132+
from jax import custom_vjp
133+
132134
if jit:
133135
fun = backend.jit(fun)
134136

0 commit comments

Comments
 (0)