Using GraphSAGE embeddings for downstream classification model

If you want to learn more about the training process and the math behind the GraphSAGE algorithm, I suggest you take a look at the An Intuitive Explanation of GraphSAGE blog post by Rıza Özçelik or the official GraphSAGE site.

Using GraphSAGE embeddings for a downstream classification task

Neo4j Graph Data Science library operates entirely on heap memory to enable fast caching for the graph’s topology, containing only relevant nodes, relationships, and weights. Graph algorithms are executed on an in-memory projected graph model, which is separate from Neo4j’s stored graph model.

Photo from Neo4j GDS library documentation, reposted with permission

Before you can execute any graph algorithms, you have to project the in-memory graph via the Graph Loader component. You can use either native projection or cypher projection to load the in-memory graph.

In this example, you will use the native projection feature to load the in-memory graph. To start, you will project the training data and store it as a named graph in the Graph Catalog. The current implementation of the GraphSAGE algorithm supports only node features that are of type Float. For this reason, you will include the decoupled node properties ranging from embedding_0 to embedding_49 in the graph projection instead of a single property embeddings_all, which holds all the node features in the form of a list of Floats. Additionally, you will treat the projected graph as undirected.

UNWIND range(0,49) as i
WITH collect('embedding_' + toString(i)) as embeddings
CALL gds.graph.create('train','Train',
{INTERACTS:{orientation:'UNDIRECTED'}},
{nodeProperties:embeddings})
YIELD graphName, nodeCount, relationshipCount
RETURN graphName, nodeCount, relationshipCount

Next, you will train the GraphSAGE model. The model’s hyper-parameter settings were mostly copied from the original paper. I have noticed that the lower learning-rate setting had the most impact on the downstream classification accuracy. Another import hyper-parameter is the samplingSizes parameter, where the size of the list determines the number of layers (defined as K parameter in the paper), and the values determine how many nodes will be sampled by the layers. Find more information about the available hyper-parameters in the documentation.

UNWIND range(0,49) as i
WITH collect('embedding_' + toString(i)) as embeddings
CALL gds.beta.graphSage.train('train',{
modelName:'proteinModel',
aggregator:'pool',
batchSize:512,
activationFunction:'relu',
epochs:10,
sampleSizes:[25,10],
learningRate:0.0000001,
embeddingDimension:256,
featureProperties:embeddings})
YIELD modelInfo
RETURN modelInfo

The training process took around 20 minutes on my laptop. After the training process finishes, the GraphSAGE model will be stored in the model catalog. You can now use this model to induce node embeddings on any projected graph with the same node properties used during the training. Before testing the downstream classification accuracy, you have to load the test data as an in-memory graph in the Graph Catalog.

UNWIND range(0,49) as i
WITH collect('embedding_' + toString(i)) as embeddings
CALL gds.graph.create('test','Test',
{INTERACTS:{orientation:'UNDIRECTED'}},
{nodeProperties:embeddings})
YIELD graphName, nodeCount, relationshipCount
RETURN graphName, nodeCount, relationshipCount

With the GraphSAGE model trained and both the training and test data projected as an in-memory graph, you can go ahead and calculate the f1 score using the GraphSAGE embeddings in a downstream classification model. Remember, the GraphSAGE model has not observed the test data during the training phase.

Using the GraphSAGE embeddings as feature input to the classification model, you have improved the f1 score to 0.462. You can also try to follow the other examples in the original GraphSAGE paper to hone your graph data science skills.

Takeaways

  • Connections within your data can help you increase the accuracy of your ML models
  • GraphSAGE algorithm can induce embeddings of new unseen nodes, without the need for re-training process

As always, the code is available on GitHub.

[1] Hamilton, Will, Zhitao Ying, and Jure Leskovec. “Inductive representation learning on large graphs.” Advances in Neural Information Processing Systems. 2017.

Favorite

Leave a Comment