top of page

DepWiGNN - multi-hop spatial reasoning in text

If you're a CXO, founder or investor - follow me on LinkedIn & Twitter, or join my newsletter on my website. I share latest simplified AI research and tactical advice on building AI products.


Definitions of pre-requisite terms can be found at end of this article, if required.





Practical Uses


1. Startups working on medical image analysis using AI for reasoning can boost their products by adding more powerful spatial reasoning features.


2. Map products can use this to autocomplete their missing details and build real-time richer navigation features for end users.


3. Robotics startups can use this model to enhance spatial navigation capacities of their robots.


4. All LLMs like ChatGPT can use this model to boost their spatial reasoning capabilities.



How this model works


The DepWiNet framework.
The DepWiNet framework.

A pre-trained language model (PLM) is used to extract entity representations from text.

Then a homogeneous graph is constructed based on entity embeddings.


A DepWiGNN model is designed which takes above graph as input and processes to aggregate information depth-wise for all indirectly connected node pairs. This information is stored in every node as node memories.


Then reasoning is performed on these updated node embeddings.



How this model is different


Normal GNNs suffer from over-smoothening problem i.e. adjacent nodes start becoming more similar as more layers are added. This creates a paradox - multi-hop reasoning requires more layers to capture longer dependencies between nodes but these dependencies are lost as more layers are added. This limits performance of these in spatial reasoning from text.


This model solves this problem by introducing a new node memory implementation which stores only depth-path information between nodes, instead of breadth-path as currently used.


This approach enables to capture long dependencies between nodes without stacking layers bypassing over-smoothening problem.



How this model was trained


  1. Problem is defined as follows - Given a story description S consisting of multiple sentences, the system aims to answer a question Q based on the story S by selecting a correct answer from the fixed set of given candidate answers regarding the spatial relations.

  2. A PLM (BERT or it's variations) base version with 768 embedding dimensions was used to extract entity representations. The model took the concatenation of the story S and the question Q as the input and output embeddings of each token. The output embeddings were further projected using a single linear layer.

  3. The entities were first recognized from the input by employing rule-based entity recognition.In StepGame dataset, the entities were represented by a single capitalized letter, so only all single capitalized letters needed to be located. For SPARTUN and ReSQ, nltk RegexpParser, with self designed grammars, was used to recognize entities.

  4. Entities and their embeddings were treated as nodes of the graph, and an edge between two entities existed if and only if the two entities co-occur in the same sentence.

  5. CLS token was added as an edge feature which worked as follows - If two entities are the same (self-loop), the edge feature is just a zero tensor with a size. Otherwise, it is the sequence’s last layer hidden state of the [CLS] token.

  6. This graph was fed to a DepWiGNN module comprising three components:

    1. Node Memory Initialization - firstly, node memories of all nodes are initialized with the relations to their immediate neighbors. Tensor Product Representation (TPR) mechanism was used which uses outer product operations to bind roles and fillers.

    2. Long Dependency Collection - for each pair of indirectly connected nodes, first shortest path between them was found using breadth-first search (BFS). Then, all the existing atomic relation fillers along the path were unbound using the embedding of each node in the path. The collected relation fillers were aggregated using a selected depth aggregator (using LSTM) and passed to a feedforward neural network to reason the relation filler between the source and destination node in the path. The result spatial filler were then bound with the target node embedding and added to the source node memory.

    3. Spatial Relation Retrieval - Spatial information from a source node to a target node could be extracted by unbinding the spatial filler from the source node memory using a self-determined key. The key could be the target node embedding itself if the target node can be easily recognized from the question, or some computationally extracted representation from the sequence of question token embeddings if the target node was hard to discern. Key was used to unbind the spatial relation from all nodes’ memory and pass the concatenation of it with source and target node embeddings to a multilayer perceptron to get the updated node embeddings. The updated node embeddings were then passed to the prediction module to get the final result.

  7. The model was trained in an end-to-end manner using Adam optimizer. The training was stopped if, up to 3 epochs, there is no improvement greater than 1e-3 on the cross-entropy loss for the validation set.

  8. A Pytorch training scheduler was also applied that reduces the learning rate with a factor of 0.1 if the improvement of cross-entropy loss on the validation set is lower than 1e-3 for 2 epochs.

  9. In terms of the determination of the key in the Spatial Relation Retrieval part, target node embedding was used for StepGame since it can be easily recognized, and we employed a single linear layer to extract the key representation from the sum-aggregated question token embeddings for ReSQ.

  10. In the StepGame experiment, the model was fine-tuned on the training set and tested on the test set.

  11. For ReSQ, model was tested on ReSQ with or without further supervision from SPARTUN.


Performance


Results on StepGame
Results on StepGame

Results on ResQ Dataset
Results on ResQ Dataset


Pre-requisite definitions


Spational reasoning



Question-answering about spatial context e.g. front of, back of, left side of, right side of, above etc.


Multi-hop reasoning

Asking questions which require context from multiple paragraphs to formulate the answer.


GNN (Graph Neural Network)

A neural network designed to work on graphs containing data in nodes and relationships in edges.




4 views

Commentaires


bottom of page