Under the hood of graph networks

Graph networks appeared in 2016 and were popular in 2018, but we were not able to find many articles on this topic on Habré. Therefore, after analyzing blogs and publications, it was decided to put together a short summary (and I hope understandable). Let’s consider what operations exist in a graph, what operations there are in graph networks, what pooling is in terms of a graph, and so on.

Ps I have huge problems with proofreading the text, so if you see a typo/error, please let me know. Thank you!

What is a graph?

You are probably familiar with the word graph (Monte Cristo) and what it is. But for readers who are not aware or have forgotten, I will write a few words. Graph – mathematical structurewhich is vertex set (nodes), connected by edges (faces). The vertices of a graph can represent entities, and the edges can represent relationships or connections between these entities. A simple example: you have a page on a social network (you are the top), there are other people on this social network (also peaks). Some people you are friends with (you are connected by edges), also your friends are friends with other people (they are connected to them by ribs) and so on. I think it’s clear.

Example graph

Example graph

Graph in neural networks

Graphs are a universal tool for modeling various structures and phenomena, from social networks to molecular structures. Graphs can be used to describe a variety of systems in a variety of domains, including social networks, physical systems, protein-protein interactions, and knowledge graphs. In the context of graph neural networks (GNNs), graphs can represent the basis for data analysis and processing. To work with graphs within the framework of neural network approaches, we can store information about vertices, edges and the entire graph as a whole, saving embeddings for the corresponding parts.

But what are graphs in the context of visual and text data? Both pictures and text can be represented as graphs. For example, when we look at an image, we can think of it as a graph, where each pixel is connected to its neighbor. Text can also be thought of as a sequential graph, where each word is connected to the previous and subsequent ones.

There are three main types of problems that can be solved using graph neural networks:

  1. Prediction at the whole graph level: here we aim to predict a property for the entire graph. For example, to predict its odor for an entire molecule, if we give an analogy with pictures, then this is a problem of classifying pictures.

  2. Node level prediction: in this case, we aim to predict a property for a specific element in the graph. For example, if we analyze a social network, we can try to predict the interests of the user; in the case of pictures, this is a segmentation task.

  3. Prediction at rib level: here we aim to predict a property between two elements of a graph. For example, in scene analysis, we can predict the relationship between objects or the likelihood of a certain connection existing. This is similar to the scene comprehension task.

How to store the graph?

  1. In the form of an adjacency matrix. But there is a disadvantage that the matrix can be discharged when the data is not strongly related (and this is a memory cost). The second drawback is that for the same connectivity there are different adjacency matrices and there is no guarantee that they will give the same results, although they should.

  2. A more economical option is to store information about the edge that connects vertices (i, j) in the list.

Layer Pass

Layer Pass

Processing Graph Data in GNN

Consider a simple GNN (graph neural network, remember this abbreviation!) without changing connectivity. As input we receive vertex embeddings and pass them through the MLP layer, and as output we get modified embeddings for vertices. We also do this for edges and embedding of an entire graph. As a result, we get a graph with the same size of the edge adjacency list.

For classification at the vertex level it is used linear classifier for embeddings of each vertex. But what if we don’t have information about the vertices, but we need to solve the problem at the vertex level? Let’s say we have information about the edges, for this we apply an operation called pulling. For each vertex, we take the embeddings of its edges, collect them into a matrix and apply some operation to them, for example, sum. Thus, we get embeddings for the vertices, and then simply, as usual, pass them through a linear classifier for prediction.

The same scheme applies to the opposite situation, when we have information about the vertices, but no information about the edges, and the problem is solved at the edge level. If we solve the problem at the level of the whole graph and we only have vertex embeddings, then the operation is applied Global PoolingWhen we we aggregate all available information.

To take into account more information, you can collect information from neighboring vertices by getting their embeddings and applying some function (for example, sum). Or, for example, do pulling first, and then transmitting information (passing message). In a large graph, nodes that are far apart may never receive information about each other, even if we do the passing message step several times, and this is also very computationally expensive. To solve this problem, you can use the so-called master node or contex vector, which consists of all embedding nodes and edges. This creates a more informative representation of the graph.

Translating a node into a low-dimensional graph space

Okay, let’s think about it. Let’s say we have a social network where nodes represent users and edges represent connections between them (for example, friendship). The task is to find out a person’s interests.

How will we vectorize knot? One of the first methods for embedding graph nodes into low-dimensional space is DeepWalk (there are other methods, for example node2vec and so on). The essence of the method is as follows: first, we select a random user and begin a random walk through the network, moving from the user to his random friend (the random node with which he is connected). We then repeat this process many times, creating many random walks across the network.

Once we have collected enough random walk data, we apply a learning algorithm, e.g. SkipGram, which allows us to train a model to predict neighboring users for each user in the network. As a result of this training, each user will have his own vector that represents him in the feature space.

Below are a few words about the SkipGram algorithm, for those who are hearing it for the first time or have forgotten.

The essence of the algorithm is to predict the context for a given node. It works like this:

  1. Data preparation: First, the text is divided into a sequence of graph nodes, which act as training examples.

  2. Creating training pairs: We have a set that we got through DeepWalk.

  3. Model training: The training pairs are fed to the model, which updates the vector representations of the words in a way that minimizes the context prediction error from the center word. That is, minimizing errors – user friends.

  4. Using trained vectors: After training is completed, each node in the graph is assigned its own vector representation.

Vertex vectorization in graph networks such as DeepWalk, makes sense even if we already have information about the properties of the vertices. This is due to the fact that the vector representation of vertices allows them to be used in various machine learning algorithms, such as classification, clustering and predicting connections between vertices. In addition, vectorization preserves the topological relationships between vertices, which is important for understanding the structure of the graph and its analysis.

Anyone familiar with NLP will think of Word2Vec. Essentially, DeepWalk and Word2Vec are analogues in the sense that they both use neural networks to learn vector representations.

Computing modules

Graph networks have 3 main modules with which the architecture is built.

  1. Propagation – we aggregate information from the graph or make our model work (roughly speaking, the forward method)

  2. Sampling

  3. Pooling

Let’s look at each of them in more detail. Let’s start with Propagation, you can use three types of operators: convolutional, recurrent and skip connection (we forward information from previous steps forward)

GNN Components and Examples

GNN Components and Examples

GCN – Graph Convolutional Networks

The challenge is to adapt convolutions from other regions in graphs.

Graph convolutional operations were first introduced in “Semi-Supervised Classification with Graph Convolutional Networks” by Thomas Kipf and Max Wellenberg in 2017.

In GCN, the convolution operation is applied to the features of graph nodes based on their structure. Formally, GCN is defined as follows:

Pass formula for GCN

Pass formula for GCN

So, the formula seems unclear, I fully support it. Let’s try to figure out why she is like this?

  1. Ã – Adjacency matrix of the graph. Extending the adjacency matrix by adding an identity matrix l allows you to take into account information about your own node when aggregating neighbors.

  2. D~ – Diagonal matrix of node degrees Ã, where each element D_ii is equal to the sum of all elements of the i-th row of the matrix Ã. This matrix is ​​used for normalization aggregated features.

  3. H^l – Matrix of characteristics of nodes on l-om layer, which contains features for each node at this level.

  4. W^(l) – Weight matrix for convolution on l-om layer, which applies to aggregated features.

As a result, the convolution operation in GCN aggregates information from neighboring nodes taking into account their characteristics, normalizes this information using the degrees of the nodes, applies a linear transformation using weight matrices and applies nonlinear function activation to obtain the final representation of the nodes on the next layer.

Great, it seems to have become a little clearer, but not quite. Why does multiplication by D happen twice? Why is D raised to the -½ power?

Let’s throw out H and W for simplicity (it’s clear here, everything is the same as in ordinary neural networks – we multiply the weights by the features). What remains is normalization, which is expressed by multiplying three matrices.

So, here we go:

The matrix D~ is raised to the -½ power to invert the degrees of the nodes and normalize across the rows and columns. This helps reduce the influence of nodes with a large number of neighbors and reduces the variability of weights. On the left, multiplication by the matrix à is performed to normalize the rows of the matrix à (normalization by rows), and on the right side to normalize the columns of this result (column normalization). Using left and right multiplication allows you to perform normalization in both directions. This helps account for the different weights of a node’s neighbors and ensures the stability and efficiency of the convolution operation.

GAT – Graph Attention Network

A model of graph neural networks that uses an attention mechanism to efficiently aggregate information from neighboring nodes in a graph.

Unlike classical GCN, which calculates a weighted sum of features of neighboring nodes with constant weights, GAT uses an attention mechanism to dynamic determination of weights based on node content. This allows the model to pay more attention to the more important nodes in the graph. GAT allows each node have different weights depending on contextwhat does the model do more flexible and powerful for processing various graphs.

The hidden state for node h is obtained using the following scary formula:

Pass Formula for GAT

Pass Formula for GAT

α is the vector of weights that is trained together with the model, [Wh_i ∣∣ Wh_j] – concatenation of vectors.

Calmly! Everything is in order.

  • We first apply the learnable parameter matrix W to the vector representations of nodes h_i and h_j. This allows the model to learn to find more informative representations of nodes.

  • Next, we concatenate the two resulting matrices to take into account the information of the node pair i and j together. This allows the model to take into account the interactions between these nodes and use this information when calculating attention weights.

  • Next we multiply by the vector of weights α. This allows the model to estimate the importance of the connection between nodes i and j in the graph.

  • Well, then LeakyReLU (activation function) and this two-story formula is softmax.

Many letters! Let’s recap:

For two nodes: we multiply the node by its own matrix of weights (they are trained) → we combine the received information for 2 nodes → we multiply by the trained matrix of the importance of the connection of these two pairs → activation function → softmax.

In graph networks, multi-head attention can also be used (roughly speaking, when we perform such an operation for a pair of nodes N times, where N is the number of heads).

Mult-head attention GNN

Mult-head attention GNN

Recurrent operator

The recurrent operator is similar to the recurrent layers in recurrent neural networks (RNNs), where information is updated incrementally and depends on the previous state. However, unlike RNN, where the sequence is determined by time, in GNN the sequence is determined graph structure.

Formally, the recurrent update operator for node i and j in the lth layer in a GNN can be expressed as follows:

Passage to GRN

Passage to GRN

The recurrent update operator updates the representation of node i by aggregating the representations of its neighbors j and applying a nonlinear activation function.

Looks pretty simple!

The disadvantage of the operator is that we may lose information.

When we focus on representing nodes in a graph, we want each node to have its own unique and informative representation. However, when using the recurrent operator, the distribution of hidden state values ​​of nodes may become too smooth and uniform. This means that different nodes can have more or less the same hidden state values, making their representations less distinguishable from each other. This can lead to a loss of information about significant differences between nodes and a decrease in the model’s ability to distinguish between them. This occurs due to the process of updating the hidden states of nodes in the graph, which tends to a stable state. Through the process of repeated updating, the information of neighboring nodes is gradually propagated throughout the graph, and the hidden states converge to certain values.

There are also several works attempting to use mechanisms such as GRU and LSTM. In this approach, a GGNN node first aggregates messages from its neighbors. Then update functions similar GRU,integrate information from other nodes and the previous ,time step to update the hidden state of each node. The same goes for LSTM.

To be honest, I have a dislike for recurrent operators, so let me not go into details.

Skip connection

A large number of layers does not always help improve the quality of a neural network. Information is lost and a lot of noise is added. To solve this problem, they came up with a skip connection (remember ResNet). The idea is to add short connections that bypass one or more layers of the neural network, allowing input data or transformed versions of it to pass directly to subsequent layers.

For example, DeepGCNs took the idea from ResNet and DenseNet.

Skip connection

Skip connection

Another example: Highway GCN – where Highway connection blocks are used, which allow you to efficiently transfer information through several layers of the network. The basic idea of ​​Highway GCN is that Highway connection blocks use gates. These gates decide what proportion of information should be passed through the block unchanged and how much should be transformed.

The approach formula is presented below:

Skip connection - Highway GCN

Skip connection – Highway GCN

where T is a gate (function) that takes vector x **** as input and performs the transformation using the W_t parameters (they are trained)

H is a function that performs transformations on x using the parameters W_h (also trained).

Looks a little unclear, especially the right side. Let’s look at it:

The left side of the formula allows you to control which part of the information will be updated using the W_h parameters and which will remain unchanged.

The right side represents the elements of the input vector x that have not been updated. It also determines how much of the original information will be retained unchanged.

Sampling modules

When the graph becomes huge, sampling helps select a subset of nodes or edges for processing. After selecting a subset of data, information propagation occurs only on this subset, which makes calculations faster and more efficient.

  1. Node sampling. The idea is that we select a subset of neighbors for each node. GraphSAGE (model) selects a fixed number of neighbors, providing between 2 and 50 neighbors for each node.

    PinSage offers selection based on importance. By simulating random walks starting from leaf nodes, this approach selects the T nodes with the largest normalized number of visits (the number of visits to a node in a walk divided by the number of nodes in the graph).

  2. Layer sampling. We save a certain number of nodes on each layer.

  3. Subgraph sampling. We select subgraphs accordingly. For example, using subgraph generation (GraphSAINT)

Types of graph

In real problems, graph networks are a complex complex of various data. They can be represented as:

  1. Directed graph. In his case, you can maintain 2 types of scales – W_p and W_c for forward and reverse directions, respectively.

  2. Heterogeneous graphs. Nodes can represent different entities, and edges can represent different types of relationships between these entities. Roughly speaking, a graph does not have one type of data format; it contains complex relationships between various objects of different natures. What to do in this case? Metapaths come to the rescue. A metapath is a sequence of nodes and edges that define a path or relationship between entities in a graph. A metapath captures the similarity of two nodes that may not be directly connected. Thanks to this approach, we get several homogeneous graphs that we already know how to work with. Let me give you an example from recommendation systems: user – rates – movie – genre – movie. There are also methods on edges again.

  3. Dynamic graph. Nodes and edges change over time. Structural-RNN and ST-GCN expand the graph over time by adding temporary connections to the modified graph, and then use methods to work with GNN. DCRNN and STGCN first collect spatial information and then feed it to sequence-to-sequence or RNN models.

Tools

I will list a couple of libraries for working with graph networks:

  1. PyTorch Geometric

  2. Deep Graph Library (DGL)

  3. Graph Nets

  4. Spektral

Literature

A visual article about GNN. I recommend!

Useful summary

GCN

Graph neural networks: A review of methods and applications (arxiv.org)

Graphs in Computer Vision

About libraries

Interesting article with code (Habr)

Yandex about GNN

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *