这里会显示出您选择的修订版和当前版本之间的差别。
两侧同时换到之前的修订记录 前一修订版 后一修订版 | 前一修订版 | ||
模型参数的载入 [2020/02/01 21:41] 127.0.0.1 外部编辑 |
模型参数的载入 [2020/09/02 17:07] (当前版本) 218.104.204.98 |
||
---|---|---|---|
行 4: | 行 4: | ||
====== 参数的筛选 ====== | ====== 参数的筛选 ====== | ||
pretrained_dict =... | pretrained_dict =... | ||
- | |||
model_dict = model.state_dict() | model_dict = model.state_dict() | ||
行 22: | 行 21: | ||
model.load_state_dict(model_dict) | model.load_state_dict(model_dict) | ||
+ | ====== 载入参数到指定设备 ====== | ||
+ | |||
+ | ===== 1. cpu -> cpu或者gpu -> gpu ===== | ||
+ | |||
+ | |||
+ | checkpoint = torch.load('modelparameters.pth') | ||
+ | model.load_state_dict(checkpoint) | ||
+ | |||
+ | ===== 2. cpu -> gpu 1 ===== | ||
+ | |||
+ | '' | ||
+ | torch.load('modelparameters.pth', map_location=lambda storage, loc: storage.cuda(1)) | ||
+ | '' | ||
+ | |||
+ | |||
+ | ===== 3. gpu 1 -> gpu 0 ===== | ||
+ | |||
+ | |||
+ | ''torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})'' | ||
+ | ===== 4. gpu -> cpu ===== | ||
+ | |||
+ | |||
+ | ''torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)'' | ||
+ | |||
+ | ''torch.load('modelparameters.pth', map_location='cpu')'' | ||