import torch import logging def select_device(min_memory = 2048): logger = logging.getLogger(__name__) if torch.cuda.is_available(): available_gpus = [] for i in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(i) free_memory = props.total_memory - torch.cuda.memory_reserved(i) available_gpus.append((i, free_memory)) selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1]) device = torch.device(f'cuda:{selected_gpu}') free_memory_mb = max_free_memory / (1024 * 1024) if free_memory_mb < min_memory: logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.') device = torch.device('cpu') else: logger.log(logging.WARNING, f'No GPU found, use CPU instead') device = torch.device('cpu') return device