Skip to content

Commit e46e967

Browse files
inducerkaushikcfd
authored andcommitted
pytato: Convert einsum arg_names into PrefixNamed tags
1 parent 4374e44 commit e46e967

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,19 +219,33 @@ def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
219219
def einsum(self, spec, *args, arg_names=None, tagged=()):
220220
import pyopencl.array as cla
221221
import pytato as pt
222-
if arg_names is not None:
223-
from warnings import warn
224-
warn("'arg_names' don't bear any significance in "
225-
"PytatoPyOpenCLArrayContext.", stacklevel=2)
222+
if arg_names is None:
223+
arg_names = (None,) * len(args)
226224

227-
def preprocess_arg(arg):
225+
def preprocess_arg(name, arg):
228226
if isinstance(arg, cla.Array):
229-
return self.thaw(arg)
227+
ary = self.thaw(arg)
230228
else:
231229
assert isinstance(arg, pt.Array)
232-
return arg
230+
ary = arg
233231

234-
return pt.einsum(spec, *(preprocess_arg(arg) for arg in args))
232+
if name is not None:
233+
from pytato.tags import PrefixNamed
234+
235+
# Tagging Placeholders with naming-related tags is pointless:
236+
# They already have names. It's also counterproductive, as
237+
# multiple placeholders with the same name that are not
238+
# also the same object are not allowed, and this would produce
239+
# a different Placeholder object of the same name.
240+
if not isinstance(ary, pt.Placeholder):
241+
ary = ary.tagged(PrefixNamed(name))
242+
243+
return ary
244+
245+
return pt.einsum(spec, *[
246+
preprocess_arg(name, arg)
247+
for name, arg in zip(arg_names, args)
248+
])
235249

236250
@property
237251
def permits_inplace_modification(self):

0 commit comments

Comments
 (0)