modifica per usare un file di config

This commit is contained in:
Dmitri 2025-11-20 00:11:05 +01:00
parent 1179f93485
commit 9964843459
Signed by: kanopo
GPG Key ID: 759ADD40E3132AC7
3 changed files with 559 additions and 487 deletions

View File

@ -0,0 +1,74 @@
[
{
"name": "all_dataset_k1",
"k_fold": 1,
"dataset_type": "all",
"oversampling": false,
"undersampling": false,
"epochs": 50,
"path": "~/Documents/womb-wise-data"
},
{
"name": "fetus_dataset_k1",
"k_fold": 1,
"dataset_type": "fetus",
"oversampling": false,
"undersampling": false,
"epochs": 50,
"path": "~/Documents/womb-wise-data"
},
{
"name": "mother_dataset_k1",
"k_fold": 1,
"dataset_type": "mother",
"oversampling": false,
"undersampling": false,
"epochs": 50,
"path": "~/Documents/womb-wise-data"
},
{
"name": "fetus_mother_dataset_k1",
"k_fold": 1,
"dataset_type": "fetus-mother",
"oversampling": false,
"undersampling": false,
"epochs": 50,
"path": "~/Documents/womb-wise-data"
},
{
"name": "mother_fetus_dataset_k1",
"k_fold": 1,
"dataset_type": "mother-fetus",
"oversampling": false,
"undersampling": false,
"epochs": 50,
"path": "~/Documents/womb-wise-data"
},
{
"name": "all_dataset_k5",
"k_fold": 5,
"dataset_type": "all",
"oversampling": false,
"undersampling": false,
"epochs": 50,
"path": "~/Documents/womb-wise-data"
},
{
"name": "fetus_dataset_k5",
"k_fold": 5,
"dataset_type": "fetus",
"oversampling": false,
"undersampling": false,
"epochs": 50,
"path": "~/Documents/womb-wise-data"
},
{
"name": "mother_dataset_k5",
"k_fold": 5,
"dataset_type": "mother",
"oversampling": false,
"undersampling": false,
"epochs": 50,
"path": "~/Documents/womb-wise-data"
}
]

View File

@ -0,0 +1,83 @@
import json
import os
def create_experiments(experiments_path, base_path):
if os.path.exists(experiments_path):
with open(experiments_path, "r") as f:
return json.load(f)
experiments = [
{
"name": "all_dataset_k1",
"k_fold": 1,
"dataset_type": "all",
"oversampling": False,
"undersampling": False,
"epochs": 50,
},
{
"name": "fetus_dataset_k1",
"k_fold": 1,
"dataset_type": "fetus",
"oversampling": False,
"undersampling": False,
"epochs": 50,
},
{
"name": "mother_dataset_k1",
"k_fold": 1,
"dataset_type": "mother",
"oversampling": False,
"undersampling": False,
"epochs": 50,
},
{
"name": "fetus_mother_dataset_k1",
"k_fold": 1,
"dataset_type": "fetus-mother",
"oversampling": False,
"undersampling": False,
"epochs": 50,
},
{
"name": "mother_fetus_dataset_k1",
"k_fold": 1,
"dataset_type": "mother-fetus",
"oversampling": False,
"undersampling": False,
"epochs": 50,
},
{
"name": "all_dataset_k5",
"k_fold": 5,
"dataset_type": "all",
"oversampling": False,
"undersampling": False,
"epochs": 50,
},
{
"name": "fetus_dataset_k5",
"k_fold": 5,
"dataset_type": "fetus",
"oversampling": False,
"undersampling": False,
"epochs": 50,
},
{
"name": "mother_dataset_k5",
"k_fold": 5,
"dataset_type": "mother",
"oversampling": False,
"undersampling": False,
"epochs": 50,
},
]
for experiment in experiments:
experiment["path"] = base_path
with open(experiments_path, "w") as f:
json.dump(experiments, f, indent=4)
return experiments

View File

@ -9,18 +9,17 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import seaborn as sns import seaborn as sns
import argparse
from imblearn.over_sampling import RandomOverSampler from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler from imblearn.under_sampling import RandomUnderSampler
from collections import Counter from collections import Counter
from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from datetime import datetime
from load_dataset import get_dataset from load_dataset import get_dataset
from model import SimpleLSTM from model import SimpleLSTM
from training import training_loop from training import training_loop
from validation import validation from validation import validation
from experiments import create_experiments
warnings.simplefilter(action="ignore", category=FutureWarning) warnings.simplefilter(action="ignore", category=FutureWarning)
@ -53,8 +52,6 @@ def setup_model_training(
optimizer, mode="min", factor=0.1, patience=25 optimizer, mode="min", factor=0.1, patience=25
) )
# criterion = nn.CrossEntropyLoss()
# criterion = nn.BCELoss()
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
return (model, optimizer, scheduler, criterion) return (model, optimizer, scheduler, criterion)
@ -161,112 +158,31 @@ class FetusDataset(Dataset):
x = self.data[idx]["data"] x = self.data[idx]["data"]
y = self.data[idx]["label"] y = self.data[idx]["label"]
# Conversione del tipo di dato
x = x.astype(np.float32) x = x.astype(np.float32)
y = np.eye(self.classes)[y] y = np.eye(self.classes)[y]
# Conversione in tensori
x = torch.tensor(x, dtype=torch.float32) x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.int32) y = torch.tensor(y, dtype=torch.int32)
# Gestione di valori NaN o infiniti
x = torch.nan_to_num( x = torch.nan_to_num(
x x
) # Sostituisce NaN con 0 e valori infiniti con numeri molto grandi o piccoli )
# Normalizzazione solo durante il training
if self.train: if self.train:
mean = x.mean() mean = x.mean()
std = x.std() std = x.std()
# Normalizzazione condizionale (solo se std > 0)
if std > 0: if std > 0:
x = (x - mean) / std x = (x - mean) / std
return x, y return x, y
def createArgParser():
parser = argparse.ArgumentParser(description="Womb Wise")
parser.add_argument(
"-rd",
"--reload-dataset",
action="store_true",
help="Reload the dataset",
)
# path to the dataset
parser.add_argument(
"-p",
"--path",
action="store",
help="Path to the dataset",
default="~/Documents/womb-wise-data",
)
# epoch
parser.add_argument(
"-e",
"--epochs",
action="store",
help="Number of epochs",
default=10,
)
parser.add_argument(
"-k",
"--kfold",
action="store",
help="Number of folds for kfold cross validation",
default=1,
)
parser.add_argument(
"-o",
"--oversampling",
action="store_true",
help="Apply oversampling",
)
parser.add_argument(
"-u",
"--undersampling",
action="store_true",
help="Apply undersampling",
)
parser.add_argument(
"-d",
"--dataset",
action="store",
default="all",
choices=["all", "fetus", "mother", "fetus-mother", "mother-fetus"],
help="Choose the dataset: all, fetus, mother or train with mother and test with fetus or viceversa",
)
args = parser.parse_args()
print(
f"""
ARGS:
\n
reload-dataset: {args.reload_dataset}
path: {args.path}
epochs: {args.epochs}
kfold: {args.kfold}
oversampling: {args.oversampling}
undersampling: {args.undersampling}
dataset: {args.dataset}
"""
)
return args
if __name__ == "__main__": if __name__ == "__main__":
BASE_PATH = "~/Documents/womb-wise-data"
CLASSES = ["baseline", "opcl", "yawn"] CLASSES = ["baseline", "opcl", "yawn"]
FEATURE_SIZE = 10 FEATURE_SIZE = 10
SERIES_LENGTH = 60 SERIES_LENGTH = 60
# SINGLE_FRAME_LENGTH = FEATURE_SIZE * SERIES_LENGTH
BATCH_SIZE = 4 BATCH_SIZE = 4
WEIGHT_DECAY = 1e-5 WEIGHT_DECAY = 1e-5
LEARNING_RATE = 1e-3 LEARNING_RATE = 1e-3
@ -275,15 +191,27 @@ if __name__ == "__main__":
DROP_OUT = 0.0 DROP_OUT = 0.0
NUM_LAYERS = 2 NUM_LAYERS = 2
EPS = 1e-7 EPS = 1e-7
EARLY_STOPPING = True
TEST_NAME = "0_k1_all" seed = 42
# TEST_NAME = "1_k1_fetus" np.random.seed(seed)
# TEST_NAME = "2_k1_mother" torch.manual_seed(seed)
# TEST_NAME = "3_k1_mother_fetus" torch.cuda.manual_seed(seed)
# TEST_NAME = "4_k1_fetus_mother" torch.cuda.manual_seed_all(seed)
# TEST_NAME = "5_k5_all"
# TEST_NAME = "6_k5_fetus" device = get_device()
# TEST_NAME = "7_k5_mother"
experiments = create_experiments("experiments.json", BASE_PATH)
for experiment in experiments:
TEST_NAME = experiment["name"]
K_FOLD = experiment["k_fold"]
OVER_SAMPLING = experiment["oversampling"]
UNDER_SAMPLING = experiment["undersampling"]
DATASET_TYPE = experiment["dataset_type"]
EPOCHS = experiment["epochs"]
PATH = experiment["path"]
print(f"Running experiment: {TEST_NAME}")
if not os.path.exists("output/" + TEST_NAME): if not os.path.exists("output/" + TEST_NAME):
os.makedirs("output/" + TEST_NAME) os.makedirs("output/" + TEST_NAME)
@ -297,26 +225,7 @@ if __name__ == "__main__":
if not os.path.exists("output/" + TEST_NAME + "/metrics"): if not os.path.exists("output/" + TEST_NAME + "/metrics"):
os.makedirs("output/" + TEST_NAME + "/metrics") os.makedirs("output/" + TEST_NAME + "/metrics")
# fix the seed if os.path.exists("dataset.csv") and os.path.exists("mother.csv") and os.path.exists("fetus.csv"):
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = get_device()
args = createArgParser()
PATH = args.path
EPOCHS = int(args.epochs)
K_FOLD = int(args.kfold)
OVER_SAMPLING = args.oversampling
UNDER_SAMPLING = args.undersampling
EARLY_STOPPING = True
DATASET_TYPE = args.dataset
if os.path.exists("dataset.csv") and args.reload_dataset is False:
dataset = pd.read_csv("dataset.csv") dataset = pd.read_csv("dataset.csv")
mother = pd.read_csv("mother.csv") mother = pd.read_csv("mother.csv")
fetus = pd.read_csv("fetus.csv") fetus = pd.read_csv("fetus.csv")
@ -359,8 +268,6 @@ if __name__ == "__main__":
dataset = pd.concat([mother, fetus]) dataset = pd.concat([mother, fetus])
dataset.to_csv("dataset.csv") dataset.to_csv("dataset.csv")
mother = mother.drop(columns=["top_bottom_distance"]) mother = mother.drop(columns=["top_bottom_distance"])
fetus = fetus.drop(columns=["top_bottom_distance"]) fetus = fetus.drop(columns=["top_bottom_distance"])
dataset = dataset.drop(columns=["top_bottom_distance"]) dataset = dataset.drop(columns=["top_bottom_distance"])
@ -388,7 +295,8 @@ if __name__ == "__main__":
if group.shape[0] < SERIES_LENGTH: if group.shape[0] < SERIES_LENGTH:
group = np.vstack( group = np.vstack(
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))] [group, np.zeros(
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
) )
elif group.shape[0] > SERIES_LENGTH: elif group.shape[0] > SERIES_LENGTH:
@ -420,7 +328,8 @@ if __name__ == "__main__":
if group.shape[0] < SERIES_LENGTH: if group.shape[0] < SERIES_LENGTH:
group = np.vstack( group = np.vstack(
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))] [group, np.zeros(
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
) )
elif group.shape[0] > SERIES_LENGTH: elif group.shape[0] > SERIES_LENGTH:
@ -453,7 +362,8 @@ if __name__ == "__main__":
if group.shape[0] < SERIES_LENGTH: if group.shape[0] < SERIES_LENGTH:
group = np.vstack( group = np.vstack(
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))] [group, np.zeros(
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
) )
elif group.shape[0] > SERIES_LENGTH: elif group.shape[0] > SERIES_LENGTH:
@ -470,7 +380,6 @@ if __name__ == "__main__":
) )
if K_FOLD == 1: if K_FOLD == 1:
x_all = [d["data"] for d in data] x_all = [d["data"] for d in data]
y_all = [d["label"] for d in data] y_all = [d["label"] for d in data]
@ -565,9 +474,9 @@ if __name__ == "__main__":
device, device,
) )
# save classification report to a file
df = pd.DataFrame(classification_rep).transpose() df = pd.DataFrame(classification_rep).transpose()
df.to_csv("output/" + TEST_NAME + "/metrics/classification_report.csv") df.to_csv("output/" + TEST_NAME +
"/metrics/classification_report.csv")
torch.save( torch.save(
trained_model.state_dict(), trained_model.state_dict(),
@ -588,11 +497,13 @@ if __name__ == "__main__":
plt.xlabel("Predicted") plt.xlabel("Predicted")
plt.ylabel("Actual") plt.ylabel("Actual")
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix.png") plt.savefig("output/" + TEST_NAME +
"/confusion_matrix/confusion_matrix.png")
plt.figure(figsize=(19.20, 10.80)) plt.figure(figsize=(19.20, 10.80))
plt.title("Confusion Matrix Percentage") plt.title("Confusion Matrix Percentage")
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100 conf_matrix_percent = conf_matrix.astype(
'float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
sns.heatmap( sns.heatmap(
conf_matrix_percent, conf_matrix_percent,
annot=True, annot=True,
@ -605,7 +516,8 @@ if __name__ == "__main__":
plt.xlabel("Predicted") plt.xlabel("Predicted")
plt.ylabel("Actual") plt.ylabel("Actual")
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix_percentage.png") plt.savefig("output/" + TEST_NAME +
"/confusion_matrix/confusion_matrix_percentage.png")
else: else:
x_all = [d["data"] for d in data] x_all = [d["data"] for d in data]
@ -659,7 +571,8 @@ if __name__ == "__main__":
eps=EPS, eps=EPS,
) )
kf = StratifiedKFold(n_splits=K_FOLD, shuffle=True, random_state=seed) kf = StratifiedKFold(
n_splits=K_FOLD, shuffle=True, random_state=seed)
model_index = 0 model_index = 0
x = None x = None
@ -720,7 +633,8 @@ if __name__ == "__main__":
device=device, device=device,
epochs=EPOCHS, epochs=EPOCHS,
early_stopping=EARLY_STOPPING, early_stopping=EARLY_STOPPING,
log_dir="output/" + TEST_NAME + "/metrics/" + f"{model_index}", log_dir="output/" + TEST_NAME +
"/metrics/" + f"{model_index}",
) )
loss, conf_matrix, classification_rep = validation( loss, conf_matrix, classification_rep = validation(
@ -730,7 +644,6 @@ if __name__ == "__main__":
device, device,
) )
# save classification report to a file
df = pd.DataFrame(classification_rep).transpose() df = pd.DataFrame(classification_rep).transpose()
df.to_csv( df.to_csv(
"output/" "output/"
@ -742,7 +655,8 @@ if __name__ == "__main__":
torch.save( torch.save(
trained_model.state_dict(), trained_model.state_dict(),
"output/" + TEST_NAME + "/weights/model_" + str(model_index) + ".pth", "output/" + TEST_NAME + "/weights/model_" +
str(model_index) + ".pth",
) )
plt.figure(figsize=(19.20, 10.80)) plt.figure(figsize=(19.20, 10.80))
plt.title("Confusion Matrix") plt.title("Confusion Matrix")
@ -764,7 +678,8 @@ if __name__ == "__main__":
+ str(model_index) + str(model_index)
+ ".png" + ".png"
) )
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100 conf_matrix_percent = conf_matrix.astype(
'float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
plt.figure(figsize=(19.20, 10.80)) plt.figure(figsize=(19.20, 10.80))
plt.title("Confusion Matrix Percentage") plt.title("Confusion Matrix Percentage")
sns.heatmap( sns.heatmap(