Skip to content

[Bug][Dynamic Embedding] improper optimizier state_dict momentum2 key while constructing PSCollectionΒ #2177

Open
@JacoCheung

Description

@JacoCheung

Describe the bug

A PSCollection should contain optimizer states besides weights. The optimizer states tensors are obtained directly from EmbeddingCollection Module.

However, the sharded_module.fused_optimizer.state_dict()['state'] does not contain key {table_name}.momentum2 because

  1. TBE::get_optimizer_state() which is used by PSCollection will not return key like xxx.momentum1 or xxx.momentum2. They are customized by TBE.
  2. The states keys are renamed by torchrec::EmbeddingFusedOptimizer. The first state falls back on xxx.momentum1 while the left keys are copied from above retrived results.

See the below illustration where optimizer is Adam. The expected number of state tensors should be 2, but the it eventually gives momentum1 and leaves momentum2 (which is synonymously exp_avg_sq) out.

opt report

It will pose impact on all kinds of optimizer that contains momentum2.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions