diff --git a/dmsr.py b/dmsr.py index 0faa52e7f896fe786a8d4768d66184f83bdd9843..68d6eff77f4dd873ba084a7389b8ea10c1f5941b 100644 --- a/dmsr.py +++ b/dmsr.py @@ -97,17 +97,17 @@ class HeteroGNN(torch.nn.Module): edge_weight = data.edge_feature # Edge weights need to be reformated to match model's requirements. keys = [key for key in edge_weight] + edge_weight2 = {} for key in keys: newKey = key[1] - edge_weight[newKey] = edge_weight[key] - del edge_weight[key] + edge_weight2[newKey] = edge_weight[key] - x = self.convs1(x, edge_index, edge_weight) + x = self.convs1(x, edge_index, edge_weight2) x = forward_op(x, self.bns1) x = forward_op(x, self.relus1) x = forward_op(x, self.dropout1) - x = self.convs2(x, edge_index, edge_weight) + x = self.convs2(x, edge_index, edge_weight2) x = forward_op(x, self.bns2) return x