From 882541a1054240dddd2247abc3d5226756fba4aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20Ayuso=20Mu=C3=B1oz?= Date: Wed, 28 Jun 2023 15:43:58 +0200 Subject: [PATCH] Fix logic. --- dmsr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dmsr.py b/dmsr.py index 0faa52e..68d6eff 100644 --- a/dmsr.py +++ b/dmsr.py @@ -97,17 +97,17 @@ class HeteroGNN(torch.nn.Module): edge_weight = data.edge_feature # Edge weights need to be reformated to match model's requirements. keys = [key for key in edge_weight] + edge_weight2 = {} for key in keys: newKey = key[1] - edge_weight[newKey] = edge_weight[key] - del edge_weight[key] + edge_weight2[newKey] = edge_weight[key] - x = self.convs1(x, edge_index, edge_weight) + x = self.convs1(x, edge_index, edge_weight2) x = forward_op(x, self.bns1) x = forward_op(x, self.relus1) x = forward_op(x, self.dropout1) - x = self.convs2(x, edge_index, edge_weight) + x = self.convs2(x, edge_index, edge_weight2) x = forward_op(x, self.bns2) return x -- 2.24.1