我今天调了一天,还是不行,就感觉是zero的问题,我开个issue在这里。后续再debug一下,防止遗忘。
主要原因还是shard问题。zero只能要么全部learnable的parameters都shard,要么都不shard。然后中间如果要gather sharded的parameters,只能暂时得到,比如model算loss的时候wrapp一个ctx使得其可以gather sharded parameter,然而backward的时候依然会出问题,因为这个ctx只在forward的时候有效果。如果要在backward那里也加上就比较麻烦,因为改loss.backward()需要覆写__inner_training_loop, 改动会有点大。
感觉没必要,本来zero3感觉就是fsdp full shard的下位替代?