OriLib commited on
Commit
65ecd1d
1 Parent(s): b0fb67c

Update example_inference.py

Browse files
Files changed (1) hide show
  1. example_inference.py +2 -4
example_inference.py CHANGED
@@ -10,10 +10,8 @@ def example_inference():
10
  im_path = f"{os.path.dirname(__file__)}/example_input.jpg"
11
 
12
  net = BriaRMBG()
13
- if torch.cuda.is_available():
14
- net.load_state_dict(torch.load(model_path)).cuda()
15
- else:
16
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
17
  net.eval()
18
 
19
  # prepare input
 
10
  im_path = f"{os.path.dirname(__file__)}/example_input.jpg"
11
 
12
  net = BriaRMBG()
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ net.load_state_dict(torch.load(model_path, map_location=device))
 
 
15
  net.eval()
16
 
17
  # prepare input