fancyfeast commited on
Commit
df9e86f
1 Parent(s): 6982e15
Files changed (1) hide show
  1. Models.py +2 -2
Models.py CHANGED
@@ -134,13 +134,13 @@ class VisionModel(nn.Module):
134
  from safetensors.torch import load_file
135
  resume = load_file(Path(path) / 'model.safetensors', device='cpu')
136
  else:
137
- resume = torch.load(Path(path) / 'model.pt', map_location=torch.device('cpu'))
138
 
139
  model_classes = VisionModel.__subclasses__()
140
  model_cls = next(cls for cls in model_classes if cls.__name__ == config['class'])
141
 
142
  model = model_cls(**{k: v for k, v in config.items() if k != 'class'})
143
- model.load(resume['model'])
144
  if device is not None:
145
  model = model.to(device)
146
 
 
134
  from safetensors.torch import load_file
135
  resume = load_file(Path(path) / 'model.safetensors', device='cpu')
136
  else:
137
+ resume = torch.load(Path(path) / 'model.pt', map_location=torch.device('cpu'))['model']
138
 
139
  model_classes = VisionModel.__subclasses__()
140
  model_cls = next(cls for cls in model_classes if cls.__name__ == config['class'])
141
 
142
  model = model_cls(**{k: v for k, v in config.items() if k != 'class'})
143
+ model.load(resume)
144
  if device is not None:
145
  model = model.to(device)
146