kardionet / predict_potassium_12lead.py
outofray's picture
copy data from repo
ff8e6c1
raw
history blame
1.21 kB
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from utils.datasets import ECGDataset
from utils.models import EffNet
from utils.training_models import RegressionModel
# +
# This is the path where your data samples are stored.
data_path = "your/ecg/data/folder"
# This is the path where your manifest, containing filenames for inference to be run on, is stored.
manifest_path = 'your/manifest/path'
# -
# Initialize a dataset that contains the examples you want to run prediction on.
test_ds = ECGDataset(
split="test",
data_path=data_path,
manifest_path=manifest_path,
update_manifest_func=None,
)
# Wrap the dataset in a dataloader to handle batching and multithreading.
test_dl = DataLoader(
test_ds,
num_workers=16,
batch_size=256,
drop_last=False,
shuffle=False
)
# Initialize the "backbone", the core model weights that will act on the data.
backbone = EffNet(input_channels=12, output_neurons=1)
model = RegressionModel(backbone)
weights = torch.load("model_12_lead.pt")
print(model.load_state_dict(weights))
# +
trainer = Trainer(accelerator="gpu", devices=1)
trainer.predict(model, dataloaders=test_dl)