Skip to content

Commit fb47546

Browse files
Make the casting in lists the same as regular inputs. (#8373)
1 parent 180db67 commit fb47546

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

comfy/model_base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ class ModelSampling(s, c):
102102
return ModelSampling(model_config)
103103

104104

105+
def convert_tensor(extra, dtype):
106+
if hasattr(extra, "dtype"):
107+
if extra.dtype != torch.int and extra.dtype != torch.long:
108+
extra = extra.to(dtype)
109+
return extra
110+
111+
105112
class BaseModel(torch.nn.Module):
106113
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
107114
super().__init__()
@@ -165,13 +172,13 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
165172
extra_conds = {}
166173
for o in kwargs:
167174
extra = kwargs[o]
175+
168176
if hasattr(extra, "dtype"):
169-
if extra.dtype != torch.int and extra.dtype != torch.long:
170-
extra = extra.to(dtype)
171-
if isinstance(extra, list):
177+
extra = convert_tensor(extra, dtype)
178+
elif isinstance(extra, list):
172179
ex = []
173180
for ext in extra:
174-
ex.append(ext.to(dtype))
181+
ex.append(convert_tensor(ext, dtype))
175182
extra = ex
176183
extra_conds[o] = extra
177184

0 commit comments

Comments
 (0)