@@ -234,6 +234,24 @@ def calc_ctmrg_expectation_custom(
234
234
peps_tensors , unitcell , spiral_vectors = _map_tensors (
235
235
input_tensors , unitcell , convert_to_unitcell_func , True
236
236
)
237
+
238
+ if any (i .size == 1 for i in spiral_vectors ):
239
+ spiral_vectors_x = additional_input .get ("spiral_vectors_x" )
240
+ spiral_vectors_y = additional_input .get ("spiral_vectors_y" )
241
+ if spiral_vectors_x is not None :
242
+ if isinstance (spiral_vectors_x , jnp .ndarray ):
243
+ spiral_vectors_x = (spiral_vectors_x ,)
244
+ spiral_vectors = tuple (
245
+ jnp .array ((sx , sy ))
246
+ for sx , sy in safe_zip (spiral_vectors_x , spiral_vectors )
247
+ )
248
+ elif spiral_vectors_y is not None :
249
+ if isinstance (spiral_vectors_y , jnp .ndarray ):
250
+ spiral_vectors_y = (spiral_vectors_y ,)
251
+ spiral_vectors = tuple (
252
+ jnp .array ((sx , sy ))
253
+ for sx , sy in safe_zip (spiral_vectors , spiral_vectors_y )
254
+ )
237
255
else :
238
256
peps_tensors , unitcell = _map_tensors (
239
257
input_tensors , unitcell , convert_to_unitcell_func , False
0 commit comments