freyza commited on
Commit
8f2d077
1 Parent(s): 4cc130e

Update src/mdx.py

Browse files
Files changed (1) hide show
  1. src/mdx.py +4 -2
src/mdx.py CHANGED
@@ -65,8 +65,10 @@ class MDX:
65
  def __init__(self, model_path: str, params: MDXModel, processor=DEFAULT_PROCESSOR):
66
 
67
  # Set the device and the provider (CPU or CUDA)
68
- self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
69
- self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
 
 
70
 
71
  self.model = params
72
 
 
65
  def __init__(self, model_path: str, params: MDXModel, processor=DEFAULT_PROCESSOR):
66
 
67
  # Set the device and the provider (CPU or CUDA)
68
+ #self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
69
+ self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
70
+ #self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
71
+ self.provider = ['CPUExecutionProvider']
72
 
73
  self.model = params
74