File size: 1,207 Bytes
ff8e6c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from utils.datasets import ECGDataset
from utils.models import EffNet
from utils.training_models import RegressionModel


# +
# This is the path where your data samples are stored.
data_path = "your/ecg/data/folder"

# This is the path where your manifest, containing filenames for inference to be run on, is stored.
manifest_path = 'your/manifest/path'
# -


# Initialize a dataset that contains the examples you want to run prediction on. 
test_ds = ECGDataset(
    split="test",
    data_path=data_path,
    manifest_path=manifest_path,
    update_manifest_func=None,
)

# Wrap the dataset in a dataloader to handle batching and multithreading.
test_dl = DataLoader(
    test_ds, 
    num_workers=16, 
    batch_size=256, 
    drop_last=False, 
    shuffle=False
)

# Initialize the "backbone", the core model weights that will act on the data.
backbone = EffNet(input_channels=12, output_neurons=1)

model = RegressionModel(backbone)

weights = torch.load("model_12_lead.pt")
print(model.load_state_dict(weights))

# +
trainer = Trainer(accelerator="gpu", devices=1)

trainer.predict(model, dataloaders=test_dl)