Commit 448ca4e4 authored by ADRIAN  AYUSO MUNOZ's avatar ADRIAN AYUSO MUNOZ

Name change

parent ab14a632
......@@ -227,7 +227,7 @@ def train(model, dataloaders, optimizer, epochs):
plt.xlabel('Iteration')
plt.xticks(range(0, epochs+1, int(epochs/10)))
plt.yticks((0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85))
plt.savefig('plots/deepSnapPred/metrics/loss.svg', format='svg', dpi=1200)
plt.savefig('plots/behor/metrics/loss.svg', format='svg', dpi=1200)
plt.clf()
return t_accu, v_accu, e_accu, best_model
......@@ -310,7 +310,7 @@ def main(epochs, hidden_dim, lr, weight_decay, dropout):
_, _, _, model = train(model, dataloaders, optimizer, epochs)
print("Finished training at", datetime.now().strftime("%H:%M:%S"))
torch.save(model.state_dict(),"./models/modelDeepSnapPred")
torch.save(model.state_dict(), "models/behor")
# Testing
model.eval()
......@@ -322,8 +322,8 @@ def main(epochs, hidden_dim, lr, weight_decay, dropout):
labels = [item for sublist in true_labels for item in sublist]
pure_predictions = [item for sublist in pure_pred_labels for item in sublist]
plot_roc(labels, pure_predictions, keys[0], "deepSnapPred/")
plot_prc(torch.tensor(labels), pure_predictions, keys[0], "deepSnapPred/")
plot_roc(labels, pure_predictions, keys[0], "behor/")
plot_prc(torch.tensor(labels), pure_predictions, keys[0], "behor/")
return model
......
......@@ -107,10 +107,10 @@ def deepSnapMetrics(model):
pure_predictions = [item for sublist in [preds,predsN] for item in sublist]
labels = torch.tensor([item for sublist in [labels1,labels2] for item in sublist])
fpr, tpr, label1 = plot_roc(labels, pure_predictions,('disorder', 'dis_dru_the', 'drug'), "deepSnapPred/", "RepoDB")
recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "deepSnapPred/", "RepoDB")
fpr, tpr, label1 = plot_roc(labels, pure_predictions,('disorder', 'dis_dru_the', 'drug'), "behor/", "RepoDB")
recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "behor/", "RepoDB")
plotMetrics("deepSnapPred", fpr, tpr, label1, recall, precision, label2)
plotMetrics("behor", fpr, tpr, label1, recall, precision, label2)
return label1, label2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment