Skip to content

check if model is instance of DataParallel before saving checkpoint #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jiayangshi
Copy link

Problem

In current save_checkpoint function, if the model is on multiple GPUs, i.e. model is a instance of torch.nn.DataParallel, and then the saved checkpoint could not be loaded again.

Describe your changes

Follow the pytorch tutorial, first the current model is checked, if it is a instance of DataParallel class.
If the model is a instance of DataParallel class, then model.module.state_dict() is saved instead of model.state_dict() in current implementation.

@Bjarten Bjarten deleted the branch Bjarten:main October 14, 2024 03:52
@Bjarten Bjarten closed this Oct 14, 2024
@Bjarten Bjarten reopened this Oct 14, 2024
@Bjarten Bjarten changed the base branch from master to main October 14, 2024 04:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants