[GNN] Graph 기반 증강

GNN
Author

김보람

Published

January 1, 2024

Graph 기반 증강

GRAN

GRAN은 그래프 불균형을 해결하기 위해 무작위 그래프 증강을 사용하는 방법

import torch
import torch.nn as nn
import torch.optim as optim

class GRAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GRAN, self).__init__()
        self.conv1 = torch.nn.GraphConvolution(input_dim, hidden_dim)
        self.conv2 = torch.nn.GraphConvolution(hidden_dim, output_dim)

    def forward(self, x, adjacency_matrix):
        x = F.relu(self.conv1(x, adjacency_matrix))
        x = self.conv2(x, adjacency_matrix)
        return x

# 데이터 로딩 및 전처리
# ...

# 모델 초기화
model = GRAN(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 학습 루프
for epoch in range(num_epochs):
    # 그래프 증강 수행
    augmented_data = augment_graph(original_data)
    
    # 모델 학습
    output = model(augmented_data)
    loss = criterion(output, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

CARE_GNN

CARE-GNN은 적대적 학습을 사용하여 그래프 데이터를 향상시키는 방법

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv

class CARE_GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(CARE_GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

# 데이터 로딩 및 전처리
# ...

# 모델 초기화
model = CARE_GNN(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 학습 루프
for epoch in range(num_epochs):
    # 적대적 학습 수행
    adversarial_data = adversarial_training(original_data)
    
    # 모델 학습
    output = model(adversarial_data)
    loss = criterion(output, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()