diff --git a/mova/engine/trainer/accelerate/accelerate_trainer.py b/mova/engine/trainer/accelerate/accelerate_trainer.py index 1f25d95..7b8df7f 100644 --- a/mova/engine/trainer/accelerate/accelerate_trainer.py +++ b/mova/engine/trainer/accelerate/accelerate_trainer.py @@ -146,8 +146,8 @@ def _setup(self): if self.use_fsdp: from accelerate import FullyShardedDataParallelPlugin from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullOptimStateDictConfig, - FullStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, ) for attr_name in ['text_encoder', 'prompter', 'video_vae', 'audio_vae']: @@ -170,13 +170,11 @@ def _setup(self): fsdp_config['reshard_after_forward'] = True fsdp_plugin = FullyShardedDataParallelPlugin( - state_dict_config=FullStateDictConfig( + state_dict_config=ShardedStateDictConfig( offload_to_cpu=True, - # rank0_only=True, ), - optim_state_dict_config=FullOptimStateDictConfig( + optim_state_dict_config=ShardedOptimStateDictConfig( offload_to_cpu=True, - # rank0_only=True, ), **fsdp_config ) @@ -511,6 +509,17 @@ def _save_checkpoint(self, final: bool = False): os.makedirs(os.path.join(step_dir, "accelerator"), exist_ok=True) self.accelerator.wait_for_everyone() self.accelerator.save_state(os.path.join(step_dir, "accelerator")) + if self.use_fsdp: + rank = self.accelerator.process_index + accel_dir = os.path.join(step_dir, "accelerator") + os.makedirs(accel_dir, exist_ok=True) + torch.save(self.optimizer.state_dict(), + os.path.join(accel_dir, f"optimizer_{rank}.bin")) + torch.save(self.lr_scheduler.state_dict(), + os.path.join(accel_dir, f"scheduler_{rank}.bin")) + else: + self.accelerator.save_state(os.path.join(step_dir, "accelerator")) + def _resume_checkpoint(self, checkpoint_path: str): """Resume checkpoint""" @@ -522,7 +531,17 @@ def _resume_checkpoint(self, checkpoint_path: str): accelerator_path = os.path.join(checkpoint_path, "accelerator") if os.path.exists(accelerator_path): - self.accelerator.load_state(accelerator_path) + if self.use_fsdp: + rank = self.accelerator.process_index + optim_path = os.path.join(accelerator_path, f"optimizer_{rank}.bin") + sched_path = os.path.join(accelerator_path, f"scheduler_{rank}.bin") + if os.path.exists(optim_path): + self.optimizer.load_state_dict(torch.load(optim_path, map_location="cpu")) + if os.path.exists(sched_path): + self.lr_scheduler.load_state_dict(torch.load(sched_path, map_location="cpu")) + else: + self.accelerator.load_state(accelerator_path) + if self.use_lora: from mova.engine.trainer.accelerate.lora_utils import load_lora_weights