Commit 882541a1 authored by ADRIAN  AYUSO MUNOZ's avatar ADRIAN AYUSO MUNOZ

Fix logic.

parent c1efac0d
......@@ -97,17 +97,17 @@ class HeteroGNN(torch.nn.Module):
edge_weight = data.edge_feature # Edge weights need to be reformated to match model's requirements.
keys = [key for key in edge_weight]
edge_weight2 = {}
for key in keys:
newKey = key[1]
edge_weight[newKey] = edge_weight[key]
del edge_weight[key]
edge_weight2[newKey] = edge_weight[key]
x = self.convs1(x, edge_index, edge_weight)
x = self.convs1(x, edge_index, edge_weight2)
x = forward_op(x, self.bns1)
x = forward_op(x, self.relus1)
x = forward_op(x, self.dropout1)
x = self.convs2(x, edge_index, edge_weight)
x = self.convs2(x, edge_index, edge_weight2)
x = forward_op(x, self.bns2)
return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment