790 lines
22 KiB
Python
Executable File
790 lines
22 KiB
Python
Executable File
#!/usr/bin/env python3.12
|
|
|
|
import os
|
|
import warnings
|
|
import pandas as pd
|
|
import numpy as np
|
|
from typing import List, Dict
|
|
from matplotlib import pyplot as plt
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import seaborn as sns
|
|
import argparse
|
|
from imblearn.over_sampling import RandomOverSampler
|
|
from imblearn.under_sampling import RandomUnderSampler
|
|
from collections import Counter
|
|
from sklearn.model_selection import StratifiedKFold
|
|
from sklearn.model_selection import train_test_split
|
|
from datetime import datetime
|
|
|
|
from load_dataset import get_dataset
|
|
from model import SimpleLSTM
|
|
from training import training_loop
|
|
from validation import validation
|
|
|
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
|
|
|
|
def setup_model_training(
|
|
input_size,
|
|
hidden_size,
|
|
num_layers,
|
|
num_classes,
|
|
sequence_length,
|
|
device,
|
|
lr,
|
|
weight_decay,
|
|
eps,
|
|
):
|
|
model = SimpleLSTM(
|
|
input_size=input_size,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
num_classes=num_classes,
|
|
sequence_length=sequence_length,
|
|
device=device,
|
|
)
|
|
|
|
optimizer = optim.Adam(
|
|
model.parameters(), lr=lr, weight_decay=weight_decay, eps=eps
|
|
)
|
|
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer, mode="min", factor=0.1, patience=25
|
|
)
|
|
|
|
# criterion = nn.CrossEntropyLoss()
|
|
# criterion = nn.BCELoss()
|
|
criterion = nn.BCEWithLogitsLoss()
|
|
|
|
return (model, optimizer, scheduler, criterion)
|
|
|
|
|
|
def create_loaders(
|
|
x, y, data, under=False, over=False, classes=["base", "yawn"], batch_size=2
|
|
):
|
|
if under is True:
|
|
x, y = undersample(x, y, data)
|
|
|
|
if over is True:
|
|
x, y = oversample(x, y, data)
|
|
|
|
x_train, x_test, y_train, y_test = train_test_split(
|
|
x, y, test_size=TEST_OPTIMAL_SIZE, random_state=seed
|
|
)
|
|
|
|
train_dataset = FetusDataset(
|
|
[{"data": x, "label": y} for x, y in zip(x_train, y_train)],
|
|
train=True,
|
|
classes=len(classes),
|
|
)
|
|
|
|
test_dataset = FetusDataset(
|
|
[{"data": x, "label": y} for x, y in zip(x_test, y_test)],
|
|
train=True,
|
|
classes=len(classes),
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
)
|
|
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
)
|
|
|
|
return (train_loader, test_loader)
|
|
|
|
|
|
def get_device() -> torch.device:
|
|
if torch.backends.mps.is_built():
|
|
return torch.device("mps")
|
|
|
|
if torch.cuda.is_available():
|
|
return torch.device("cuda")
|
|
|
|
return torch.device("cpu")
|
|
|
|
|
|
def undersample(x: np.ndarray, y: np.ndarray, data):
|
|
print("Before undersampling")
|
|
print(Counter([d["label"] for d in data]))
|
|
|
|
rus = RandomUnderSampler(random_state=seed, sampling_strategy="majority")
|
|
x = np.array([d["data"] for d in data])
|
|
y = np.array([d["label"] for d in data])
|
|
flat_x = np.array([x.flatten() for x in x])
|
|
flat_y = np.array(y)
|
|
x, y = rus.fit_resample(flat_x, flat_y)
|
|
x = x.reshape(-1, SERIES_LENGTH, FEATURE_SIZE)
|
|
|
|
print("After undersampling")
|
|
print(Counter(y))
|
|
return x, y
|
|
|
|
|
|
def oversample(x: np.ndarray, y: np.ndarray, data):
|
|
print("Before oversampling")
|
|
print(Counter([d["label"] for d in data]))
|
|
|
|
ros = RandomOverSampler(random_state=seed, sampling_strategy="all")
|
|
x = np.array([d["data"] for d in data])
|
|
y = np.array([d["label"] for d in data])
|
|
flat_x = np.array([x.flatten() for x in x])
|
|
flat_y = np.array(y)
|
|
x, y = ros.fit_resample(flat_x, flat_y)
|
|
x = x.reshape(-1, SERIES_LENGTH, FEATURE_SIZE)
|
|
|
|
print("After oversampling")
|
|
print(Counter(y))
|
|
return x, y
|
|
|
|
|
|
class FetusDataset(Dataset):
|
|
def __init__(
|
|
self, data: List[Dict[np.ndarray, int]], train: bool = False, classes: int = 2
|
|
):
|
|
self.data = data
|
|
self.train = train
|
|
self.classes = classes
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
x = self.data[idx]["data"]
|
|
y = self.data[idx]["label"]
|
|
|
|
# Conversione del tipo di dato
|
|
x = x.astype(np.float32)
|
|
y = np.eye(self.classes)[y]
|
|
|
|
# Conversione in tensori
|
|
x = torch.tensor(x, dtype=torch.float32)
|
|
y = torch.tensor(y, dtype=torch.int32)
|
|
|
|
# Gestione di valori NaN o infiniti
|
|
x = torch.nan_to_num(
|
|
x
|
|
) # Sostituisce NaN con 0 e valori infiniti con numeri molto grandi o piccoli
|
|
|
|
# Normalizzazione solo durante il training
|
|
if self.train:
|
|
mean = x.mean()
|
|
std = x.std()
|
|
|
|
# Normalizzazione condizionale (solo se std > 0)
|
|
if std > 0:
|
|
x = (x - mean) / std
|
|
|
|
return x, y
|
|
|
|
|
|
def createArgParser():
|
|
parser = argparse.ArgumentParser(description="Womb Wise")
|
|
parser.add_argument(
|
|
"-rd",
|
|
"--reload-dataset",
|
|
action="store_true",
|
|
help="Reload the dataset",
|
|
)
|
|
|
|
# path to the dataset
|
|
parser.add_argument(
|
|
"-p",
|
|
"--path",
|
|
action="store",
|
|
help="Path to the dataset",
|
|
default="~/Documents/womb-wise/Data/",
|
|
)
|
|
|
|
# epoch
|
|
parser.add_argument(
|
|
"-e",
|
|
"--epochs",
|
|
action="store",
|
|
help="Number of epochs",
|
|
default=10,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-k",
|
|
"--kfold",
|
|
action="store",
|
|
help="Number of folds for kfold cross validation",
|
|
default=1,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-o",
|
|
"--oversampling",
|
|
action="store_true",
|
|
help="Apply oversampling",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-u",
|
|
"--undersampling",
|
|
action="store_true",
|
|
help="Apply undersampling",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-d",
|
|
"--dataset",
|
|
action="store",
|
|
default="all",
|
|
choices=["all", "fetus", "mother", "fetus-mother", "mother-fetus"],
|
|
help="Choose the dataset: all, fetus, mother or train with mother and test with fetus or viceversa",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
print(
|
|
f"""
|
|
ARGS:
|
|
\n
|
|
reload-dataset: {args.reload_dataset}
|
|
path: {args.path}
|
|
epochs: {args.epochs}
|
|
kfold: {args.kfold}
|
|
oversampling: {args.oversampling}
|
|
undersampling: {args.undersampling}
|
|
dataset: {args.dataset}
|
|
"""
|
|
)
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
CLASSES = ["baseline", "opcl", "yawn"]
|
|
FEATURE_SIZE = 10
|
|
SERIES_LENGTH = 60
|
|
# SINGLE_FRAME_LENGTH = FEATURE_SIZE * SERIES_LENGTH
|
|
BATCH_SIZE = 4
|
|
WEIGHT_DECAY = 1e-5
|
|
LEARNING_RATE = 1e-3
|
|
TEST_OPTIMAL_SIZE = 0.2
|
|
HIDDEN_SIZE = 256
|
|
DROP_OUT = 0.0
|
|
NUM_LAYERS = 2
|
|
EPS = 1e-7
|
|
|
|
# TEST_NAME = "0_k1_all"
|
|
# TEST_NAME = "1_k1_fetus"
|
|
# TEST_NAME = "2_k1_mother"
|
|
# TEST_NAME = "3_k1_mother_fetus"
|
|
# TEST_NAME = "4_k1_fetus_mother"
|
|
# TEST_NAME = "5_k5_all"
|
|
TEST_NAME = "6_k5_fetus"
|
|
# TEST_NAME = "7_k5_mother"
|
|
|
|
if not os.path.exists("output/" + TEST_NAME):
|
|
os.makedirs("output/" + TEST_NAME)
|
|
|
|
if not os.path.exists("output/" + TEST_NAME + "/weights"):
|
|
os.makedirs("output/" + TEST_NAME + "/weights")
|
|
|
|
if not os.path.exists("output/" + TEST_NAME + "/confusion_matrix"):
|
|
os.makedirs("output/" + TEST_NAME + "/confusion_matrix")
|
|
|
|
if not os.path.exists("output/" + TEST_NAME + "/metrics"):
|
|
os.makedirs("output/" + TEST_NAME + "/metrics")
|
|
|
|
# fix the seed
|
|
seed = 42
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
device = get_device()
|
|
args = createArgParser()
|
|
|
|
PATH = args.path
|
|
EPOCHS = int(args.epochs)
|
|
K_FOLD = int(args.kfold)
|
|
OVER_SAMPLING = args.oversampling
|
|
UNDER_SAMPLING = args.undersampling
|
|
EARLY_STOPPING = True
|
|
DATASET_TYPE = args.dataset
|
|
|
|
if os.path.exists("dataset.csv") and args.reload_dataset is False:
|
|
dataset = pd.read_csv("dataset.csv")
|
|
mother = pd.read_csv("mother.csv")
|
|
fetus = pd.read_csv("fetus.csv")
|
|
else:
|
|
baseline_fetus = get_dataset(
|
|
"~/Documents/kanopo/womb-wise/Data/Ultrasound_Scans/tracked_frames/",
|
|
"baseline",
|
|
)
|
|
yawn_fetus = get_dataset(
|
|
"~/Documents/kanopo/womb-wise/Data/Ultrasound_Scans/tracked_frames/",
|
|
"yawn",
|
|
)
|
|
opcl_fetus = get_dataset(
|
|
"~/Documents/kanopo/womb-wise/Data/Ultrasound_Scans/tracked_frames/",
|
|
"opcl",
|
|
)
|
|
|
|
fetus = pd.concat([baseline_fetus, yawn_fetus, opcl_fetus])
|
|
|
|
baseline_mother = get_dataset(
|
|
"~/Documents/kanopo/womb-wise/Data/Mothers_videos/Tracked/",
|
|
"baseline",
|
|
)
|
|
yawn_mother = get_dataset(
|
|
"~/Documents/kanopo/womb-wise/Data/Mothers_videos/Tracked/",
|
|
"yawn",
|
|
)
|
|
opcl_mother = get_dataset(
|
|
"~/Documents/kanopo/womb-wise/Data/Mothers_videos/Tracked/",
|
|
"opcl",
|
|
)
|
|
|
|
mother = pd.concat([baseline_mother, yawn_mother, opcl_mother])
|
|
fetus["type"] = "fetus"
|
|
mother["type"] = "mother"
|
|
|
|
fetus.to_csv("fetus.csv")
|
|
mother.to_csv("mother.csv")
|
|
|
|
dataset = pd.concat([mother, fetus])
|
|
dataset.to_csv("dataset.csv")
|
|
|
|
|
|
|
|
mother = mother.drop(columns=["top_bottom_distance"])
|
|
fetus = fetus.drop(columns=["top_bottom_distance"])
|
|
dataset = dataset.drop(columns=["top_bottom_distance"])
|
|
grouped_dataset = dataset.groupby(["label", "frame", "test", "type"])
|
|
grouped_mother = mother.groupby(["label", "frame", "test"])
|
|
grouped_fetus = fetus.groupby(["label", "frame", "test"])
|
|
|
|
data: List[Dict[np.ndarray, int]] = []
|
|
mother_data: List[Dict[np.ndarray, int]] = []
|
|
fetus_data: List[Dict[np.ndarray, int]] = []
|
|
|
|
for name, group in grouped_mother:
|
|
label = group["label"]
|
|
frame = group["frame"]
|
|
test = group["test"]
|
|
|
|
group = group.drop(columns=["test", "frame", "label", "type"])
|
|
|
|
group.set_index("image_name", inplace=True)
|
|
|
|
if group.columns[0] == "Unnamed: 0":
|
|
group = group.drop(columns=["Unnamed: 0"])
|
|
|
|
group = group.to_numpy()
|
|
|
|
if group.shape[0] < SERIES_LENGTH:
|
|
group = np.vstack(
|
|
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
|
)
|
|
|
|
elif group.shape[0] > SERIES_LENGTH:
|
|
group = group[:SERIES_LENGTH]
|
|
|
|
group = group.astype(np.float32)
|
|
|
|
label = CLASSES.index(label.iat[0])
|
|
mother_data.append(
|
|
{
|
|
"data": group,
|
|
"label": label,
|
|
}
|
|
)
|
|
|
|
for name, group in grouped_fetus:
|
|
label = group["label"]
|
|
frame = group["frame"]
|
|
test = group["test"]
|
|
|
|
group = group.drop(columns=["test", "frame", "label", "type"])
|
|
|
|
group.set_index("image_name", inplace=True)
|
|
|
|
if group.columns[0] == "Unnamed: 0":
|
|
group = group.drop(columns=["Unnamed: 0"])
|
|
|
|
group = group.to_numpy()
|
|
|
|
if group.shape[0] < SERIES_LENGTH:
|
|
group = np.vstack(
|
|
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
|
)
|
|
|
|
elif group.shape[0] > SERIES_LENGTH:
|
|
group = group[:SERIES_LENGTH]
|
|
|
|
group = group.astype(np.float32)
|
|
|
|
label = CLASSES.index(label.iat[0])
|
|
fetus_data.append(
|
|
{
|
|
"data": group,
|
|
"label": label,
|
|
}
|
|
)
|
|
|
|
for name, group in grouped_dataset:
|
|
label = group["label"]
|
|
frame = group["frame"]
|
|
test = group["test"]
|
|
data_type = group["type"]
|
|
|
|
group = group.drop(columns=["test", "frame", "label", "type"])
|
|
|
|
group.set_index("image_name", inplace=True)
|
|
|
|
if group.columns[0] == "Unnamed: 0":
|
|
group = group.drop(columns=["Unnamed: 0"])
|
|
|
|
group = group.to_numpy()
|
|
|
|
if group.shape[0] < SERIES_LENGTH:
|
|
group = np.vstack(
|
|
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
|
)
|
|
|
|
elif group.shape[0] > SERIES_LENGTH:
|
|
group = group[:SERIES_LENGTH]
|
|
|
|
group = group.astype(np.float32)
|
|
|
|
label = CLASSES.index(label.iat[0])
|
|
data.append(
|
|
{
|
|
"data": group,
|
|
"label": label,
|
|
}
|
|
)
|
|
|
|
if K_FOLD == 1:
|
|
|
|
x_all = [d["data"] for d in data]
|
|
y_all = [d["label"] for d in data]
|
|
|
|
x_mother = [d["data"] for d in mother_data]
|
|
y_mother = [d["label"] for d in mother_data]
|
|
|
|
x_fetus = [d["data"] for d in fetus_data]
|
|
y_fetus = [d["label"] for d in fetus_data]
|
|
|
|
(train_loader_all, test_loader_all) = create_loaders(
|
|
x_all,
|
|
y_all,
|
|
data,
|
|
over=OVER_SAMPLING,
|
|
under=UNDER_SAMPLING,
|
|
classes=CLASSES,
|
|
batch_size=BATCH_SIZE,
|
|
)
|
|
|
|
(train_loader_mother, test_loader_mother) = create_loaders(
|
|
x_mother,
|
|
y_mother,
|
|
mother_data,
|
|
over=OVER_SAMPLING,
|
|
under=UNDER_SAMPLING,
|
|
classes=CLASSES,
|
|
batch_size=BATCH_SIZE,
|
|
)
|
|
|
|
(train_loader_fetus, test_loader_fetus) = create_loaders(
|
|
x_fetus,
|
|
y_fetus,
|
|
fetus_data,
|
|
over=OVER_SAMPLING,
|
|
under=UNDER_SAMPLING,
|
|
classes=CLASSES,
|
|
batch_size=BATCH_SIZE,
|
|
)
|
|
|
|
(model, optimizer, scheduler, criterion) = setup_model_training(
|
|
input_size=FEATURE_SIZE,
|
|
hidden_size=HIDDEN_SIZE,
|
|
num_layers=NUM_LAYERS,
|
|
num_classes=len(CLASSES),
|
|
sequence_length=SERIES_LENGTH,
|
|
device=device,
|
|
lr=LEARNING_RATE,
|
|
weight_decay=WEIGHT_DECAY,
|
|
eps=EPS,
|
|
)
|
|
|
|
if DATASET_TYPE == "all":
|
|
train_loader = train_loader_all
|
|
test_loader = test_loader_all
|
|
|
|
elif DATASET_TYPE == "fetus":
|
|
train_loader = train_loader_fetus
|
|
test_loader = test_loader_fetus
|
|
|
|
elif DATASET_TYPE == "mother":
|
|
train_loader = train_loader_mother
|
|
test_loader = test_loader_mother
|
|
|
|
elif DATASET_TYPE == "fetus-mother":
|
|
train_loader = train_loader_fetus
|
|
test_loader = test_loader_mother
|
|
|
|
elif DATASET_TYPE == "mother-fetus":
|
|
train_loader = train_loader_mother
|
|
test_loader = test_loader_fetus
|
|
|
|
else:
|
|
Exception("Invalid dataset type")
|
|
|
|
trained_model = training_loop(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
test_loader=test_loader,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
criterion=criterion,
|
|
device=device,
|
|
epochs=EPOCHS,
|
|
early_stopping=EARLY_STOPPING,
|
|
log_dir="output/" + TEST_NAME + "/metrics",
|
|
)
|
|
|
|
loss, conf_matrix, classification_rep = validation(
|
|
trained_model,
|
|
test_loader,
|
|
criterion,
|
|
device,
|
|
)
|
|
|
|
# save classification report to a file
|
|
df = pd.DataFrame(classification_rep).transpose()
|
|
df.to_csv("output/" + TEST_NAME + "/metrics/classification_report.csv")
|
|
|
|
torch.save(
|
|
trained_model.state_dict(),
|
|
"output/" + TEST_NAME + "/weights/model.pth",
|
|
)
|
|
|
|
plt.figure(figsize=(19.20, 10.80))
|
|
plt.title("Confusion Matrix")
|
|
sns.heatmap(
|
|
conf_matrix,
|
|
annot=True,
|
|
fmt=".2f",
|
|
xticklabels=["Baseline", "Opcl", "Yawn"],
|
|
yticklabels=["Baseline", "Opcl", "Yawn"],
|
|
cmap="viridis",
|
|
)
|
|
|
|
plt.xlabel("Predicted")
|
|
plt.ylabel("Actual")
|
|
|
|
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix.png")
|
|
|
|
plt.figure(figsize=(19.20, 10.80))
|
|
plt.title("Confusion Matrix Percentage")
|
|
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
|
sns.heatmap(
|
|
conf_matrix_percent,
|
|
annot=True,
|
|
fmt=".2f",
|
|
xticklabels=["Baseline", "Opcl", "Yawn"],
|
|
yticklabels=["Baseline", "Opcl", "Yawn"],
|
|
cmap="viridis",
|
|
)
|
|
|
|
plt.xlabel("Predicted")
|
|
plt.ylabel("Actual")
|
|
|
|
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix_percentage.png")
|
|
|
|
else:
|
|
x_all = [d["data"] for d in data]
|
|
y_all = [d["label"] for d in data]
|
|
|
|
x_mother = [d["data"] for d in mother_data]
|
|
y_mother = [d["label"] for d in mother_data]
|
|
|
|
x_fetus = [d["data"] for d in fetus_data]
|
|
y_fetus = [d["label"] for d in fetus_data]
|
|
|
|
(train_loader_all, test_loader_all) = create_loaders(
|
|
x_all,
|
|
y_all,
|
|
data,
|
|
over=OVER_SAMPLING,
|
|
under=UNDER_SAMPLING,
|
|
classes=CLASSES,
|
|
batch_size=BATCH_SIZE,
|
|
)
|
|
|
|
(train_loader_mother, test_loader_mother) = create_loaders(
|
|
x_mother,
|
|
y_mother,
|
|
mother_data,
|
|
over=OVER_SAMPLING,
|
|
under=UNDER_SAMPLING,
|
|
classes=CLASSES,
|
|
batch_size=BATCH_SIZE,
|
|
)
|
|
|
|
(train_loader_fetus, test_loader_fetus) = create_loaders(
|
|
x_fetus,
|
|
y_fetus,
|
|
fetus_data,
|
|
over=OVER_SAMPLING,
|
|
under=UNDER_SAMPLING,
|
|
classes=CLASSES,
|
|
batch_size=BATCH_SIZE,
|
|
)
|
|
|
|
(model, optimizer, scheduler, criterion) = setup_model_training(
|
|
input_size=FEATURE_SIZE,
|
|
hidden_size=HIDDEN_SIZE,
|
|
num_layers=NUM_LAYERS,
|
|
num_classes=len(CLASSES),
|
|
sequence_length=SERIES_LENGTH,
|
|
device=device,
|
|
lr=LEARNING_RATE,
|
|
weight_decay=WEIGHT_DECAY,
|
|
eps=EPS,
|
|
)
|
|
|
|
kf = StratifiedKFold(n_splits=K_FOLD, shuffle=True, random_state=seed)
|
|
model_index = 0
|
|
|
|
x = None
|
|
y = None
|
|
|
|
if DATASET_TYPE == "all":
|
|
x = [d["data"] for d in data]
|
|
y = [d["label"] for d in data]
|
|
elif DATASET_TYPE == "fetus":
|
|
x = [d["data"] for d in fetus_data]
|
|
y = [d["label"] for d in fetus_data]
|
|
elif DATASET_TYPE == "mother":
|
|
x = [d["data"] for d in mother_data]
|
|
y = [d["label"] for d in mother_data]
|
|
|
|
# TODO: HOW TO HANDLE THIS CASES? for the mixed training and validation
|
|
else:
|
|
Exception("Invalid dataset type")
|
|
|
|
for train_index, test_index in kf.split(X=x, y=y):
|
|
train_data = [data[i] for i in train_index]
|
|
test_data = [data[i] for i in test_index]
|
|
|
|
data = train_data + test_data
|
|
|
|
x = [d["data"] for d in train_data]
|
|
y = [d["label"] for d in train_data]
|
|
|
|
(train_loader, test_loader) = create_loaders(
|
|
x,
|
|
y,
|
|
data,
|
|
over=OVER_SAMPLING,
|
|
under=UNDER_SAMPLING,
|
|
classes=CLASSES,
|
|
batch_size=BATCH_SIZE,
|
|
)
|
|
|
|
(model, optimizer, scheduler, criterion) = setup_model_training(
|
|
input_size=FEATURE_SIZE,
|
|
hidden_size=HIDDEN_SIZE,
|
|
num_layers=NUM_LAYERS,
|
|
num_classes=len(CLASSES),
|
|
sequence_length=SERIES_LENGTH,
|
|
device=device,
|
|
lr=LEARNING_RATE,
|
|
weight_decay=WEIGHT_DECAY,
|
|
eps=EPS,
|
|
)
|
|
|
|
trained_model = training_loop(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
test_loader=test_loader,
|
|
optimizer=optimizer,
|
|
scheduler=scheduler,
|
|
criterion=criterion,
|
|
device=device,
|
|
epochs=EPOCHS,
|
|
early_stopping=EARLY_STOPPING,
|
|
log_dir="output/" + TEST_NAME + "/metrics/" + f"{model_index}",
|
|
)
|
|
|
|
loss, conf_matrix, classification_rep = validation(
|
|
trained_model,
|
|
test_loader,
|
|
criterion,
|
|
device,
|
|
)
|
|
|
|
# save classification report to a file
|
|
df = pd.DataFrame(classification_rep).transpose()
|
|
df.to_csv(
|
|
"output/"
|
|
+ TEST_NAME
|
|
+ "/metrics/classification_report_"
|
|
+ str(model_index)
|
|
+ ".csv"
|
|
)
|
|
|
|
torch.save(
|
|
trained_model.state_dict(),
|
|
"output/" + TEST_NAME + "/weights/model_" + str(model_index) + ".pth",
|
|
)
|
|
plt.figure(figsize=(19.20, 10.80))
|
|
plt.title("Confusion Matrix")
|
|
sns.heatmap(
|
|
conf_matrix,
|
|
annot=True,
|
|
fmt=".2f",
|
|
xticklabels=["Baseline", "Opcl", "Yawn"],
|
|
yticklabels=["Baseline", "Opcl", "Yawn"],
|
|
cmap="viridis",
|
|
)
|
|
|
|
plt.xlabel("Predicted")
|
|
plt.ylabel("Actual")
|
|
plt.savefig(
|
|
"output/"
|
|
+ TEST_NAME
|
|
+ "/confusion_matrix/confusion_matrix_"
|
|
+ str(model_index)
|
|
+ ".png"
|
|
)
|
|
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
|
plt.figure(figsize=(19.20, 10.80))
|
|
plt.title("Confusion Matrix Percentage")
|
|
sns.heatmap(
|
|
conf_matrix_percent,
|
|
annot=True,
|
|
fmt=".2f",
|
|
xticklabels=["Baseline", "Opcl", "Yawn"],
|
|
yticklabels=["Baseline", "Opcl", "Yawn"],
|
|
cmap="viridis",
|
|
)
|
|
|
|
plt.xlabel("Predicted")
|
|
plt.ylabel("Actual")
|
|
plt.savefig(
|
|
"output/"
|
|
+ TEST_NAME
|
|
+ "/confusion_matrix/confusion_matrix_percentage_"
|
|
+ str(model_index)
|
|
+ ".png"
|
|
)
|
|
model_index += 1
|