2024-11-29 16:19:05 +01:00

790 lines
22 KiB
Python
Executable File

#!/usr/bin/env python3.12
import os
import warnings
import pandas as pd
import numpy as np
from typing import List, Dict
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
import argparse
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from datetime import datetime
from load_dataset import get_dataset
from model import SimpleLSTM
from training import training_loop
from validation import validation
warnings.simplefilter(action="ignore", category=FutureWarning)
def setup_model_training(
input_size,
hidden_size,
num_layers,
num_classes,
sequence_length,
device,
lr,
weight_decay,
eps,
):
model = SimpleLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
num_classes=num_classes,
sequence_length=sequence_length,
device=device,
)
optimizer = optim.Adam(
model.parameters(), lr=lr, weight_decay=weight_decay, eps=eps
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=25
)
# criterion = nn.CrossEntropyLoss()
# criterion = nn.BCELoss()
criterion = nn.BCEWithLogitsLoss()
return (model, optimizer, scheduler, criterion)
def create_loaders(
x, y, data, under=False, over=False, classes=["base", "yawn"], batch_size=2
):
if under is True:
x, y = undersample(x, y, data)
if over is True:
x, y = oversample(x, y, data)
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=TEST_OPTIMAL_SIZE, random_state=seed
)
train_dataset = FetusDataset(
[{"data": x, "label": y} for x, y in zip(x_train, y_train)],
train=True,
classes=len(classes),
)
test_dataset = FetusDataset(
[{"data": x, "label": y} for x, y in zip(x_test, y_test)],
train=True,
classes=len(classes),
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
return (train_loader, test_loader)
def get_device() -> torch.device:
if torch.backends.mps.is_built():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def undersample(x: np.ndarray, y: np.ndarray, data):
print("Before undersampling")
print(Counter([d["label"] for d in data]))
rus = RandomUnderSampler(random_state=seed, sampling_strategy="majority")
x = np.array([d["data"] for d in data])
y = np.array([d["label"] for d in data])
flat_x = np.array([x.flatten() for x in x])
flat_y = np.array(y)
x, y = rus.fit_resample(flat_x, flat_y)
x = x.reshape(-1, SERIES_LENGTH, FEATURE_SIZE)
print("After undersampling")
print(Counter(y))
return x, y
def oversample(x: np.ndarray, y: np.ndarray, data):
print("Before oversampling")
print(Counter([d["label"] for d in data]))
ros = RandomOverSampler(random_state=seed, sampling_strategy="all")
x = np.array([d["data"] for d in data])
y = np.array([d["label"] for d in data])
flat_x = np.array([x.flatten() for x in x])
flat_y = np.array(y)
x, y = ros.fit_resample(flat_x, flat_y)
x = x.reshape(-1, SERIES_LENGTH, FEATURE_SIZE)
print("After oversampling")
print(Counter(y))
return x, y
class FetusDataset(Dataset):
def __init__(
self, data: List[Dict[np.ndarray, int]], train: bool = False, classes: int = 2
):
self.data = data
self.train = train
self.classes = classes
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx]["data"]
y = self.data[idx]["label"]
# Conversione del tipo di dato
x = x.astype(np.float32)
y = np.eye(self.classes)[y]
# Conversione in tensori
x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.int32)
# Gestione di valori NaN o infiniti
x = torch.nan_to_num(
x
) # Sostituisce NaN con 0 e valori infiniti con numeri molto grandi o piccoli
# Normalizzazione solo durante il training
if self.train:
mean = x.mean()
std = x.std()
# Normalizzazione condizionale (solo se std > 0)
if std > 0:
x = (x - mean) / std
return x, y
def createArgParser():
parser = argparse.ArgumentParser(description="Womb Wise")
parser.add_argument(
"-rd",
"--reload-dataset",
action="store_true",
help="Reload the dataset",
)
# path to the dataset
parser.add_argument(
"-p",
"--path",
action="store",
help="Path to the dataset",
default="~/Documents/womb-wise/Data/",
)
# epoch
parser.add_argument(
"-e",
"--epochs",
action="store",
help="Number of epochs",
default=10,
)
parser.add_argument(
"-k",
"--kfold",
action="store",
help="Number of folds for kfold cross validation",
default=1,
)
parser.add_argument(
"-o",
"--oversampling",
action="store_true",
help="Apply oversampling",
)
parser.add_argument(
"-u",
"--undersampling",
action="store_true",
help="Apply undersampling",
)
parser.add_argument(
"-d",
"--dataset",
action="store",
default="all",
choices=["all", "fetus", "mother", "fetus-mother", "mother-fetus"],
help="Choose the dataset: all, fetus, mother or train with mother and test with fetus or viceversa",
)
args = parser.parse_args()
print(
f"""
ARGS:
\n
reload-dataset: {args.reload_dataset}
path: {args.path}
epochs: {args.epochs}
kfold: {args.kfold}
oversampling: {args.oversampling}
undersampling: {args.undersampling}
dataset: {args.dataset}
"""
)
return args
if __name__ == "__main__":
CLASSES = ["baseline", "opcl", "yawn"]
FEATURE_SIZE = 10
SERIES_LENGTH = 60
# SINGLE_FRAME_LENGTH = FEATURE_SIZE * SERIES_LENGTH
BATCH_SIZE = 4
WEIGHT_DECAY = 1e-5
LEARNING_RATE = 1e-3
TEST_OPTIMAL_SIZE = 0.2
HIDDEN_SIZE = 256
DROP_OUT = 0.0
NUM_LAYERS = 2
EPS = 1e-7
# TEST_NAME = "0_k1_all"
# TEST_NAME = "1_k1_fetus"
# TEST_NAME = "2_k1_mother"
# TEST_NAME = "3_k1_mother_fetus"
# TEST_NAME = "4_k1_fetus_mother"
# TEST_NAME = "5_k5_all"
TEST_NAME = "6_k5_fetus"
# TEST_NAME = "7_k5_mother"
if not os.path.exists("output/" + TEST_NAME):
os.makedirs("output/" + TEST_NAME)
if not os.path.exists("output/" + TEST_NAME + "/weights"):
os.makedirs("output/" + TEST_NAME + "/weights")
if not os.path.exists("output/" + TEST_NAME + "/confusion_matrix"):
os.makedirs("output/" + TEST_NAME + "/confusion_matrix")
if not os.path.exists("output/" + TEST_NAME + "/metrics"):
os.makedirs("output/" + TEST_NAME + "/metrics")
# fix the seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = get_device()
args = createArgParser()
PATH = args.path
EPOCHS = int(args.epochs)
K_FOLD = int(args.kfold)
OVER_SAMPLING = args.oversampling
UNDER_SAMPLING = args.undersampling
EARLY_STOPPING = True
DATASET_TYPE = args.dataset
if os.path.exists("dataset.csv") and args.reload_dataset is False:
dataset = pd.read_csv("dataset.csv")
mother = pd.read_csv("mother.csv")
fetus = pd.read_csv("fetus.csv")
else:
baseline_fetus = get_dataset(
"~/Documents/kanopo/womb-wise/Data/Ultrasound_Scans/tracked_frames/",
"baseline",
)
yawn_fetus = get_dataset(
"~/Documents/kanopo/womb-wise/Data/Ultrasound_Scans/tracked_frames/",
"yawn",
)
opcl_fetus = get_dataset(
"~/Documents/kanopo/womb-wise/Data/Ultrasound_Scans/tracked_frames/",
"opcl",
)
fetus = pd.concat([baseline_fetus, yawn_fetus, opcl_fetus])
baseline_mother = get_dataset(
"~/Documents/kanopo/womb-wise/Data/Mothers_videos/Tracked/",
"baseline",
)
yawn_mother = get_dataset(
"~/Documents/kanopo/womb-wise/Data/Mothers_videos/Tracked/",
"yawn",
)
opcl_mother = get_dataset(
"~/Documents/kanopo/womb-wise/Data/Mothers_videos/Tracked/",
"opcl",
)
mother = pd.concat([baseline_mother, yawn_mother, opcl_mother])
fetus["type"] = "fetus"
mother["type"] = "mother"
fetus.to_csv("fetus.csv")
mother.to_csv("mother.csv")
dataset = pd.concat([mother, fetus])
dataset.to_csv("dataset.csv")
mother = mother.drop(columns=["top_bottom_distance"])
fetus = fetus.drop(columns=["top_bottom_distance"])
dataset = dataset.drop(columns=["top_bottom_distance"])
grouped_dataset = dataset.groupby(["label", "frame", "test", "type"])
grouped_mother = mother.groupby(["label", "frame", "test"])
grouped_fetus = fetus.groupby(["label", "frame", "test"])
data: List[Dict[np.ndarray, int]] = []
mother_data: List[Dict[np.ndarray, int]] = []
fetus_data: List[Dict[np.ndarray, int]] = []
for name, group in grouped_mother:
label = group["label"]
frame = group["frame"]
test = group["test"]
group = group.drop(columns=["test", "frame", "label", "type"])
group.set_index("image_name", inplace=True)
if group.columns[0] == "Unnamed: 0":
group = group.drop(columns=["Unnamed: 0"])
group = group.to_numpy()
if group.shape[0] < SERIES_LENGTH:
group = np.vstack(
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
)
elif group.shape[0] > SERIES_LENGTH:
group = group[:SERIES_LENGTH]
group = group.astype(np.float32)
label = CLASSES.index(label.iat[0])
mother_data.append(
{
"data": group,
"label": label,
}
)
for name, group in grouped_fetus:
label = group["label"]
frame = group["frame"]
test = group["test"]
group = group.drop(columns=["test", "frame", "label", "type"])
group.set_index("image_name", inplace=True)
if group.columns[0] == "Unnamed: 0":
group = group.drop(columns=["Unnamed: 0"])
group = group.to_numpy()
if group.shape[0] < SERIES_LENGTH:
group = np.vstack(
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
)
elif group.shape[0] > SERIES_LENGTH:
group = group[:SERIES_LENGTH]
group = group.astype(np.float32)
label = CLASSES.index(label.iat[0])
fetus_data.append(
{
"data": group,
"label": label,
}
)
for name, group in grouped_dataset:
label = group["label"]
frame = group["frame"]
test = group["test"]
data_type = group["type"]
group = group.drop(columns=["test", "frame", "label", "type"])
group.set_index("image_name", inplace=True)
if group.columns[0] == "Unnamed: 0":
group = group.drop(columns=["Unnamed: 0"])
group = group.to_numpy()
if group.shape[0] < SERIES_LENGTH:
group = np.vstack(
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
)
elif group.shape[0] > SERIES_LENGTH:
group = group[:SERIES_LENGTH]
group = group.astype(np.float32)
label = CLASSES.index(label.iat[0])
data.append(
{
"data": group,
"label": label,
}
)
if K_FOLD == 1:
x_all = [d["data"] for d in data]
y_all = [d["label"] for d in data]
x_mother = [d["data"] for d in mother_data]
y_mother = [d["label"] for d in mother_data]
x_fetus = [d["data"] for d in fetus_data]
y_fetus = [d["label"] for d in fetus_data]
(train_loader_all, test_loader_all) = create_loaders(
x_all,
y_all,
data,
over=OVER_SAMPLING,
under=UNDER_SAMPLING,
classes=CLASSES,
batch_size=BATCH_SIZE,
)
(train_loader_mother, test_loader_mother) = create_loaders(
x_mother,
y_mother,
mother_data,
over=OVER_SAMPLING,
under=UNDER_SAMPLING,
classes=CLASSES,
batch_size=BATCH_SIZE,
)
(train_loader_fetus, test_loader_fetus) = create_loaders(
x_fetus,
y_fetus,
fetus_data,
over=OVER_SAMPLING,
under=UNDER_SAMPLING,
classes=CLASSES,
batch_size=BATCH_SIZE,
)
(model, optimizer, scheduler, criterion) = setup_model_training(
input_size=FEATURE_SIZE,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYERS,
num_classes=len(CLASSES),
sequence_length=SERIES_LENGTH,
device=device,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
eps=EPS,
)
if DATASET_TYPE == "all":
train_loader = train_loader_all
test_loader = test_loader_all
elif DATASET_TYPE == "fetus":
train_loader = train_loader_fetus
test_loader = test_loader_fetus
elif DATASET_TYPE == "mother":
train_loader = train_loader_mother
test_loader = test_loader_mother
elif DATASET_TYPE == "fetus-mother":
train_loader = train_loader_fetus
test_loader = test_loader_mother
elif DATASET_TYPE == "mother-fetus":
train_loader = train_loader_mother
test_loader = test_loader_fetus
else:
Exception("Invalid dataset type")
trained_model = training_loop(
model=model,
train_loader=train_loader,
test_loader=test_loader,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
device=device,
epochs=EPOCHS,
early_stopping=EARLY_STOPPING,
log_dir="output/" + TEST_NAME + "/metrics",
)
loss, conf_matrix, classification_rep = validation(
trained_model,
test_loader,
criterion,
device,
)
# save classification report to a file
df = pd.DataFrame(classification_rep).transpose()
df.to_csv("output/" + TEST_NAME + "/metrics/classification_report.csv")
torch.save(
trained_model.state_dict(),
"output/" + TEST_NAME + "/weights/model.pth",
)
plt.figure(figsize=(19.20, 10.80))
plt.title("Confusion Matrix")
sns.heatmap(
conf_matrix,
annot=True,
fmt=".2f",
xticklabels=["Baseline", "Opcl", "Yawn"],
yticklabels=["Baseline", "Opcl", "Yawn"],
cmap="viridis",
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix.png")
plt.figure(figsize=(19.20, 10.80))
plt.title("Confusion Matrix Percentage")
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
sns.heatmap(
conf_matrix_percent,
annot=True,
fmt=".2f",
xticklabels=["Baseline", "Opcl", "Yawn"],
yticklabels=["Baseline", "Opcl", "Yawn"],
cmap="viridis",
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix_percentage.png")
else:
x_all = [d["data"] for d in data]
y_all = [d["label"] for d in data]
x_mother = [d["data"] for d in mother_data]
y_mother = [d["label"] for d in mother_data]
x_fetus = [d["data"] for d in fetus_data]
y_fetus = [d["label"] for d in fetus_data]
(train_loader_all, test_loader_all) = create_loaders(
x_all,
y_all,
data,
over=OVER_SAMPLING,
under=UNDER_SAMPLING,
classes=CLASSES,
batch_size=BATCH_SIZE,
)
(train_loader_mother, test_loader_mother) = create_loaders(
x_mother,
y_mother,
mother_data,
over=OVER_SAMPLING,
under=UNDER_SAMPLING,
classes=CLASSES,
batch_size=BATCH_SIZE,
)
(train_loader_fetus, test_loader_fetus) = create_loaders(
x_fetus,
y_fetus,
fetus_data,
over=OVER_SAMPLING,
under=UNDER_SAMPLING,
classes=CLASSES,
batch_size=BATCH_SIZE,
)
(model, optimizer, scheduler, criterion) = setup_model_training(
input_size=FEATURE_SIZE,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYERS,
num_classes=len(CLASSES),
sequence_length=SERIES_LENGTH,
device=device,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
eps=EPS,
)
kf = StratifiedKFold(n_splits=K_FOLD, shuffle=True, random_state=seed)
model_index = 0
x = None
y = None
if DATASET_TYPE == "all":
x = [d["data"] for d in data]
y = [d["label"] for d in data]
elif DATASET_TYPE == "fetus":
x = [d["data"] for d in fetus_data]
y = [d["label"] for d in fetus_data]
elif DATASET_TYPE == "mother":
x = [d["data"] for d in mother_data]
y = [d["label"] for d in mother_data]
# TODO: HOW TO HANDLE THIS CASES? for the mixed training and validation
else:
Exception("Invalid dataset type")
for train_index, test_index in kf.split(X=x, y=y):
train_data = [data[i] for i in train_index]
test_data = [data[i] for i in test_index]
data = train_data + test_data
x = [d["data"] for d in train_data]
y = [d["label"] for d in train_data]
(train_loader, test_loader) = create_loaders(
x,
y,
data,
over=OVER_SAMPLING,
under=UNDER_SAMPLING,
classes=CLASSES,
batch_size=BATCH_SIZE,
)
(model, optimizer, scheduler, criterion) = setup_model_training(
input_size=FEATURE_SIZE,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYERS,
num_classes=len(CLASSES),
sequence_length=SERIES_LENGTH,
device=device,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
eps=EPS,
)
trained_model = training_loop(
model=model,
train_loader=train_loader,
test_loader=test_loader,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
device=device,
epochs=EPOCHS,
early_stopping=EARLY_STOPPING,
log_dir="output/" + TEST_NAME + "/metrics/" + f"{model_index}",
)
loss, conf_matrix, classification_rep = validation(
trained_model,
test_loader,
criterion,
device,
)
# save classification report to a file
df = pd.DataFrame(classification_rep).transpose()
df.to_csv(
"output/"
+ TEST_NAME
+ "/metrics/classification_report_"
+ str(model_index)
+ ".csv"
)
torch.save(
trained_model.state_dict(),
"output/" + TEST_NAME + "/weights/model_" + str(model_index) + ".pth",
)
plt.figure(figsize=(19.20, 10.80))
plt.title("Confusion Matrix")
sns.heatmap(
conf_matrix,
annot=True,
fmt=".2f",
xticklabels=["Baseline", "Opcl", "Yawn"],
yticklabels=["Baseline", "Opcl", "Yawn"],
cmap="viridis",
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.savefig(
"output/"
+ TEST_NAME
+ "/confusion_matrix/confusion_matrix_"
+ str(model_index)
+ ".png"
)
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
plt.figure(figsize=(19.20, 10.80))
plt.title("Confusion Matrix Percentage")
sns.heatmap(
conf_matrix_percent,
annot=True,
fmt=".2f",
xticklabels=["Baseline", "Opcl", "Yawn"],
yticklabels=["Baseline", "Opcl", "Yawn"],
cmap="viridis",
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.savefig(
"output/"
+ TEST_NAME
+ "/confusion_matrix/confusion_matrix_percentage_"
+ str(model_index)
+ ".png"
)
model_index += 1