这里会显示出您选择的修订版和当前版本之间的差别。
| 两侧同时换到之前的修订记录 前一修订版 后一修订版 | 前一修订版 | ||
|
模型参数的载入 [2020/09/02 11:24] 218.104.204.98 |
模型参数的载入 [2020/09/02 17:07] (当前版本) 218.104.204.98 |
||
|---|---|---|---|
| 行 4: | 行 4: | ||
| ====== 参数的筛选 ====== | ====== 参数的筛选 ====== | ||
| pretrained_dict =... | pretrained_dict =... | ||
| - | |||
| model_dict = model.state_dict() | model_dict = model.state_dict() | ||
| 行 22: | 行 21: | ||
| model.load_state_dict(model_dict) | model.load_state_dict(model_dict) | ||
| - | ====== 载入参数到cpu ====== | + | ====== 载入参数到指定设备 ====== |
| + | |||
| + | ===== 1. cpu -> cpu或者gpu -> gpu ===== | ||
| + | |||
| + | |||
| + | checkpoint = torch.load('modelparameters.pth') | ||
| + | model.load_state_dict(checkpoint) | ||
| + | |||
| + | ===== 2. cpu -> gpu 1 ===== | ||
| + | |||
| + | '' | ||
| + | torch.load('modelparameters.pth', map_location=lambda storage, loc: storage.cuda(1)) | ||
| + | '' | ||
| + | |||
| + | |||
| + | ===== 3. gpu 1 -> gpu 0 ===== | ||
| + | |||
| + | |||
| + | ''torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})'' | ||
| + | ===== 4. gpu -> cpu ===== | ||
| - | ''torch.load(opt.model,map_location='cpu')'' | + | ''torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)'' |
| + | ''torch.load('modelparameters.pth', map_location='cpu')'' | ||