How to improve the performance of the fused architecture consisting of a tabular transformer and a graph neural network used for representation learning for multimodal data?

More Info
expand_more

Abstract

The substantial amount of tabular data can be attributed to its storage convenience. There is a high demand for learning useful information from the data. To achieve that, machine learning models, called transformers, have been created. They can find patterns in the data, learn from them, and improve their predictive abilities based on that learning experience. There are also tabular transformers for tabular data. In order to attempt to increase the predictive performance of the transformers, we have combined them with graph neural networks (GNNs), which are again machine learning models, which work on graph data by learning information from the nodes and the edges. A graph representation of the dataset is created and input into the graph neural network. The architecture that fuses these two machine learning models is a more complex machine learning model that combines the transformer and the GNN. The aim is to increase the predictive ability of the model for values from the table or to predict whether an edge in the graph exists, which represents whether a transaction between two users exists. We have built the architecture using certain types of a tabular transformer and a graph neural network, FT-Transformer and GINe respectively, and the next step is to try modifying this architecture by using different models, and different ways of using these layers, for example how many copies we are creating of it. This has the potential to be a versatile model than can be used for different kinds of datasets. We have seen notable improvement in performance when using a different GNN, PNA. The transformer ResNet also shows to be on a similar or slightly better performing level than FT-Transformer when not combined with a GNN. GraphSage in the fused model underperforms significantly due to its weakness to capture simple graph structures.