[GNN] Graph 기반 증강
GNN
Graph 기반 증강
GRAN
GRAN은 그래프 불균형을 해결하기 위해 무작위 그래프 증강을 사용하는 방법
import torch
import torch.nn as nn
import torch.optim as optim
class GRAN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(GRAN, self).__init__()
self.conv1 = torch.nn.GraphConvolution(input_dim, hidden_dim)
self.conv2 = torch.nn.GraphConvolution(hidden_dim, output_dim)
def forward(self, x, adjacency_matrix):
= F.relu(self.conv1(x, adjacency_matrix))
x = self.conv2(x, adjacency_matrix)
x return x
# 데이터 로딩 및 전처리
# ...
# 모델 초기화
= GRAN(input_dim, hidden_dim, output_dim)
model = nn.CrossEntropyLoss()
criterion = optim.Adam(model.parameters(), lr=0.001)
optimizer
# 학습 루프
for epoch in range(num_epochs):
# 그래프 증강 수행
= augment_graph(original_data)
augmented_data
# 모델 학습
= model(augmented_data)
output = criterion(output, labels)
loss
optimizer.zero_grad()
loss.backward() optimizer.step()
CARE_GNN
CARE-GNN은 적대적 학습을 사용하여 그래프 데이터를 향상시키는 방법
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv
class CARE_GNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(CARE_GNN, self).__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, output_dim)
def forward(self, x, edge_index):
= F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
x return x
# 데이터 로딩 및 전처리
# ...
# 모델 초기화
= CARE_GNN(input_dim, hidden_dim, output_dim)
model = nn.CrossEntropyLoss()
criterion = optim.Adam(model.parameters(), lr=0.001)
optimizer
# 학습 루프
for epoch in range(num_epochs):
# 적대적 학습 수행
= adversarial_training(original_data)
adversarial_data
# 모델 학습
= model(adversarial_data)
output = criterion(output, labels)
loss
optimizer.zero_grad()
loss.backward() optimizer.step()