Skip to content

Commit f042578

Browse files
committed
Start deprecating updates API
Using DeprecationWarning to keep it visible only for devs for now
1 parent 5fc02e1 commit f042578

File tree

9 files changed

+256
-120
lines changed

9 files changed

+256
-120
lines changed

pytensor/gradient.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,7 +2188,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
21882188
# It is possible that the inputs are disconnected from expr,
21892189
# even if they are connected to cost.
21902190
# This should not be an error.
2191-
hess, updates = pytensor.scan(
2191+
hess = pytensor.scan(
21922192
lambda i, y, x: grad(
21932193
y[i],
21942194
x,
@@ -2197,9 +2197,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
21972197
),
21982198
sequences=pytensor.tensor.arange(expr.shape[0]),
21992199
non_sequences=[expr, input],
2200-
)
2201-
assert not updates, (
2202-
"Scan has returned a list of updates; this should not happen."
2200+
return_updates=False,
22032201
)
22042202
hessians.append(hess)
22052203
return as_list_or_tuple(using_list, using_tuple, hessians)

pytensor/scan/basic.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,26 @@ def isNaN_or_Inf_or_None(x):
163163
return isNone or isNaN or isInf or isStr
164164

165165

166+
def _manage_output_api_change(outputs, updates, return_updates):
167+
if return_updates:
168+
warnings.warn(
169+
"Scan return signature well change. Updates dict will not be returned, only the first argument. "
170+
"Pass `return_updates=False` to conform to the new API and avoid this warning",
171+
DeprecationWarning,
172+
# Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
173+
stacklevel=2,
174+
)
175+
else:
176+
if updates:
177+
raise ValueError(
178+
f"return_updates=False but Scan produced updates {updates}."
179+
"Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates."
180+
)
181+
return outputs
182+
183+
return outputs, updates
184+
185+
166186
def scan(
167187
fn,
168188
sequences=None,
@@ -177,6 +197,7 @@ def scan(
177197
allow_gc=None,
178198
strict=False,
179199
return_list=False,
200+
return_updates: bool = True,
180201
):
181202
r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
182203
@@ -873,6 +894,11 @@ def wrap_into_list(x):
873894
raw_inner_outputs = fn(*args)
874895

875896
condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
897+
if updates:
898+
warnings.warn(
899+
"Updates functionality in Scan are deprecated. Use explicit outputs_info and build shared update expressions manually, even for RNGs.",
900+
DeprecationWarning, # Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
901+
)
876902
if condition is not None:
877903
as_while = True
878904
else:
@@ -895,7 +921,7 @@ def wrap_into_list(x):
895921
if not return_list and len(outputs) == 1:
896922
outputs = outputs[0]
897923

898-
return (outputs, updates)
924+
return _manage_output_api_change(outputs, updates, return_updates)
899925

900926
##
901927
# Step 4. Compile the dummy function
@@ -914,6 +940,8 @@ def wrap_into_list(x):
914940
fake_outputs = clone_replace(
915941
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
916942
)
943+
# TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
944+
# to find implicit inputs in a way that reduces the size of the inner function
917945
known_inputs = [*args, *fake_nonseqs]
918946
extra_inputs = [
919947
x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
@@ -1206,12 +1234,14 @@ def remove_dimensions(outs, offsets=None):
12061234

12071235
offset += n_nit_sot
12081236

1209-
# Support for explicit untraced sit_sot
1237+
# Legacy support for explicit untraced sit_sot and those built with update dictionary
1238+
# Switch to n_untraced_sit_sot_outs after deprecation period
12101239
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
12111240
untraced_sit_sot_outs = scan_outs[
12121241
offset : offset + n_explicit_untraced_sit_sot_outs
12131242
]
12141243

1244+
# Legacy support: map shared outputs to their updates
12151245
offset += n_explicit_untraced_sit_sot_outs
12161246
for idx, update_rule in enumerate(scan_outs[offset:]):
12171247
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
@@ -1244,4 +1274,4 @@ def remove_dimensions(outs, offsets=None):
12441274
elif len(scan_out_list) == 0:
12451275
scan_out_list = None
12461276

1247-
return scan_out_list, update_map
1277+
return _manage_output_api_change(scan_out_list, update_map, return_updates)

pytensor/tensor/pad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,12 @@ def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
314314

315315

316316
def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis):
317-
[_, parts], _ = scan(
317+
[_, parts] = scan(
318318
inner_func,
319319
non_sequences=[array, array_flipped],
320320
outputs_info=[0, None],
321321
n_steps=repeats,
322+
return_updates=False,
322323
)
323324

324325
parts = moveaxis(parts, 0, axis)

tests/link/jax/test_scan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
2424
def test_scan_sit_sot(view):
2525
x0 = pt.scalar("x0", dtype="float64")
26-
xs, _ = scan(
26+
xs = scan(
2727
lambda xtm1: xtm1 + 1,
2828
outputs_info=[x0],
2929
n_steps=10,
30+
return_updates=False,
3031
)
3132
if view:
3233
xs = xs[view]

tests/link/numba/test_scan.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,12 @@ def power_step(prior_result, x):
343343

344344
def test_grad_sitsot():
345345
def get_sum_of_grad(inp):
346-
scan_outputs, _updates = scan(
347-
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA"
346+
scan_outputs = scan(
347+
fn=lambda x: x * 2,
348+
outputs_info=[inp],
349+
n_steps=5,
350+
mode="NUMBA",
351+
return_updates=False,
348352
)
349353
return grad(scan_outputs.sum(), inp).sum()
350354

0 commit comments

Comments
 (0)