본문 바로가기

Error and Solve

[코드 정리] accelerator.save_state(path:str) 기본 사용법

반응형

 

 

1.accelerator.save_state(path:str) 기본 사용법 

 

`accelerator.save_state()` 함수는 기본적으로 하나의 인자를 받습니다. 

그 인자는 파일 경로 (string) 입니다. 이 인자는 모델, 옵티마이저, 스케줄러 등 훈련 상태를 저장할 위치를 지정하는 역할을 합니다.

accelerator.save_state(path: str)

accelerator.save_state(path: str)

 


파라미터:

- `path`: 저장할 파일의 경로를 지정하는 문자열입니다. 이 경로는 디렉토리와 파일 이름을 포함할 수 있습니다. 예를 들어, `"checkpoints/epoch_10"`와 같이 경로를 지정할 수 있습니다.

예시 1:

accelerator.save_state("checkpoints/model_epoch_10")

accelerator.save_state("checkpoints/model_epoch_10")


이 코드는 `checkpoints` 디렉토리에 `model_epoch_10`이라는 파일 이름으로 훈련 상태를 저장합니다.

 

 

 

 

예시 2:

accelerator.save_state(f"checkpoints/model_epoch_{epoch + 1}")

accelerator.save_state(f"checkpoints/model_epoch_{epoch + 1}")



이 코드는 각 에폭마다 훈련 상태를 저장하며, 파일 이름은 `model_epoch_1`, `model_epoch_2`, ...처럼 생성됩니다.


 

 

 

 

 

2.그 외 사용법  

 

 

다중 GPU에서 주 훈련 프로세스에서만 저장 

 

   - `accelerator.is_main_process`를 사용하여 다중 GPU 설정에서 주 프로세스에서만 상태를 저장하도록 할 수 있습니다. 이를 통해 불필요하게 여러 프로세스에서 상태가 저장되지 않도록 방지할 수 있습니다.

   if accelerator.is_main_process:
       accelerator.save_state("checkpoint")

 

 

 


정기적으로 상태 저장

 

   - 훈련 중 일정 주기마다 상태를 저장할 수 있습니다. 예를 들어, `save_freq` 파라미터를 사용하여 몇 번째 에폭마다 저장할지 설정할 수 있습니다.

  if (epoch + 1) % save_freq == 0:
       accelerator.save_state(f"checkpoint_epoch_{epoch + 1}")

 

 

 



디렉토리 구조 만들기

 

필요한 경우 디렉토리 구조를 미리 만들어 놓고, 해당 디렉토리에 저장할 수 있습니다.

   import os
   os.makedirs("checkpoints", exist_ok=True)
   accelerator.save_state("checkpoints/checkpoint_epoch_10")

 

 

 

 

 


자동 저장을 위한 콜백 사용

 

   - 훈련을 관리할 때 `save_state()`를 콜백 함수로 호출하여 훈련 중 자동으로 상태를 저장할 수 있습니다.
  훈련 중에 save_callback을 호출하여 상태 저장

  from accelerate import Accelerator
   accelerator = Accelerator()

   def save_callback(epoch):
       if epoch % save_freq == 0:
           accelerator.save_state(f"checkpoint_epoch_{epoch}")

 

 

 

 

 


훈련 중단 후 상태 복원

 

   - 훈련이 중단되었을 때 `accelerator.load_state()`를 사용하여 저장된 상태에서 훈련을 재개할 수 있습니다. 이는 `save_state()`와 함께 자주 사용됩니다.

   accelerator.load_state("checkpoint_epoch_10")

   accelerator.load_state("checkpoint_epoch_10")

 

 

 

 

End.

반응형