modifica per usare un file di config
This commit is contained in:
parent
1179f93485
commit
9964843459
74
fetus-event-detection-classification/experiments.json
Normal file
74
fetus-event-detection-classification/experiments.json
Normal file
@ -0,0 +1,74 @@
|
||||
[
|
||||
{
|
||||
"name": "all_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "all",
|
||||
"oversampling": false,
|
||||
"undersampling": false,
|
||||
"epochs": 50,
|
||||
"path": "~/Documents/womb-wise-data"
|
||||
},
|
||||
{
|
||||
"name": "fetus_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "fetus",
|
||||
"oversampling": false,
|
||||
"undersampling": false,
|
||||
"epochs": 50,
|
||||
"path": "~/Documents/womb-wise-data"
|
||||
},
|
||||
{
|
||||
"name": "mother_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "mother",
|
||||
"oversampling": false,
|
||||
"undersampling": false,
|
||||
"epochs": 50,
|
||||
"path": "~/Documents/womb-wise-data"
|
||||
},
|
||||
{
|
||||
"name": "fetus_mother_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "fetus-mother",
|
||||
"oversampling": false,
|
||||
"undersampling": false,
|
||||
"epochs": 50,
|
||||
"path": "~/Documents/womb-wise-data"
|
||||
},
|
||||
{
|
||||
"name": "mother_fetus_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "mother-fetus",
|
||||
"oversampling": false,
|
||||
"undersampling": false,
|
||||
"epochs": 50,
|
||||
"path": "~/Documents/womb-wise-data"
|
||||
},
|
||||
{
|
||||
"name": "all_dataset_k5",
|
||||
"k_fold": 5,
|
||||
"dataset_type": "all",
|
||||
"oversampling": false,
|
||||
"undersampling": false,
|
||||
"epochs": 50,
|
||||
"path": "~/Documents/womb-wise-data"
|
||||
},
|
||||
{
|
||||
"name": "fetus_dataset_k5",
|
||||
"k_fold": 5,
|
||||
"dataset_type": "fetus",
|
||||
"oversampling": false,
|
||||
"undersampling": false,
|
||||
"epochs": 50,
|
||||
"path": "~/Documents/womb-wise-data"
|
||||
},
|
||||
{
|
||||
"name": "mother_dataset_k5",
|
||||
"k_fold": 5,
|
||||
"dataset_type": "mother",
|
||||
"oversampling": false,
|
||||
"undersampling": false,
|
||||
"epochs": 50,
|
||||
"path": "~/Documents/womb-wise-data"
|
||||
}
|
||||
]
|
||||
83
fetus-event-detection-classification/src/experiments.py
Normal file
83
fetus-event-detection-classification/src/experiments.py
Normal file
@ -0,0 +1,83 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def create_experiments(experiments_path, base_path):
|
||||
if os.path.exists(experiments_path):
|
||||
with open(experiments_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
experiments = [
|
||||
{
|
||||
"name": "all_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "all",
|
||||
"oversampling": False,
|
||||
"undersampling": False,
|
||||
"epochs": 50,
|
||||
},
|
||||
{
|
||||
"name": "fetus_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "fetus",
|
||||
"oversampling": False,
|
||||
"undersampling": False,
|
||||
"epochs": 50,
|
||||
},
|
||||
{
|
||||
"name": "mother_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "mother",
|
||||
"oversampling": False,
|
||||
"undersampling": False,
|
||||
"epochs": 50,
|
||||
},
|
||||
{
|
||||
"name": "fetus_mother_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "fetus-mother",
|
||||
"oversampling": False,
|
||||
"undersampling": False,
|
||||
"epochs": 50,
|
||||
},
|
||||
{
|
||||
"name": "mother_fetus_dataset_k1",
|
||||
"k_fold": 1,
|
||||
"dataset_type": "mother-fetus",
|
||||
"oversampling": False,
|
||||
"undersampling": False,
|
||||
"epochs": 50,
|
||||
},
|
||||
{
|
||||
"name": "all_dataset_k5",
|
||||
"k_fold": 5,
|
||||
"dataset_type": "all",
|
||||
"oversampling": False,
|
||||
"undersampling": False,
|
||||
"epochs": 50,
|
||||
},
|
||||
{
|
||||
"name": "fetus_dataset_k5",
|
||||
"k_fold": 5,
|
||||
"dataset_type": "fetus",
|
||||
"oversampling": False,
|
||||
"undersampling": False,
|
||||
"epochs": 50,
|
||||
},
|
||||
{
|
||||
"name": "mother_dataset_k5",
|
||||
"k_fold": 5,
|
||||
"dataset_type": "mother",
|
||||
"oversampling": False,
|
||||
"undersampling": False,
|
||||
"epochs": 50,
|
||||
},
|
||||
]
|
||||
|
||||
for experiment in experiments:
|
||||
experiment["path"] = base_path
|
||||
|
||||
with open(experiments_path, "w") as f:
|
||||
json.dump(experiments, f, indent=4)
|
||||
|
||||
return experiments
|
||||
@ -9,18 +9,17 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import seaborn as sns
|
||||
import argparse
|
||||
from imblearn.over_sampling import RandomOverSampler
|
||||
from imblearn.under_sampling import RandomUnderSampler
|
||||
from collections import Counter
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
from sklearn.model_selection import train_test_split
|
||||
from datetime import datetime
|
||||
|
||||
from load_dataset import get_dataset
|
||||
from model import SimpleLSTM
|
||||
from training import training_loop
|
||||
from validation import validation
|
||||
from experiments import create_experiments
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
@ -53,8 +52,6 @@ def setup_model_training(
|
||||
optimizer, mode="min", factor=0.1, patience=25
|
||||
)
|
||||
|
||||
# criterion = nn.CrossEntropyLoss()
|
||||
# criterion = nn.BCELoss()
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
|
||||
return (model, optimizer, scheduler, criterion)
|
||||
@ -161,112 +158,31 @@ class FetusDataset(Dataset):
|
||||
x = self.data[idx]["data"]
|
||||
y = self.data[idx]["label"]
|
||||
|
||||
# Conversione del tipo di dato
|
||||
x = x.astype(np.float32)
|
||||
y = np.eye(self.classes)[y]
|
||||
|
||||
# Conversione in tensori
|
||||
x = torch.tensor(x, dtype=torch.float32)
|
||||
y = torch.tensor(y, dtype=torch.int32)
|
||||
|
||||
# Gestione di valori NaN o infiniti
|
||||
x = torch.nan_to_num(
|
||||
x
|
||||
) # Sostituisce NaN con 0 e valori infiniti con numeri molto grandi o piccoli
|
||||
)
|
||||
|
||||
# Normalizzazione solo durante il training
|
||||
if self.train:
|
||||
mean = x.mean()
|
||||
std = x.std()
|
||||
|
||||
# Normalizzazione condizionale (solo se std > 0)
|
||||
if std > 0:
|
||||
x = (x - mean) / std
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
def createArgParser():
|
||||
parser = argparse.ArgumentParser(description="Womb Wise")
|
||||
parser.add_argument(
|
||||
"-rd",
|
||||
"--reload-dataset",
|
||||
action="store_true",
|
||||
help="Reload the dataset",
|
||||
)
|
||||
|
||||
# path to the dataset
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--path",
|
||||
action="store",
|
||||
help="Path to the dataset",
|
||||
default="~/Documents/womb-wise-data",
|
||||
)
|
||||
|
||||
# epoch
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--epochs",
|
||||
action="store",
|
||||
help="Number of epochs",
|
||||
default=10,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--kfold",
|
||||
action="store",
|
||||
help="Number of folds for kfold cross validation",
|
||||
default=1,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--oversampling",
|
||||
action="store_true",
|
||||
help="Apply oversampling",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-u",
|
||||
"--undersampling",
|
||||
action="store_true",
|
||||
help="Apply undersampling",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
action="store",
|
||||
default="all",
|
||||
choices=["all", "fetus", "mother", "fetus-mother", "mother-fetus"],
|
||||
help="Choose the dataset: all, fetus, mother or train with mother and test with fetus or viceversa",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(
|
||||
f"""
|
||||
ARGS:
|
||||
\n
|
||||
reload-dataset: {args.reload_dataset}
|
||||
path: {args.path}
|
||||
epochs: {args.epochs}
|
||||
kfold: {args.kfold}
|
||||
oversampling: {args.oversampling}
|
||||
undersampling: {args.undersampling}
|
||||
dataset: {args.dataset}
|
||||
"""
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
BASE_PATH = "~/Documents/womb-wise-data"
|
||||
CLASSES = ["baseline", "opcl", "yawn"]
|
||||
FEATURE_SIZE = 10
|
||||
SERIES_LENGTH = 60
|
||||
# SINGLE_FRAME_LENGTH = FEATURE_SIZE * SERIES_LENGTH
|
||||
BATCH_SIZE = 4
|
||||
WEIGHT_DECAY = 1e-5
|
||||
LEARNING_RATE = 1e-3
|
||||
@ -275,15 +191,27 @@ if __name__ == "__main__":
|
||||
DROP_OUT = 0.0
|
||||
NUM_LAYERS = 2
|
||||
EPS = 1e-7
|
||||
EARLY_STOPPING = True
|
||||
|
||||
TEST_NAME = "0_k1_all"
|
||||
# TEST_NAME = "1_k1_fetus"
|
||||
# TEST_NAME = "2_k1_mother"
|
||||
# TEST_NAME = "3_k1_mother_fetus"
|
||||
# TEST_NAME = "4_k1_fetus_mother"
|
||||
# TEST_NAME = "5_k5_all"
|
||||
# TEST_NAME = "6_k5_fetus"
|
||||
# TEST_NAME = "7_k5_mother"
|
||||
seed = 42
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
device = get_device()
|
||||
|
||||
experiments = create_experiments("experiments.json", BASE_PATH)
|
||||
for experiment in experiments:
|
||||
TEST_NAME = experiment["name"]
|
||||
K_FOLD = experiment["k_fold"]
|
||||
OVER_SAMPLING = experiment["oversampling"]
|
||||
UNDER_SAMPLING = experiment["undersampling"]
|
||||
DATASET_TYPE = experiment["dataset_type"]
|
||||
EPOCHS = experiment["epochs"]
|
||||
PATH = experiment["path"]
|
||||
|
||||
print(f"Running experiment: {TEST_NAME}")
|
||||
|
||||
if not os.path.exists("output/" + TEST_NAME):
|
||||
os.makedirs("output/" + TEST_NAME)
|
||||
@ -297,26 +225,7 @@ if __name__ == "__main__":
|
||||
if not os.path.exists("output/" + TEST_NAME + "/metrics"):
|
||||
os.makedirs("output/" + TEST_NAME + "/metrics")
|
||||
|
||||
# fix the seed
|
||||
seed = 42
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
device = get_device()
|
||||
args = createArgParser()
|
||||
|
||||
PATH = args.path
|
||||
EPOCHS = int(args.epochs)
|
||||
K_FOLD = int(args.kfold)
|
||||
OVER_SAMPLING = args.oversampling
|
||||
UNDER_SAMPLING = args.undersampling
|
||||
EARLY_STOPPING = True
|
||||
DATASET_TYPE = args.dataset
|
||||
|
||||
|
||||
if os.path.exists("dataset.csv") and args.reload_dataset is False:
|
||||
if os.path.exists("dataset.csv") and os.path.exists("mother.csv") and os.path.exists("fetus.csv"):
|
||||
dataset = pd.read_csv("dataset.csv")
|
||||
mother = pd.read_csv("mother.csv")
|
||||
fetus = pd.read_csv("fetus.csv")
|
||||
@ -359,8 +268,6 @@ if __name__ == "__main__":
|
||||
dataset = pd.concat([mother, fetus])
|
||||
dataset.to_csv("dataset.csv")
|
||||
|
||||
|
||||
|
||||
mother = mother.drop(columns=["top_bottom_distance"])
|
||||
fetus = fetus.drop(columns=["top_bottom_distance"])
|
||||
dataset = dataset.drop(columns=["top_bottom_distance"])
|
||||
@ -388,7 +295,8 @@ if __name__ == "__main__":
|
||||
|
||||
if group.shape[0] < SERIES_LENGTH:
|
||||
group = np.vstack(
|
||||
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||
[group, np.zeros(
|
||||
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||
)
|
||||
|
||||
elif group.shape[0] > SERIES_LENGTH:
|
||||
@ -420,7 +328,8 @@ if __name__ == "__main__":
|
||||
|
||||
if group.shape[0] < SERIES_LENGTH:
|
||||
group = np.vstack(
|
||||
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||
[group, np.zeros(
|
||||
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||
)
|
||||
|
||||
elif group.shape[0] > SERIES_LENGTH:
|
||||
@ -453,7 +362,8 @@ if __name__ == "__main__":
|
||||
|
||||
if group.shape[0] < SERIES_LENGTH:
|
||||
group = np.vstack(
|
||||
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||
[group, np.zeros(
|
||||
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||
)
|
||||
|
||||
elif group.shape[0] > SERIES_LENGTH:
|
||||
@ -470,7 +380,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if K_FOLD == 1:
|
||||
|
||||
x_all = [d["data"] for d in data]
|
||||
y_all = [d["label"] for d in data]
|
||||
|
||||
@ -565,9 +474,9 @@ if __name__ == "__main__":
|
||||
device,
|
||||
)
|
||||
|
||||
# save classification report to a file
|
||||
df = pd.DataFrame(classification_rep).transpose()
|
||||
df.to_csv("output/" + TEST_NAME + "/metrics/classification_report.csv")
|
||||
df.to_csv("output/" + TEST_NAME +
|
||||
"/metrics/classification_report.csv")
|
||||
|
||||
torch.save(
|
||||
trained_model.state_dict(),
|
||||
@ -588,11 +497,13 @@ if __name__ == "__main__":
|
||||
plt.xlabel("Predicted")
|
||||
plt.ylabel("Actual")
|
||||
|
||||
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix.png")
|
||||
plt.savefig("output/" + TEST_NAME +
|
||||
"/confusion_matrix/confusion_matrix.png")
|
||||
|
||||
plt.figure(figsize=(19.20, 10.80))
|
||||
plt.title("Confusion Matrix Percentage")
|
||||
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
||||
conf_matrix_percent = conf_matrix.astype(
|
||||
'float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
||||
sns.heatmap(
|
||||
conf_matrix_percent,
|
||||
annot=True,
|
||||
@ -605,7 +516,8 @@ if __name__ == "__main__":
|
||||
plt.xlabel("Predicted")
|
||||
plt.ylabel("Actual")
|
||||
|
||||
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix_percentage.png")
|
||||
plt.savefig("output/" + TEST_NAME +
|
||||
"/confusion_matrix/confusion_matrix_percentage.png")
|
||||
|
||||
else:
|
||||
x_all = [d["data"] for d in data]
|
||||
@ -659,7 +571,8 @@ if __name__ == "__main__":
|
||||
eps=EPS,
|
||||
)
|
||||
|
||||
kf = StratifiedKFold(n_splits=K_FOLD, shuffle=True, random_state=seed)
|
||||
kf = StratifiedKFold(
|
||||
n_splits=K_FOLD, shuffle=True, random_state=seed)
|
||||
model_index = 0
|
||||
|
||||
x = None
|
||||
@ -720,7 +633,8 @@ if __name__ == "__main__":
|
||||
device=device,
|
||||
epochs=EPOCHS,
|
||||
early_stopping=EARLY_STOPPING,
|
||||
log_dir="output/" + TEST_NAME + "/metrics/" + f"{model_index}",
|
||||
log_dir="output/" + TEST_NAME +
|
||||
"/metrics/" + f"{model_index}",
|
||||
)
|
||||
|
||||
loss, conf_matrix, classification_rep = validation(
|
||||
@ -730,7 +644,6 @@ if __name__ == "__main__":
|
||||
device,
|
||||
)
|
||||
|
||||
# save classification report to a file
|
||||
df = pd.DataFrame(classification_rep).transpose()
|
||||
df.to_csv(
|
||||
"output/"
|
||||
@ -742,7 +655,8 @@ if __name__ == "__main__":
|
||||
|
||||
torch.save(
|
||||
trained_model.state_dict(),
|
||||
"output/" + TEST_NAME + "/weights/model_" + str(model_index) + ".pth",
|
||||
"output/" + TEST_NAME + "/weights/model_" +
|
||||
str(model_index) + ".pth",
|
||||
)
|
||||
plt.figure(figsize=(19.20, 10.80))
|
||||
plt.title("Confusion Matrix")
|
||||
@ -764,7 +678,8 @@ if __name__ == "__main__":
|
||||
+ str(model_index)
|
||||
+ ".png"
|
||||
)
|
||||
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
||||
conf_matrix_percent = conf_matrix.astype(
|
||||
'float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
||||
plt.figure(figsize=(19.20, 10.80))
|
||||
plt.title("Confusion Matrix Percentage")
|
||||
sns.heatmap(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user