38
38
%res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{
39
39
%step_64 = arith.index_cast %arg0 : index to i64
40
40
%this_step = tensor.from_elements %step_64 : tensor<1xi64>
41
- %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step ) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{ bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64 >) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}>
41
+ %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %this_step, % p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<1x{precision}>, tensor<{ bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}>
42
42
scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}>
43
43
}}
44
44
return %res : tensor<{batch_size}x4x{lw}x{lh}x{precision}>
48
48
49
49
produce_img_split = r"""
50
50
module @sdxl_compiled_pipeline {{
51
- func.func private @{scheduler_module}.run_initialize(%arg0: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, !torch.vtensor<[ {bd},6], {precision}>, !torch.vtensor<[1],f16 >, !torch.vtensor<[ {num_steps}],f32 >) attributes {{torch.assume_strict_symbolic_shapes}}
52
- func.func private @{scheduler_module}.run_scale(%arg0: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, %arg1: !torch.vtensor<[1],si64 >, %arg2: !torch.vtensor<[ {num_steps}],f32 >) -> (!torch.vtensor<[{bd},4, {lh}, {lw}], {precision}>, !torch.vtensor<[1], {precision}>) attributes {{torch.assume_strict_symbolic_shapes}}
53
- func.func private @{scheduler_module}.run_step(%arg0: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, %arg1: !torch.vtensor<[1], {precision}>, %arg2: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
54
- func.func private @{unet_module}.{unet_function}(%arg0: !torch.vtensor<[{bd},4, {lh}, {lw}], {precision}>, %arg1: !torch.vtensor<[1], {precision}>, %arg2: !torch.vtensor<[ {bd}, {max_length},2048], {precision}>, %arg3: !torch.vtensor<[ {bd},1280], {precision}>, %arg4: !torch.vtensor<[ {bd},6], {precision}>, %arg5: !torch.vtensor<[1], {precision}>) -> !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
55
- func.func private @{vae_module}.decode(%arg0: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> !torch.vtensor<[ {batch_size},3, {height}, {width}], {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
56
-
57
- func.func @produce_image_latents(%sample: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, %p_embeds: !torch.vtensor<[ {bd}, {max_length},2048], {precision}>, %t_embeds: !torch.vtensor<[ {bd},1280], {precision}>, %guidance_scale: !torch.vtensor<[1], {precision}>) -> !torch.vtensor<[ {batch_size},3, {height}, {width}], {precision}> {{
58
- %noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, !torch.vtensor<[ {bd},6], {precision}>, !torch.vtensor<[1], {precision}>, !torch.vtensor<[ {num_steps}],f32 >)
51
+ func.func private @{scheduler_module}.run_initialize(%arg0: tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> (tensor< {batch_size}x4x {lh}x {lw}x {precision}>, tensor< {bd}x6x {precision}>, tensor<1xf16 >, tensor< {num_steps}xf32 >) attributes {{torch.assume_strict_symbolic_shapes}}
52
+ func.func private @{scheduler_module}.run_scale(%arg0: tensor< {batch_size}x4x {lh}x {lw}x {precision}>, %arg1: tensor<1xi64 >, %arg2: tensor< {num_steps}xf32 >) -> (tensor<{batch_size}x4x {lh}x {lw}x {precision}>, tensor<1x {precision}>) attributes {{torch.assume_strict_symbolic_shapes}}
53
+ func.func private @{scheduler_module}.run_step(%arg0: tensor< {batch_size}x4x {lh}x {lw}x {precision}>, %arg1: tensor<1x {precision}>, %arg2: tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> tensor< {batch_size}x4x {lh}x {lw}x {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
54
+ func.func private @{unet_module}.{unet_function}(%arg0: tensor<{batch_size}x4x {lh}x {lw}x {precision}>, %arg1: tensor<1x {precision}>, %arg2: tensor< {bd}x {max_length}x2048x {precision}>, %arg3: tensor< {bd}x1280x {precision}>, %arg4: tensor< {bd}x6x {precision}>, %arg5: tensor<1x {precision}>) -> tensor< {batch_size}x4x {lh}x {lw}x {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
55
+ func.func private @{vae_module}.decode(%arg0: tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> tensor< {batch_size}x3x {height}x {width}x {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
56
+
57
+ func.func @produce_image_latents(%sample: tensor< {batch_size}x4x {lh}x {lw}x {precision}>, %p_embeds: tensor< {bd}x {max_length}x2048x {precision}>, %t_embeds: tensor< {bd}x1280x {precision}>, %guidance_scale: tensor<1x {precision}>) -> tensor< {batch_size}x3x {height}x {width}x {precision}> {{
58
+ %noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> (tensor< {batch_size}x4x {lh}x {lw}x {precision}>, tensor< {bd}x6x {precision}>, tensor<1x {precision}>, tensor< {num_steps}xf32 >)
59
59
%c0 = arith.constant 0 : index
60
60
%c1 = arith.constant 1 : index
61
61
%n_steps = arith.constant {num_steps} : index
62
- %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) {{
62
+ %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor< {batch_size}x4x {lh}x {lw}x {precision}>) {{
63
63
%step_64 = arith.index_cast %arg0 : index to i64
64
64
%this_step = tensor.from_elements %step_64 : tensor<1xi64>
65
- %step_torch = torch_c.from_builtin_tensor %this_step : tensor<1xi64> -> !torch.vtensor<[1],si64>
66
- %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %step_torch, %timesteps) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],si64>, !torch.vtensor<[{num_steps}],f32>) -> (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>)
67
- %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{bd},{max_length},2048],{precision}>, !torch.vtensor<[{bd},1280],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
68
- %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
69
- scf.yield %pred : !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
65
+ %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %timesteps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>)
66
+ %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}>
67
+ %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}>
68
+ scf.yield %pred : tensor<{batch_size}x4x{lh}x{lw}x{precision}>
70
69
}}
71
- %image = func.call @{vae_module}.decode(%res): (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> !torch.vtensor<[ {batch_size},3, {height}, {width}], {precision}>
72
- return %image : !torch.vtensor<[ {batch_size},3, {height}, {width}], {precision}>
70
+ %image = func.call @{vae_module}.decode(%res): (tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> tensor< {batch_size}x3x {height}x {width}x {precision}>
71
+ return %image : tensor< {batch_size}x3x {height}x {width}x {precision}>
73
72
}}
74
73
}}
75
74
"""
@@ -128,4 +127,4 @@ def get_pipeline_ir(
128
127
scheduler_module = scheduler_module_name ,
129
128
vae_module = vae_module_name ,
130
129
num_steps = num_steps ,
131
- )
130
+ )
0 commit comments