Skip to content

Commit

Permalink
TCI
Browse files Browse the repository at this point in the history
  • Loading branch information
TheColdIce committed Jan 17, 2024
1 parent a9c284d commit 14ee797
Show file tree
Hide file tree
Showing 11 changed files with 705 additions and 54 deletions.
Binary file modified gnn/__pycache__/criterions.cpython-311.pyc
Binary file not shown.
Binary file modified gnn/__pycache__/data.cpython-311.pyc
Binary file not shown.
Binary file modified gnn/__pycache__/modules.cpython-311.pyc
Binary file not shown.
23 changes: 16 additions & 7 deletions gnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,27 @@ def get_data(self):


class AmlsimDataset():
def __init__(self, node_file:str, edge_file:str):
self.data = self.load_data(node_file, edge_file)
def __init__(self, node_file:str, edge_file:str, node_features:bool=False, edge_features:bool=True, node_labels:bool=False, edge_labels:bool=False, seed:int=42):
self.data = self.load_data(node_file, edge_file, node_features, edge_features, node_labels, edge_labels)

def load_data(self, node_file, edge_file):
def load_data(self, node_file, edge_file, node_features, edge_features, node_labels, edge_labels):
nodes = pd.read_csv(node_file)
edges = pd.read_csv(edge_file)
edge_index = torch.tensor(edges[['src', 'dst']].values, dtype=torch.long)
edge_index = edge_index.t().contiguous()
x = torch.tensor(nodes[['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10']].values, dtype=torch.float)
y = torch.tensor(nodes['y'].values, dtype=torch.long)
#y = torch.nn.functional.one_hot(y.type(torch.long), num_classes=2).type(torch.float)
data = Data(x=x, edge_index=edge_index, y=y)

if node_features:
x = torch.tensor(nodes[nodes.columns[:-1]].values, dtype=torch.float)
else:
x = torch.ones(nodes.shape[0], 1)
if edge_features:
edge_attr = torch.tensor(edges[edges.columns[:-1]].values, dtype=torch.float)
if node_labels:
y = torch.tensor(nodes[nodes.columns[-1]].values, dtype=torch.long)
elif edge_labels:
y = torch.tensor(edges[edges.columns[-1]].values, dtype=torch.long)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
return data

def get_data(self):
Expand Down
Loading

0 comments on commit 14ee797

Please sign in to comment.