Пошаговое руководство для GNNExplainer для объяснения узлов и графиков, реализованных в PyTorch Geometric.

Графическая нейронная сеть (GNN) - это тип нейронной сети, которая может быть непосредственно применена к данным с графической структурой. В моем предыдущем посте я кратко познакомился с GNN. Читатели могут быть перенаправлены к этому посту для получения более подробной информации.

Многие исследования показали, что GNN способна понимать графики, но то, как и почему работает GNN, по-прежнему остается загадкой для большинства людей. В отличие от CNN, где мы можем извлекать активацию каждого уровня для визуализации решений сети, в GNN трудно получить содержательное объяснение того, какие функции сеть изучила. Почему GNN определяет, что узел относится к классу A, а не к классу B? Почему GNN определяет, что график - это химическое вещество или молекула? Кажется, что GNN видит некоторую полезную структурную информацию и делает выводы на основании этих наблюдений. Но теперь проблема в том, какие наблюдения видит GNN?

Что такое GNNExplainer?

GNNExplainer представлен в этой статье.

Короче говоря, он пытается построить сеть, чтобы узнать, чему научилась GNN.

Главный принцип GNNExplainer заключается в сокращении избыточной информации в графе, которая не влияет напрямую на решения. Чтобы объяснить граф, мы хотим знать, какие ключевые функции или структуры в графе влияют на решения нейронной сети. Если функция важна, то прогноз следует в значительной степени изменить, удалив или заменив эту функцию чем-то другим. С другой стороны, если удаление или изменение признака не влияет на результат прогнозирования, признак считается несущественным и, следовательно, не должен включаться в пояснение к графику.

Как это работает?

Основная цель GNNExplainer - создать минимальный граф, объясняющий решение для узла или графа. Для достижения этой цели задача может быть определена как поиск подграфа в графе вычислений, который минимизирует разницу в оценках прогнозов с использованием всего графа вычислений и минимального графа. В статье этот процесс сформулирован как максимизация взаимной информации (MI) между минимальным графом Gs и графом вычислений G:

Кроме того, есть второстепенная задача: граф должен быть минимальным. Хотя об этом также упоминалось в первой задаче, нам также нужен метод, чтобы сформулировать эту цель. В документе это решается путем добавления потерь для количества краев. Следовательно, потеря для GNNExplainer - это буквально комбинация потери предсказания и потери размера края.

Объяснение Задачи

В документе обсуждаются три типа объяснений: объяснение для одного узла, объяснение для одного класса узлов и объяснение для графа. Основное различие заключается в графиках вычислений.

Для объяснения отдельного узла граф вычислений - это его соседи с k шагами, где k - количество сверток в модели.

Для объяснения класса узлов в документе предлагается выбрать ссылочный узел и использовать тот же метод для вычисления объяснения. Контрольный узел можно выбрать, взяв узел, характеристики которого наиболее близки к средним характеристикам всех других узлов того же класса.

Для объяснения всего графа граф вычислений становится объединением графов вычислений для всех узлов в графе. Это делает граф вычислений эквивалентным всему входному графу.

Подход по маске

Изучение минимального графа Gs осуществляется путем изучения маски для ребер и маски для признаков. То есть для каждого ребра в графе вычислений существует значение в edge_mask, которое определяет важность ребра. Точно так же для каждого объекта в объекте узла свойство feature_mask определяет, важен ли объект для окончательного решения.

Краткое содержание

Со всеми этими концепциями мы можем резюмировать все для GNNExplainer:

  1. Нам нужно извлечь граф вычислений, который является соседом из k переходов для классификации узлов, или весь граф для классификации графов.
  2. Инициализируйте edge_mask для каждого ребра в графе вычислений и маску объекта для каждого измерения объекта.
  3. Постройте нейронную сеть, которая изучает edge_mask и feature_mask с потерями, описанными выше.
  4. Используйте edge_mask и feature_mask, чтобы уменьшить граф вычислений до минимального графа.

Реализация GNNExplainer в Pytorch

Это все, что нам нужно знать, прежде чем мы сможем реализовать GNNExplainer. Подводя итог, мы пытаемся изучить edge_mask и node_feature_mask, которые удаляют некоторые ребра и особенности из графа вычислений, минимизируя разницу в оценке прогноза, результирующий граф является минимальным графом, который объясняет решение узла или графа.

Я собираюсь реализовать это в Pytorch Geometric (PyG). Одним из больших преимуществ PyG является то, что он очень часто обновляется и имеет множество реализаций текущих моделей. К моему удивлению, я обнаружил, что GNNExplainer уже реализован в библиотеке PyG, что экономит мне много времени. Хотя он работает только для объяснений узлов, благодаря открытому исходному коду, нетрудно изменить его, чтобы он работал и для объяснений графов.

Объяснение узлов

Для начала нам сначала нужно установить PyG. GNNExplainer еще не находится в их текущей версии (PyG 1.4.4), но коды уже выпущены на Github. Итак, чтобы получить GNNExplainer, вы должны клонировать его из репозитория Github и установить оттуда.

Пример кода представлен на сайте PyG. За ним легко следить, поэтому я не собираюсь показывать код в этом посте. Но детали реализации - это то, что мы хотим проверить и в дальнейшем использовать для классификации графов.

Я собираюсь отследить код, основываясь на моем кратком изложении выше. Пример кода передает индекс узла вместе с полной матрицей функций и списком ребер в модуль GNNExplainer.

explainer = GNNExplainer(model, epochs=200)node_idx = 10node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)

То, что происходит в GNNExplainer, мы обсуждали в предыдущем разделе.

  1. Извлеките граф вычислений

Чтобы объяснить узел, нам сначала нужно получить его граф вычислений с k шагами. Это делается с помощью метода _ _subgraph __ () в PyG.

x, edge_index, hard_edge_mask, kwargs = self.__subgraph__(
            node_idx, x, edge_index, **kwargs)

Hard_edge_mask удаляет все остальные ребра за пределами k-шаговой окрестности.

2. Маски инициализируются методом __set_mask __ () и применяются ко всем слоям сети.

self.__set_masks__(x, edge_index)         
def __set_masks__(self, x, edge_index, init="normal"):         
    (N, F), E = x.size(), edge_index.size(1)          
    std = 0.1         
    self.node_feat_mask = torch.nn.Parameter(torch.randn(F) * 0.1)                
    std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))           
    self.edge_mask = torch.nn.Parameter(torch.randn(E) * std)                 
    for module in self.model.modules():             
        if isinstance(module, MessagePassing):                          
            module.__explain__ = True
            module.__edge_mask__ = self.edge_mask

3. Первоначальное прогнозирование выполняется с использованием обученной модели, затем прогноз используется в качестве метки для обучения GNNExplainer.

# Get the initial prediction.         
with torch.no_grad():             
    log_logits = self.model(x=x, edge_index=edge_index, **kwargs) 
    pred_label = log_logits.argmax(dim=-1)          
# Train GNNExplainer
for epoch in range(1, self.epochs + 1):                  
    optimizer.zero_grad()             
    h = x * self.node_feat_mask.view(1, -1).sigmoid()             
    log_logits = self.model(x=h, edge_index=edge_index, **kwargs)              
    loss = self.__loss__(0, log_logits, pred_label)             
    loss.backward()             
    optimizer.step()

4. Убыток определен

def __loss__(self, node_idx, log_logits, pred_label):         
      loss = -log_logits[node_idx, pred_label[node_idx]]          
      m = self.edge_mask.sigmoid()         
      loss = loss + self.coeffs['edge_size'] * m.sum()         
      ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)         
      loss = loss + self.coeffs['edge_ent'] * ent.mean()          
      m = self.node_feat_mask.sigmoid()         
      loss = loss + self.coeffs['node_feat_size'] * m.sum()         
      ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)         
      loss = loss + self.coeffs['node_feat_ent'] * ent.mean()          return loss

Объяснение графов

Текущая реализация PyG предназначена только для объяснения узлов. Но, понимая принципы, лежащие в основе, нетрудно переписать функцию для объяснения графика.

Нам нужно заменить всего несколько функций: 1) нам нужно заменить функцию __subgraph__, чтобы получить граф вычислений для всего графа. 2) нам нужно установить маски для всего графа. 3) Нам нужно изменить функцию потерь, чтобы вычислить потери для графиков.

Полная реализация кода доступна по этой ссылке на Github.

Заключение

GNNExplainer предоставляет основу для визуализации того, что узнала модель GNN. Однако фактический результат объяснения может быть недостаточно хорош для объяснения огромного графа, поскольку пространство поиска для оптимального объяснения экспоненциально больше, чем меньшее. Вместо подбора нейронной сети могут также применяться другие методы поиска, чтобы найти оптимальное объяснение, заимствующее те же концепции, а эффективность еще предстоит доказать.

Ссылка:

GNNExplainer: Генерация объяснений для графических нейронных сетей, https://arxiv.org/abs/1903.03894

Pytorch Geometric, https://pytorch-geometric.readthedocs.io/en/latest/