@@ -1276,11 +1276,30 @@ def __init__(
12761276
12771277 self .to_modality_shape_fn = cast_tuple (to_modality_shape_fn , self .num_modalities )
12781278
1279+ # default token lengths for respective modality
1280+ # fallback if the language model does not come up with valid dimensions
1281+
1282+ if not exists (modality_default_shape ) or is_bearable (modality_default_shape , tuple [int , ...]):
1283+ modality_default_shape = (modality_default_shape ,) * self .num_modalities
1284+
1285+ self .modality_default_shape = modality_default_shape
1286+
1287+ assert len (self .modality_default_shape ) == self .num_modalities
1288+
1289+ self .fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid
1290+
1291+ # default `modality_num_dim` to `len(modality_default_shape)` if latter is specified but former not
1292+
1293+ modality_num_dim = default (modality_num_dim , tuple (len (shape ) for shape in self .modality_default_shape ))
1294+
12791295 # specifying the number of dimensions for the modality, which will be hard validated
12801296
12811297 self .modality_num_dim = cast_tuple (modality_num_dim , self .num_modalities )
1298+
12821299 assert len (self .modality_num_dim ) == self .num_modalities
12831300
1301+ assert all ([not exists (ndim ) or not exists (shape ) or len (shape ) == ndim for ndim , shape in zip (self .modality_num_dim , self .modality_default_shape )])
1302+
12841303 # whether to add an extra axial positional embedding per modality
12851304
12861305 self .add_pos_emb = cast_tuple (add_pos_emb , self .num_modalities )
@@ -1318,18 +1337,6 @@ def __init__(
13181337
13191338 self .maybe_add_temp_batch_dim = add_temp_batch_dim if modality_encoder_decoder_requires_batch_dim else identity
13201339
1321- # default token lengths for respective modality
1322- # fallback if the language model does not come up with valid dimensions
1323-
1324- if not exists (modality_default_shape ) or is_bearable (modality_default_shape , tuple [int , ...]):
1325- modality_default_shape = (modality_default_shape ,) * self .num_modalities
1326-
1327- self .modality_default_shape = modality_default_shape
1328-
1329- assert len (self .modality_default_shape ) == self .num_modalities
1330-
1331- self .fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid
1332-
13331340 # store number of text tokens
13341341
13351342 self .num_text_tokens = num_text_tokens
0 commit comments