Commit 882541a1 authored by ADRIAN  AYUSO MUNOZ's avatar ADRIAN AYUSO MUNOZ

Fix logic.

parent c1efac0d
...@@ -97,17 +97,17 @@ class HeteroGNN(torch.nn.Module): ...@@ -97,17 +97,17 @@ class HeteroGNN(torch.nn.Module):
edge_weight = data.edge_feature # Edge weights need to be reformated to match model's requirements. edge_weight = data.edge_feature # Edge weights need to be reformated to match model's requirements.
keys = [key for key in edge_weight] keys = [key for key in edge_weight]
edge_weight2 = {}
for key in keys: for key in keys:
newKey = key[1] newKey = key[1]
edge_weight[newKey] = edge_weight[key] edge_weight2[newKey] = edge_weight[key]
del edge_weight[key]
x = self.convs1(x, edge_index, edge_weight) x = self.convs1(x, edge_index, edge_weight2)
x = forward_op(x, self.bns1) x = forward_op(x, self.bns1)
x = forward_op(x, self.relus1) x = forward_op(x, self.relus1)
x = forward_op(x, self.dropout1) x = forward_op(x, self.dropout1)
x = self.convs2(x, edge_index, edge_weight) x = self.convs2(x, edge_index, edge_weight2)
x = forward_op(x, self.bns2) x = forward_op(x, self.bns2)
return x return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment