modifica per usare un file di config
This commit is contained in:
parent
1179f93485
commit
9964843459
74
fetus-event-detection-classification/experiments.json
Normal file
74
fetus-event-detection-classification/experiments.json
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "all_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "all",
|
||||||
|
"oversampling": false,
|
||||||
|
"undersampling": false,
|
||||||
|
"epochs": 50,
|
||||||
|
"path": "~/Documents/womb-wise-data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "fetus_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "fetus",
|
||||||
|
"oversampling": false,
|
||||||
|
"undersampling": false,
|
||||||
|
"epochs": 50,
|
||||||
|
"path": "~/Documents/womb-wise-data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mother_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "mother",
|
||||||
|
"oversampling": false,
|
||||||
|
"undersampling": false,
|
||||||
|
"epochs": 50,
|
||||||
|
"path": "~/Documents/womb-wise-data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "fetus_mother_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "fetus-mother",
|
||||||
|
"oversampling": false,
|
||||||
|
"undersampling": false,
|
||||||
|
"epochs": 50,
|
||||||
|
"path": "~/Documents/womb-wise-data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mother_fetus_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "mother-fetus",
|
||||||
|
"oversampling": false,
|
||||||
|
"undersampling": false,
|
||||||
|
"epochs": 50,
|
||||||
|
"path": "~/Documents/womb-wise-data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "all_dataset_k5",
|
||||||
|
"k_fold": 5,
|
||||||
|
"dataset_type": "all",
|
||||||
|
"oversampling": false,
|
||||||
|
"undersampling": false,
|
||||||
|
"epochs": 50,
|
||||||
|
"path": "~/Documents/womb-wise-data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "fetus_dataset_k5",
|
||||||
|
"k_fold": 5,
|
||||||
|
"dataset_type": "fetus",
|
||||||
|
"oversampling": false,
|
||||||
|
"undersampling": false,
|
||||||
|
"epochs": 50,
|
||||||
|
"path": "~/Documents/womb-wise-data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mother_dataset_k5",
|
||||||
|
"k_fold": 5,
|
||||||
|
"dataset_type": "mother",
|
||||||
|
"oversampling": false,
|
||||||
|
"undersampling": false,
|
||||||
|
"epochs": 50,
|
||||||
|
"path": "~/Documents/womb-wise-data"
|
||||||
|
}
|
||||||
|
]
|
||||||
83
fetus-event-detection-classification/src/experiments.py
Normal file
83
fetus-event-detection-classification/src/experiments.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def create_experiments(experiments_path, base_path):
|
||||||
|
if os.path.exists(experiments_path):
|
||||||
|
with open(experiments_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
experiments = [
|
||||||
|
{
|
||||||
|
"name": "all_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "all",
|
||||||
|
"oversampling": False,
|
||||||
|
"undersampling": False,
|
||||||
|
"epochs": 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "fetus_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "fetus",
|
||||||
|
"oversampling": False,
|
||||||
|
"undersampling": False,
|
||||||
|
"epochs": 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mother_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "mother",
|
||||||
|
"oversampling": False,
|
||||||
|
"undersampling": False,
|
||||||
|
"epochs": 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "fetus_mother_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "fetus-mother",
|
||||||
|
"oversampling": False,
|
||||||
|
"undersampling": False,
|
||||||
|
"epochs": 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mother_fetus_dataset_k1",
|
||||||
|
"k_fold": 1,
|
||||||
|
"dataset_type": "mother-fetus",
|
||||||
|
"oversampling": False,
|
||||||
|
"undersampling": False,
|
||||||
|
"epochs": 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "all_dataset_k5",
|
||||||
|
"k_fold": 5,
|
||||||
|
"dataset_type": "all",
|
||||||
|
"oversampling": False,
|
||||||
|
"undersampling": False,
|
||||||
|
"epochs": 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "fetus_dataset_k5",
|
||||||
|
"k_fold": 5,
|
||||||
|
"dataset_type": "fetus",
|
||||||
|
"oversampling": False,
|
||||||
|
"undersampling": False,
|
||||||
|
"epochs": 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mother_dataset_k5",
|
||||||
|
"k_fold": 5,
|
||||||
|
"dataset_type": "mother",
|
||||||
|
"oversampling": False,
|
||||||
|
"undersampling": False,
|
||||||
|
"epochs": 50,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for experiment in experiments:
|
||||||
|
experiment["path"] = base_path
|
||||||
|
|
||||||
|
with open(experiments_path, "w") as f:
|
||||||
|
json.dump(experiments, f, indent=4)
|
||||||
|
|
||||||
|
return experiments
|
||||||
@ -9,18 +9,17 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import argparse
|
|
||||||
from imblearn.over_sampling import RandomOverSampler
|
from imblearn.over_sampling import RandomOverSampler
|
||||||
from imblearn.under_sampling import RandomUnderSampler
|
from imblearn.under_sampling import RandomUnderSampler
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from sklearn.model_selection import StratifiedKFold
|
from sklearn.model_selection import StratifiedKFold
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from load_dataset import get_dataset
|
from load_dataset import get_dataset
|
||||||
from model import SimpleLSTM
|
from model import SimpleLSTM
|
||||||
from training import training_loop
|
from training import training_loop
|
||||||
from validation import validation
|
from validation import validation
|
||||||
|
from experiments import create_experiments
|
||||||
|
|
||||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||||
|
|
||||||
@ -53,8 +52,6 @@ def setup_model_training(
|
|||||||
optimizer, mode="min", factor=0.1, patience=25
|
optimizer, mode="min", factor=0.1, patience=25
|
||||||
)
|
)
|
||||||
|
|
||||||
# criterion = nn.CrossEntropyLoss()
|
|
||||||
# criterion = nn.BCELoss()
|
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
return (model, optimizer, scheduler, criterion)
|
return (model, optimizer, scheduler, criterion)
|
||||||
@ -161,112 +158,31 @@ class FetusDataset(Dataset):
|
|||||||
x = self.data[idx]["data"]
|
x = self.data[idx]["data"]
|
||||||
y = self.data[idx]["label"]
|
y = self.data[idx]["label"]
|
||||||
|
|
||||||
# Conversione del tipo di dato
|
|
||||||
x = x.astype(np.float32)
|
x = x.astype(np.float32)
|
||||||
y = np.eye(self.classes)[y]
|
y = np.eye(self.classes)[y]
|
||||||
|
|
||||||
# Conversione in tensori
|
|
||||||
x = torch.tensor(x, dtype=torch.float32)
|
x = torch.tensor(x, dtype=torch.float32)
|
||||||
y = torch.tensor(y, dtype=torch.int32)
|
y = torch.tensor(y, dtype=torch.int32)
|
||||||
|
|
||||||
# Gestione di valori NaN o infiniti
|
|
||||||
x = torch.nan_to_num(
|
x = torch.nan_to_num(
|
||||||
x
|
x
|
||||||
) # Sostituisce NaN con 0 e valori infiniti con numeri molto grandi o piccoli
|
)
|
||||||
|
|
||||||
# Normalizzazione solo durante il training
|
|
||||||
if self.train:
|
if self.train:
|
||||||
mean = x.mean()
|
mean = x.mean()
|
||||||
std = x.std()
|
std = x.std()
|
||||||
|
|
||||||
# Normalizzazione condizionale (solo se std > 0)
|
|
||||||
if std > 0:
|
if std > 0:
|
||||||
x = (x - mean) / std
|
x = (x - mean) / std
|
||||||
|
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
def createArgParser():
|
|
||||||
parser = argparse.ArgumentParser(description="Womb Wise")
|
|
||||||
parser.add_argument(
|
|
||||||
"-rd",
|
|
||||||
"--reload-dataset",
|
|
||||||
action="store_true",
|
|
||||||
help="Reload the dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
# path to the dataset
|
|
||||||
parser.add_argument(
|
|
||||||
"-p",
|
|
||||||
"--path",
|
|
||||||
action="store",
|
|
||||||
help="Path to the dataset",
|
|
||||||
default="~/Documents/womb-wise-data",
|
|
||||||
)
|
|
||||||
|
|
||||||
# epoch
|
|
||||||
parser.add_argument(
|
|
||||||
"-e",
|
|
||||||
"--epochs",
|
|
||||||
action="store",
|
|
||||||
help="Number of epochs",
|
|
||||||
default=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"-k",
|
|
||||||
"--kfold",
|
|
||||||
action="store",
|
|
||||||
help="Number of folds for kfold cross validation",
|
|
||||||
default=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"-o",
|
|
||||||
"--oversampling",
|
|
||||||
action="store_true",
|
|
||||||
help="Apply oversampling",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"-u",
|
|
||||||
"--undersampling",
|
|
||||||
action="store_true",
|
|
||||||
help="Apply undersampling",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"-d",
|
|
||||||
"--dataset",
|
|
||||||
action="store",
|
|
||||||
default="all",
|
|
||||||
choices=["all", "fetus", "mother", "fetus-mother", "mother-fetus"],
|
|
||||||
help="Choose the dataset: all, fetus, mother or train with mother and test with fetus or viceversa",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"""
|
|
||||||
ARGS:
|
|
||||||
\n
|
|
||||||
reload-dataset: {args.reload_dataset}
|
|
||||||
path: {args.path}
|
|
||||||
epochs: {args.epochs}
|
|
||||||
kfold: {args.kfold}
|
|
||||||
oversampling: {args.oversampling}
|
|
||||||
undersampling: {args.undersampling}
|
|
||||||
dataset: {args.dataset}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
BASE_PATH = "~/Documents/womb-wise-data"
|
||||||
CLASSES = ["baseline", "opcl", "yawn"]
|
CLASSES = ["baseline", "opcl", "yawn"]
|
||||||
FEATURE_SIZE = 10
|
FEATURE_SIZE = 10
|
||||||
SERIES_LENGTH = 60
|
SERIES_LENGTH = 60
|
||||||
# SINGLE_FRAME_LENGTH = FEATURE_SIZE * SERIES_LENGTH
|
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
WEIGHT_DECAY = 1e-5
|
WEIGHT_DECAY = 1e-5
|
||||||
LEARNING_RATE = 1e-3
|
LEARNING_RATE = 1e-3
|
||||||
@ -275,15 +191,27 @@ if __name__ == "__main__":
|
|||||||
DROP_OUT = 0.0
|
DROP_OUT = 0.0
|
||||||
NUM_LAYERS = 2
|
NUM_LAYERS = 2
|
||||||
EPS = 1e-7
|
EPS = 1e-7
|
||||||
|
EARLY_STOPPING = True
|
||||||
|
|
||||||
TEST_NAME = "0_k1_all"
|
seed = 42
|
||||||
# TEST_NAME = "1_k1_fetus"
|
np.random.seed(seed)
|
||||||
# TEST_NAME = "2_k1_mother"
|
torch.manual_seed(seed)
|
||||||
# TEST_NAME = "3_k1_mother_fetus"
|
torch.cuda.manual_seed(seed)
|
||||||
# TEST_NAME = "4_k1_fetus_mother"
|
torch.cuda.manual_seed_all(seed)
|
||||||
# TEST_NAME = "5_k5_all"
|
|
||||||
# TEST_NAME = "6_k5_fetus"
|
device = get_device()
|
||||||
# TEST_NAME = "7_k5_mother"
|
|
||||||
|
experiments = create_experiments("experiments.json", BASE_PATH)
|
||||||
|
for experiment in experiments:
|
||||||
|
TEST_NAME = experiment["name"]
|
||||||
|
K_FOLD = experiment["k_fold"]
|
||||||
|
OVER_SAMPLING = experiment["oversampling"]
|
||||||
|
UNDER_SAMPLING = experiment["undersampling"]
|
||||||
|
DATASET_TYPE = experiment["dataset_type"]
|
||||||
|
EPOCHS = experiment["epochs"]
|
||||||
|
PATH = experiment["path"]
|
||||||
|
|
||||||
|
print(f"Running experiment: {TEST_NAME}")
|
||||||
|
|
||||||
if not os.path.exists("output/" + TEST_NAME):
|
if not os.path.exists("output/" + TEST_NAME):
|
||||||
os.makedirs("output/" + TEST_NAME)
|
os.makedirs("output/" + TEST_NAME)
|
||||||
@ -297,32 +225,13 @@ if __name__ == "__main__":
|
|||||||
if not os.path.exists("output/" + TEST_NAME + "/metrics"):
|
if not os.path.exists("output/" + TEST_NAME + "/metrics"):
|
||||||
os.makedirs("output/" + TEST_NAME + "/metrics")
|
os.makedirs("output/" + TEST_NAME + "/metrics")
|
||||||
|
|
||||||
# fix the seed
|
if os.path.exists("dataset.csv") and os.path.exists("mother.csv") and os.path.exists("fetus.csv"):
|
||||||
seed = 42
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
args = createArgParser()
|
|
||||||
|
|
||||||
PATH = args.path
|
|
||||||
EPOCHS = int(args.epochs)
|
|
||||||
K_FOLD = int(args.kfold)
|
|
||||||
OVER_SAMPLING = args.oversampling
|
|
||||||
UNDER_SAMPLING = args.undersampling
|
|
||||||
EARLY_STOPPING = True
|
|
||||||
DATASET_TYPE = args.dataset
|
|
||||||
|
|
||||||
|
|
||||||
if os.path.exists("dataset.csv") and args.reload_dataset is False:
|
|
||||||
dataset = pd.read_csv("dataset.csv")
|
dataset = pd.read_csv("dataset.csv")
|
||||||
mother = pd.read_csv("mother.csv")
|
mother = pd.read_csv("mother.csv")
|
||||||
fetus = pd.read_csv("fetus.csv")
|
fetus = pd.read_csv("fetus.csv")
|
||||||
else:
|
else:
|
||||||
baseline_fetus = get_dataset(
|
baseline_fetus = get_dataset(
|
||||||
PATH +"/Ultrasound_Scans/tracked_frames/",
|
PATH + "/Ultrasound_Scans/tracked_frames/",
|
||||||
"baseline",
|
"baseline",
|
||||||
)
|
)
|
||||||
yawn_fetus = get_dataset(
|
yawn_fetus = get_dataset(
|
||||||
@ -359,8 +268,6 @@ if __name__ == "__main__":
|
|||||||
dataset = pd.concat([mother, fetus])
|
dataset = pd.concat([mother, fetus])
|
||||||
dataset.to_csv("dataset.csv")
|
dataset.to_csv("dataset.csv")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
mother = mother.drop(columns=["top_bottom_distance"])
|
mother = mother.drop(columns=["top_bottom_distance"])
|
||||||
fetus = fetus.drop(columns=["top_bottom_distance"])
|
fetus = fetus.drop(columns=["top_bottom_distance"])
|
||||||
dataset = dataset.drop(columns=["top_bottom_distance"])
|
dataset = dataset.drop(columns=["top_bottom_distance"])
|
||||||
@ -388,7 +295,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if group.shape[0] < SERIES_LENGTH:
|
if group.shape[0] < SERIES_LENGTH:
|
||||||
group = np.vstack(
|
group = np.vstack(
|
||||||
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
[group, np.zeros(
|
||||||
|
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||||
)
|
)
|
||||||
|
|
||||||
elif group.shape[0] > SERIES_LENGTH:
|
elif group.shape[0] > SERIES_LENGTH:
|
||||||
@ -420,7 +328,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if group.shape[0] < SERIES_LENGTH:
|
if group.shape[0] < SERIES_LENGTH:
|
||||||
group = np.vstack(
|
group = np.vstack(
|
||||||
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
[group, np.zeros(
|
||||||
|
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||||
)
|
)
|
||||||
|
|
||||||
elif group.shape[0] > SERIES_LENGTH:
|
elif group.shape[0] > SERIES_LENGTH:
|
||||||
@ -453,7 +362,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if group.shape[0] < SERIES_LENGTH:
|
if group.shape[0] < SERIES_LENGTH:
|
||||||
group = np.vstack(
|
group = np.vstack(
|
||||||
[group, np.zeros((SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
[group, np.zeros(
|
||||||
|
(SERIES_LENGTH - group.shape[0], FEATURE_SIZE))]
|
||||||
)
|
)
|
||||||
|
|
||||||
elif group.shape[0] > SERIES_LENGTH:
|
elif group.shape[0] > SERIES_LENGTH:
|
||||||
@ -470,7 +380,6 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
if K_FOLD == 1:
|
if K_FOLD == 1:
|
||||||
|
|
||||||
x_all = [d["data"] for d in data]
|
x_all = [d["data"] for d in data]
|
||||||
y_all = [d["label"] for d in data]
|
y_all = [d["label"] for d in data]
|
||||||
|
|
||||||
@ -565,9 +474,9 @@ if __name__ == "__main__":
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save classification report to a file
|
|
||||||
df = pd.DataFrame(classification_rep).transpose()
|
df = pd.DataFrame(classification_rep).transpose()
|
||||||
df.to_csv("output/" + TEST_NAME + "/metrics/classification_report.csv")
|
df.to_csv("output/" + TEST_NAME +
|
||||||
|
"/metrics/classification_report.csv")
|
||||||
|
|
||||||
torch.save(
|
torch.save(
|
||||||
trained_model.state_dict(),
|
trained_model.state_dict(),
|
||||||
@ -588,11 +497,13 @@ if __name__ == "__main__":
|
|||||||
plt.xlabel("Predicted")
|
plt.xlabel("Predicted")
|
||||||
plt.ylabel("Actual")
|
plt.ylabel("Actual")
|
||||||
|
|
||||||
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix.png")
|
plt.savefig("output/" + TEST_NAME +
|
||||||
|
"/confusion_matrix/confusion_matrix.png")
|
||||||
|
|
||||||
plt.figure(figsize=(19.20, 10.80))
|
plt.figure(figsize=(19.20, 10.80))
|
||||||
plt.title("Confusion Matrix Percentage")
|
plt.title("Confusion Matrix Percentage")
|
||||||
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
conf_matrix_percent = conf_matrix.astype(
|
||||||
|
'float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
||||||
sns.heatmap(
|
sns.heatmap(
|
||||||
conf_matrix_percent,
|
conf_matrix_percent,
|
||||||
annot=True,
|
annot=True,
|
||||||
@ -605,7 +516,8 @@ if __name__ == "__main__":
|
|||||||
plt.xlabel("Predicted")
|
plt.xlabel("Predicted")
|
||||||
plt.ylabel("Actual")
|
plt.ylabel("Actual")
|
||||||
|
|
||||||
plt.savefig("output/" + TEST_NAME + "/confusion_matrix/confusion_matrix_percentage.png")
|
plt.savefig("output/" + TEST_NAME +
|
||||||
|
"/confusion_matrix/confusion_matrix_percentage.png")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
x_all = [d["data"] for d in data]
|
x_all = [d["data"] for d in data]
|
||||||
@ -659,7 +571,8 @@ if __name__ == "__main__":
|
|||||||
eps=EPS,
|
eps=EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
kf = StratifiedKFold(n_splits=K_FOLD, shuffle=True, random_state=seed)
|
kf = StratifiedKFold(
|
||||||
|
n_splits=K_FOLD, shuffle=True, random_state=seed)
|
||||||
model_index = 0
|
model_index = 0
|
||||||
|
|
||||||
x = None
|
x = None
|
||||||
@ -720,7 +633,8 @@ if __name__ == "__main__":
|
|||||||
device=device,
|
device=device,
|
||||||
epochs=EPOCHS,
|
epochs=EPOCHS,
|
||||||
early_stopping=EARLY_STOPPING,
|
early_stopping=EARLY_STOPPING,
|
||||||
log_dir="output/" + TEST_NAME + "/metrics/" + f"{model_index}",
|
log_dir="output/" + TEST_NAME +
|
||||||
|
"/metrics/" + f"{model_index}",
|
||||||
)
|
)
|
||||||
|
|
||||||
loss, conf_matrix, classification_rep = validation(
|
loss, conf_matrix, classification_rep = validation(
|
||||||
@ -730,7 +644,6 @@ if __name__ == "__main__":
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save classification report to a file
|
|
||||||
df = pd.DataFrame(classification_rep).transpose()
|
df = pd.DataFrame(classification_rep).transpose()
|
||||||
df.to_csv(
|
df.to_csv(
|
||||||
"output/"
|
"output/"
|
||||||
@ -742,7 +655,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
torch.save(
|
torch.save(
|
||||||
trained_model.state_dict(),
|
trained_model.state_dict(),
|
||||||
"output/" + TEST_NAME + "/weights/model_" + str(model_index) + ".pth",
|
"output/" + TEST_NAME + "/weights/model_" +
|
||||||
|
str(model_index) + ".pth",
|
||||||
)
|
)
|
||||||
plt.figure(figsize=(19.20, 10.80))
|
plt.figure(figsize=(19.20, 10.80))
|
||||||
plt.title("Confusion Matrix")
|
plt.title("Confusion Matrix")
|
||||||
@ -764,7 +678,8 @@ if __name__ == "__main__":
|
|||||||
+ str(model_index)
|
+ str(model_index)
|
||||||
+ ".png"
|
+ ".png"
|
||||||
)
|
)
|
||||||
conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
conf_matrix_percent = conf_matrix.astype(
|
||||||
|
'float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100
|
||||||
plt.figure(figsize=(19.20, 10.80))
|
plt.figure(figsize=(19.20, 10.80))
|
||||||
plt.title("Confusion Matrix Percentage")
|
plt.title("Confusion Matrix Percentage")
|
||||||
sns.heatmap(
|
sns.heatmap(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user