Skip to content

Commit 4d606eb

Browse files
authored
Fix using return value of produce (#191)
* Fix a bug with using the return value of produce * Bump patch version to 0.9.3 * Simplify test
1 parent baf1c50 commit 4d606eb

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.9.2"
6+
version = "0.9.3"
77

88
[deps]
99
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"

src/copyable_task.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,15 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
898898
prod_val = deref_id
899899
end
900900

901+
# Set the ref for this statement, as we would for any other call or invoke.
902+
# The TapedTask may need to read this ref when it resumes, if the return
903+
# value of `produce` is used within the original function.
904+
if is_used_dict[id]
905+
out_ind = ssa_id_to_ref_index_map[id]
906+
set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, prod_val)
907+
push!(inst_pairs, (ID(), new_inst(set_ref)))
908+
end
909+
901910
# Construct a `ProducedValue`.
902911
val_id = ID()
903912
push!(inst_pairs, (val_id, new_inst(Expr(:call, ProducedValue, prod_val))))

test/copyable_task.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,14 @@
220220
g() = produce(rand() > -1.0 ? 2 : 0.1)
221221
@test Libtask.consume(Libtask.TapedTask(nothing, g)) == 2
222222
end
223+
224+
@testset "Return produce" begin
225+
# Test calling a function that does something with the return value of `produce`.
226+
# In this case it just returns it. This used to error, see
227+
# https://github.com/TuringLang/Libtask.jl/issues/190.
228+
f(obs) = produce(obs)
229+
tt = Libtask.TapedTask(nothing, f, :a)
230+
@test Libtask.consume(tt) === :a
231+
@test Libtask.consume(tt) === nothing
232+
end
223233
end

0 commit comments

Comments
 (0)