November 19, 2023
The wide variety of molecular types and sizes poses numerous challenges in the computational modeling of molecular systems for drug discovery, structural biology, quantum chemistry, and others [1]. To address these challenges, recent advances in geometric deep learning (GDL) approaches have become increasingly important [2], [3]. Especially, Graph Neural Networks (GNNs) have demonstrated superior performance among various GDL approaches [4]–[6]. GNNs treat each molecule as a graph and perform message passing scheme on it [7]. By representing atoms or groups of atoms like functional groups as nodes, and chemical bonds or any pairwise interactions as edges, molecular graphs can naturally encode the structural information in molecules. In addition to this, GNNs can incorporate symmetry and achieve invariance or equivariance to transformations such as rotations, translations, and reflections [8], which further contributes to their effectiveness in molecular science applications. To enhance their ability to capture molecular structures and increase the expressive power of their models, previous GNNs have utilized auxiliary information such as chemical properties [9]–[12], atomic pairwise distances in Euclidean space [7], [13], [14], angular information [15]–[18], etc.
In spite of the success of GNNs, their application in molecular sciences is still in its early stages. One reason for this is that current GNNs often use targeted inductive bias for modeling a specific type of molecular system, and cannot be directly transferred to other contexts although all molecule structures and their interactions follow the same law of physics. For example, GNNs designed for modeling proteins may include operations that are specific to the structural characteristics of amino acids [19], [20], which are not relevant for other types of molecules. Additionally, GNNs that incorporate comprehensive geometric information can be computationally expensive, making them difficult to scale to tasks involving a large number of molecules (e.g., high-throughput compound screening) or macromolecules (e.g., proteins and RNAs). For instance, incorporating angular information can significantly improve the performance of GNNs [15]–[18], but also increases the complexity of the model, requiring at least \(O(Nk^2)\) messages to be computed where \(N\) and \(k\) denote the number of nodes and the average degree in a graph.
To tackle the limitations mentioned above, we propose a universal GNN framework, Physics-Aware Multiplex Graph Neural Network (PAMNet), for the accurate and efficient representation learning of 3D molecules ranging from small molecules to macromolecules in any molecular system. PAMNet induces a physics-informed bias inspired by molecular mechanics [21], which separately models local and non-local interactions in molecules based on different geometric information. To achieve this, we represent each molecule as a two-layer multiplex graph, where one plex only contains local interactions, and the other plex contains additional non-local interactions. PAMNet takes the multiplex graphs as input and uses different operations to incorporate the geometric information for each type of interaction. This flexibility allows PAMNet to achieve efficiency by avoiding the use of computationally expensive operations on non-local interactions, which consist of the majority of interactions in a molecule. Additionally, a fusion module in PAMNet allows the contribution of each type of interaction to be learned and fused for the final feature or prediction. To preserve symmetry, PAMNet utilizes E(3)-invariant representations and operations when predicting scalar properties, and is extended to predict E(3)-equivariant vectorial properties by considering the geometric vectors in molecular structures that arise from quantum mechanics.
To demonstrate the effectiveness of PAMNet, we conduct a comprehensive set of experiments on a variety of tasks involving different molecular systems, including small molecules, RNAs, and protein-ligand complexes. These tasks include predicting small molecule properties, RNA 3D structures, and protein-ligand binding affinities. We compare PAMNet to state-of-the-art baselines in each task and the results show that PAMNet outperforms the baselines in terms of both accuracy and efficiency across all three tasks. Given the diversity of the tasks and the types of molecules involved, the superior performance of PAMNet shows its versatility to be applied in various real-world scenarios.
Given any 3D molecule or molecular system, we define a multiplex graph representation as the input of our PAMNet model based on the original 3D structure (Fig. 1a). The construction of multiplex graphs is inspired by molecular mechanics [21], in which the molecular energy \(E\) is separately modeled based on local and non-local interactions (Fig. 1c). In detail, the local terms \(E_{\text{bond}}+E_{\text{angle}}+E_{\text{dihedral}}\) model local, covalent interactions including \(E_{\text{bond}}\) that depends on bond lengths, \(E_{\text{angle}}\) on bond angles, and \(E_{\text{dihedral}}\) on dihedral angles. The non-local terms \(E_{\text{vdW}}+E_{\text{electro}}\) model non-local, non-covalent interactions including van der Waals and electrostatic interactions which depend on interatomic distances. Motivated by this, we also decouple the modeling of these two types of interactions in PAMNet. For local interactions, we can define them either using chemical bonds or by finding the neighbors of each node within a relatively small cutoff distance, depending on the given task. For global interactions that contain both local and non-local ones, we define them by finding the neighbors of each node within a relatively large cutoff distance. For each type of interaction, we use a layer to represent all atoms as nodes and the interactions as edges. The resulting layers that share the same group of atoms form a two-layer multiplex graph \(G = \{G_{global}, G_{local}\}\) which represents the original 3D molecular structure (Fig. 1a).
To update the node embeddings in the multiplex graph \(G\), we design two message passing modules that incorporate geometric information: Global Message Passing and Local Message Passing for updating the node embeddings in \(G_{global}\) and \(G_{local}\), respectively (Fig. 1b). These message passing modules are inspired by physical principles from molecular mechanics (Fig. 1c): When modeling the molecular energy \(E\), the terms for local interactions require geometric information including interatomic distances (bond lengths) and angles (bond angles and dihedral angles), while the terms for non-local interactions only require interatomic distances as geometric information. The message passing modules in PAMNet also use geometric information in this way when modeling these interactions (Fig. 1b and Fig. 1e). Specifically, we capture the pairwise distances and angles contained within up to two-hop neighborhoods (Fig. 1d). The Local Message Passing requires the related adjacency matrix \(\boldsymbol{A}_{local}\), pairwise distances \(d_{local}\) and angles \(\theta_{local}\), while the Global Message Passing only needs the related adjacency matrix \(\boldsymbol{A}_{global}\) and pairwise distances \(d_{global}\). Each message passing module then learns the node embeddings \(\boldsymbol{z}_g\) or \(\boldsymbol{z}_l\) in \(G_{global}\) and \(G_{local}\), respectively.
For the operations in our message passing modules, they can preserve different symmetries: E(3)-invariance and E(3)-equivariance, which contain essential inductive bias incorporated by GNNs when dealing with graphs with geometric information [8]. E(3)-invariance is preserved when predicting E(3)-invariant scalar quantities like energies, which remain unchanged when the original molecular structure undergoes any E(3) transformation including rotation, translation, and reflection. To preserve E(3)-invariance, the input node embeddings \(\boldsymbol{\textit{h}}\) and geometric features are all E(3)-invariant. To update these features, PAMNet utilizes operations that can preserve the invariance. In contrast, E(3)-equivariance is preserved when predicting E(3)-equivariant vectorial quantities like dipole moment, which will change according to the same transformation applied to the original molecular structure through E(3) transformation. To preserve E(3)-equivariance, an extra associated geometric vector \(\vec{v} \in \mathbb{R}^3\) is defined for each node. These geometric vectors are updated by operations inspired by quantum mechanics [22], allowing for the learning of E(3)-equivariant vectorial representations. More details about the explanations of E(3)-invariance, E(3)-equivariance, and our operations can be found in Methods.
After updating the node embeddings \(\boldsymbol{z}_g\) or \(\boldsymbol{z}_l\) of the two layers in the multiplex graph \(G\), we design a fusion module with a two-step pooling process to combine \(\boldsymbol{z}_g\) and \(\boldsymbol{z}_l\) for downstream tasks (Figure 1b). In the first step, we design an attention pooling module based on attention mechanism [23] for each hidden layer \(t\) in PAMNet. Since \(G_{global}\) and \(G_{local}\) contains the same set of nodes \(\{\textit{N}\}\), we apply the attention mechanism to each node \(n \in \{\textit{N}\}\) to learn the attention weights (\(\alpha_g^t\) and \(\alpha_l^t\)) between the node embeddings of \(n\) in \(G_{global}\) and \(G_{local}\), which are \(\boldsymbol{z}_g^t\) and \(\boldsymbol{z}_l^t\). Then the attention weights are treated as the importance of \(\boldsymbol{z}_g^t\) and \(\boldsymbol{z}_l^t\) to compute the combined node embedding \(\boldsymbol{z}^t\) in each hidden layer \(t\) based on a weighted summation (Figure 1e). In the second step, the \(\boldsymbol{z}^t\) of all hidden layers are summed together to compute the node embeddings of the original input. If a graph embedding is desired, we compute it using an average or a summation of the node embeddings.
In this section, we will demonstrate the performance of our proposed PAMNet regarding two aspects: accuracy and efficiency. Accuracy denotes how well the model performs measured by the metrics corresponding to a given task. Efficiency denotes the memory consumed and the inference time spent by the model.
To evaluate the accuracy of PAMNet in learning representations of small 3D molecules, we choose QM9, which is a widely used benchmark for the prediction of 12 molecular properties of around 130k small organic molecules with up to 9 non-hydrogen atoms [24]. Mean absolute error (MAE) and mean standardized MAE (std. MAE) [15] are used for quantitative evaluation of the target properties. Besides evaluating the original PAMNet which captures geometric information within two-hop neighborhoods of each node, we also develop a "simple" PAMNet, called PAMNet-s, that utilizes only the geometric information within one-hop neighborhoods. The PAMNet models are compared with several state-of-the-art models including SchNet [13], PhysNet [14], MGCN [25], PaiNN [26], DimeNet++ [16], and SphereNet [27]. More details of the experiments can be found in Methods and Supplementary Information.
We compare the performance of PAMNet with those of the baseline models mentioned above on QM9, as shown in Table 1. PAMNet achieves 4 best and 6 second-best results among all 12 properties, while PAMNet-s achieves 3 second-best results. When evaluating the overall performance using the std. MAE across all properties, PAMNet and PAMNet-s rank 1 and 2 among all models with 10\(\%\) and 5\(\%\) better std. MAE than the third-best model (SphereNet), respectively. From the results, we can observe that the models incorporating only atomic pairwise distance \(d\) as geometric information like SchNet, PhysNet, and MGCN generally perform worse than those models incorporating more geometric information like PaiNN, DimeNet++, SphereNet, and our PAMNet. Besides, PAMNet-s which captures geometric information only within one-hop neighborhoods performs worse than PAMNet which considers two-hop neighborhoods. These show the importance of capturing rich geometric information when representing 3D small molecules. The superior performance of PAMNet models demonstrates the power of our separate modeling of different interactions in molecules and the effectiveness of the message passing modules designed.
When predicting dipole moment \(\mu\) as a scalar value, which is originally an E(3)-equivariant vectorial property \(\vec{\mu}\), PAMNet preserves the E(3)-equivariance to directly predict \(\vec{\mu}\) first and then takes the magnitude of \(\vec{\mu}\) as the final prediction. As a result, PAMNet and PAMNet-s all get lower MAE (10.8 mD and 11.3 mD) than the previous best result (12 mD) achieved by PaiNN, which is a GNN with equivariant operations for predicting vectorial properties. Note that the remaining baselines all directly predict dipole moment as a scalar property by preserving invariance. We also examine that by preserving invariance in PAMNet and directly predicting dipole moment as a scalar property, the MAE (24.0 mD) is much higher than the equivariant version. These results demonstrate that preserving equivariance is more helpful than preserving invariance for predicting dipole moments.
Property | Unit | SchNet | PhysNet | MGCN | PaiNN | DimeNet++ | SphereNet | PAMNet-s | PAMNet |
---|---|---|---|---|---|---|---|---|---|
\(\mu\) | mD | 21 | 52.9 | 56 | 12 | 29.7 | 24.5 | 11.3 | 10.8 |
\(\alpha\) | \(a_0^3\) | 0.124 | 0.0615 | 0.030 | 0.045 | 0.0435 | 0.0449 | 0.0466 | 0.0447 |
\(\epsilon_{\text{HOMO}}\) | meV | 47 | 32.9 | 42.1 | 27.6 | 24.6 | 22.8 | 23.9 | 22.8 |
\(\epsilon_{\text{LUMO}}\) | meV | 39 | 24.7 | 57.4 | 20.4 | 19.5 | 18.9 | 20.0 | 19.2 |
\(\Delta\epsilon\) | meV | 74 | 42.5 | 64.2 | 45.7 | 32.6 | 31.1 | 32.4 | 31.0 |
\(\left\langle R^{2}\right\rangle\) | \(a_0^2\) | 0.158 | 0.765 | 0.11 | 0.066 | 0.331 | 0.268 | 0.094 | 0.093 |
ZPVE | meV | 1.616 | 1.39 | 1.12 | 1.28 | 1.21 | 1.12 | 1.24 | 1.17 |
\(U_0\) | meV | 12 | 8.15 | 12.9 | 5.85 | 6.32 | 6.26 | 6.05 | 5.90 |
\(U\) | meV | 12 | 8.34 | 14.4 | 5.83 | 6.28 | 6.36 | 6.08 | 5.92 |
\(H\) | meV | 12 | 8.42 | 16.2 | 5.98 | 6.53 | 6.33 | 6.19 | 6.04 |
\(G\) | meV | 13 | 9.40 | 14.6 | 7.35 | 7.56 | 7.78 | 7.34 | 7.14 |
\(c_v\) | \(\frac{\mathrm{cal}}{\mathrm{mol} \mathrm{K}}\) | 0.034 | 0.0280 | 0.038 | 0.024 | 0.0230 | 0.0215 | 0.0234 | 0.0231 |
std. MAE | \(\%\) | 1.78 | 1.37 | 1.89 | 1.01 | 0.98 | 0.91 | 0.87 | 0.83 |
Besides small molecules, we further apply PAMNet to predict RNA 3D structures for evaluating the accuracy of PAMNet in learning representations of 3D macromolecules. Following the previous works [28]–[30], we refer the prediction to be the task of identifying accurate structural models of RNA from less accurate ones: Given a group of candidate 3D structural models generated based on an RNA sequence, a desired model that serves as a scoring function needs to distinguish accurate structural models among all candidates. We use the same datasets as those used in [30], which include a dataset for training and a benchmark for evaluation. The training dataset contains 18 relatively older and smaller RNA molecules experimentally determined [31]. Each RNA is used to generate 1000 structural models via the Rosetta FARFAR2 sampling method [29]. The benchmark for evaluation contains relatively newer and larger RNAs, which are the first 21 RNAs in the RNA-Puzzles structure prediction challenge [32]. Each RNA is used to generate at least 1500 structural models using FARFAR2, where 1\(\%\) of the models are near-native (i.e., within a 2\(\text{\normalfont\AA}\) RMSD of the experimentally determined native structure). In practice, each scoring function predicts the root mean square deviation (RMSD) from the unknown true structure for each structural model. A lower RMSD would suggest a more accurate structural model predicted. We compare PAMNet with four state-of-the-art baselines: ARES [30], Rosetta (2020 version) [29], RASP [33], and 3dRNAscore [28]. Among the baselines, only ARES is a deep learning-based method, and is a GNN using equivariant operations. More details of the experiments are introduced in Methods and Supplementary Information.
On the RNA-Puzzles benchmark for evaluation, PAMNet significantly outperforms all other four scoring functions as shown in Figure 2. When comparing the best-scoring structural model of each RNA (Figure 2a), the probability of the model to be near-native (<2\(\text{\normalfont\AA}\) RMSD from the native structure) is 90\(\%\) when using PAMNet, compared with 62, 43, 33, and 5\(\%\) for ARES, Rosetta, RASP, and 3dRNAscore, respectively. As for the 10 best-scoring structural models of each RNA (Figure 2b), the probability of the models to include at least one near-native model is 90\(\%\) when using PAMNet, compared with 81, 48, 48, and 33\(\%\) for ARES, Rosetta, RASP, and 3dRNAscore, respectively. When comparing the rank of the best-scoring near-native structural model of each RNA (Figure 2c), the geometric mean of the ranks across all RNAs is 1.7 for PAMNet, compared with 3.6, 73.0, 26.4, and 127.7 for ARES, Rosetta, RASP, and 3dRNAscore, respectively. The lower mean rank of PAMNet indicates that less effort is needed to go down the ranked list of PAMNet to include one near-native structural model. A more detailed analysis of the near-native ranking task can be found in Supplementary Figure 5.
Model | RMSE \(\downarrow\) | MAE \(\downarrow\) | SD \(\downarrow\) | R \(\uparrow\) | |
---|---|---|---|---|---|
ML-based | LR | 1.675 (0.000) | 1.358 (0.000) | 1.612 (0.000) | 0.671 (0.000) |
SVR | 1.555 (0.000) | 1.264 (0.000) | 1.493 (0.000) | 0.727 (0.000) | |
RF-Score | 1.446 (0.008) | 1.161 (0.007) | 1.335 (0.010) | 0.789 (0.003) | |
CNN-based | Pafnucy | 1.585 (0.013) | 1.284 (0.021) | 1.563 (0.022) | 0.695 (0.011) |
OnionNet | 1.407 (0.034) | 1.078 (0.028) | 1.391 (0.038) | 0.768 (0.014) | |
GNN-based | GraphDTA | 1.562 (0.022) | 1.191 (0.016) | 1.558 (0.018) | 0.697 (0.008) |
SGCN | 1.583 (0.033) | 1.250 (0.036) | 1.582 (0.320) | 0.686 (0.015) | |
GNN-DTI | 1.492 (0.025) | 1.192 (0.032) | 1.471 (0.051) | 0.736 (0.021) | |
D-MPNN | 1.493 (0.016) | 1.188 (0.009) | 1.489 (0.014) | 0.729 (0.006) | |
MAT | 1.457 (0.037) | 1.154 (0.037) | 1.445 (0.033) | 0.747 (0.013) | |
DimeNet | 1.453 (0.027) | 1.138 (0.026) | 1.434 (0.023) | 0.752 (0.010) | |
CMPNN | 1.408 (0.028) | 1.117 (0.031) | 1.399 (0.025) | 0.765 (0.009) | |
SIGN | 1.316 (0.031) | 1.027 (0.025) | 1.312 (0.035) | 0.797 (0.012) | |
Ours | PAMNet | 1.263 (0.017) | 0.987 (0.013) | 1.261 (0.015) | 0.815 (0.005) |
Dataset | Model | Memory (GB) | Inference Time (s) |
---|---|---|---|
QM9 | DimeNet++ | 21.1 | 11.3 |
SphereNet | 22.7 | 11.1 | |
PAMNet-s | 6.0 | 7.3 | |
PAMNet | 6.2 | 11.0 | |
RNA-Puzzles | ARES | 13.5 | 2.1 |
PAMNet | 7.8 | 0.6 | |
PDBbind | SIGN | 19.7 | 12.0 |
PAMNet | 13.1 | 1.8 |
In this experiment, we evaluate the accuracy of PAMNet in representing the complexes that contain both small molecules and macromolecules. We use PDBbind, which is a well-known public database of experimentally measured binding affinities for protein-ligand complexes [34]. The goal is to predict the binding affinity of each complex based on its 3D structure. We use the PDBbind v2016 dataset and preprocess each original complex to a structure that contains around 300 nonhydrogen atoms on average with only the ligand and the protein residues within 6\(\text{\normalfont\AA}\) around it. To comprehensively evaluate the performance, we use Root Mean Square Error (RMSE), Mean Absolute Error (MAE), Pearson’s correlation coefficient (R) and the standard deviation (SD) in regression following [18]. PAMNet is compared with various comparative methods including machine learning-based methods (LR, SVR, and RF-Score [35]), CNN-based methods (Pafnucy [36] and OnionNet [37]), and GNN-based methods (GraphDTA [38], SGCN [39], GNN-DTI [40], D-MPNN [12], MAT [41], DimeNet [15], CMPNN [42], and SIGN [18]). More details of the experiments are provided in Methods and Supplementary Information.
We list the results of all models and compare their performance in Table 2 and Supplementary Table 5. PAMNet achieves the best performance regarding all 4 evaluation metrics in our experiment. When compared with the second-best model, SIGN, our PAMNet performs significantly better with p-value < 0.05. These results clearly demonstrate the accuracy of our model when learning representations of 3D macromolecule complexes.
In general, we find that the models with explicitly encoded 3D geometric information like DimeNet, SIGN, and our PAMNet outperform the other models without the information directly encoded. An exception is that DimeNet cannot beat CMPNN. This might be because DimeNet is domain-specific and is originally designed for small molecules rather than macromolecule complexes. In contrast, our proposed PAMNet is more flexible to learn representations for various types of molecular systems. The superior performance of PAMNet for predicting binding affinity relies on the separate modeling of local and non-local interactions. For protein-ligand complexes, the local interactions mainly capture the interactions inside the protein and the ligand, while the non-local interactions can capture the interactions between protein and ligand. Thus PAMNet is able to effectively handle diverse interactions and achieve accurate results.
To evaluate the efficiency of PAMNet, we compare it to the best-performed baselines in each task regarding memory consumption and inference time and summarize the results in Table 3. Theoretically, DimeNet++, SphereNet, and SIGN all require \(O(Nk^2)\) messages in message passing, while our PAMNet requires \(O(N(k_g+{k_l}^2))\) messages instead, where \(N\) is the number of nodes, \(k\) is the average degree in a graph, \(k_g\) and \(k_l\) denotes the average degree in \(G_g\) and \(G_l\) in the corresponding multiplex graph \(G\). When \(k_g \sim k\) and \(k_l \ll k_g\), PAMNet is much more efficient regarding the number of messages involved. A more detailed analysis of computational complexity is included in Methods. Based on the results in Table 3 empirically, we find PAMNet models all require less memory consumption and inference time than the best-performed baselines in all three tasks, which matches our theoretical analysis. We also compare the memory consumption when using a different largest cutoff distance \(d\) of the related models in Figure 3. From the results, we observe that the memory consumed by DimeNet and SIGN increases much faster than PAMNet when \(d\) increases. When fixing \(d=5\text{\normalfont\AA}\) as an example, PAMNet requires 80\(\%\) and 71\(\%\) less memory than DimeNet and SIGN, respectively. Thus PAMNet is much more memory-efficient and is able to capture longer-range interactions than these baselines with restricted resources. The efficiency of PAMNet models comes from the separate modeling of local and non-local interactions in 3D molecular structures. By doing so, when modeling the non-local interactions, which make up the majority of all interactions, we utilize a relatively efficient message passing scheme that only encodes pairwise distances \(d\) as the geometric information. Thus when compared with the models that require more comprehensive geometric information when modeling all interactions, PAMNet significantly reduces the computationally expensive operations. More information about the details of experimental settings is included in Methods.
To figure out whether all of the components in PAMNet, including the fusion module and the message passing modules, contribute to the performance of PAMNet, we conduct an ablation study by designing PAMNet variants. Without the attention pooling, we use the averaged results from the message passing modules in each hidden layer to build a variant. We also remove either the Local Message Passing or the Global Message Passing for investigation. The performances of all PAMNet variants are evaluated on the three benchmarks. Specifically, the std. MAE across all properties on QM9, the geometric mean of the ranks across all RNAs on RNA-Puzzles, and the four metrics used in the experiment on PDBbind are computed for comparison. The results in Figure 4 show that all variants decrease the performance of PAMNet in the evaluations, which clearly validates the contributions of all those components. Detailed results of the properties on QM9 can be found in Supplementary Table 6.
A salient property of PAMNet is the incorporation of the attention mechanism in the fusion module, which takes the importance of node embeddings in \(G_{local}\) and \(G_{global}\) of \(G\) into consideration in learning combined node embeddings. Recall that for each node \(n\) in the set of nodes \(\{N\}\) in \(G\), the attention pooling in the fusion module learns the attention weights \(\alpha_l\) and \(\alpha_g\) between \(n\)’s node embedding \(\boldsymbol{z}_{l}\) in \(G_{local}\) and \(n\)’s node embedding \(\boldsymbol{z}_{g}\) in \(G_{global}\). \(\alpha_l\) and \(\alpha_g\) serve as the importance of \(\boldsymbol{z}_{l}\) and \(\boldsymbol{z}_{g}\) when computing the combined node embedding \(\boldsymbol{z}\). To better understand the contribution of \(\boldsymbol{z}_{l}\) and \(\boldsymbol{z}_{g}\), we conduct a detailed analysis of the learned attention weights \(\alpha_l\) and \(\alpha_g\) in the three tasks we experimented with. Since the node embeddings are directly related to the involved interactions, such analysis can also reveal the contribution of local and global interactions on the predictions in different tasks. In each task, we take an average of all \(\alpha_l\) or \(\alpha_g\) to be the overall importance of the corresponding group of interactions. Then we compare the computed average attention weights \(\overline{\alpha_l}\) and \(\overline{\alpha_g}\) and list the results in Table 4. A higher attention weight in each task indicates a stronger contribution of the corresponding interactions on solving the task.
For the targets being predicted in QM9, we find that all of them have \(\overline{\alpha_l} \geq \overline{\alpha_g}\) except the electronic spatial extent \(\left\langle R^{2}\right\rangle\), indicating a stronger contribution of the local interactions, which are defined by chemical bonds in this task. This may be because QM9 contains small molecules with only up to 9 non-hydrogen atoms, local interactions can capture a considerable portion of all atomic interactions. However, when predicting electronic spatial extent \(\left\langle R^{2}\right\rangle\), we notice that \(\overline{\alpha_l} < \overline{\alpha_g}\), which suggests that \(\left\langle R^{2}\right\rangle\) is mainly affected by the global interactions that are the pairwise interactions within \(10\text{\normalfont\AA}\) in this case. This is not surprising since \(\left\langle R^{2}\right\rangle\) is the electric field area affected by the ions in the molecule, and is directly related to the diameter or radius of the molecule. Besides, previous study [43] has demonstrated that graph properties like diameter and radius cannot be computed by message passing-based GNNs that rely entirely on local information, and additional global information is needed. Thus it is expected that global interactions have a stronger contribution than local interactions on predicting electronic spatial extent.
For the RNA 3D structure prediction on RNA-Puzzles and the protein-ligand binding affinity prediction on PDBbind, we find \(\overline{\alpha_l} < \overline{\alpha_g}\) in both cases, which indicates that global interactions play a more important role than local interactions. It is because the goals of these two tasks highly rely on global interactions, which are necessary for representing the global structure of RNA when predicting RNA 3D structure, and are crucial for capturing the relationships between protein and ligand when predicting binding affinity.
QM9 | PDBbind | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2-12 | \(\mu\) | \(\alpha\) | \(\epsilon_{\text{HOMO}}\) | \(\epsilon_{\text{LUMO}}\) | \(\left\langle R^{2}\right\rangle\) | ZPVE | \(U_0\) | \(U\) | \(H\) | \(G\) | \(c_v\) | ||
\(\overline{\alpha_l}\) | 0.64 | 0.53 | 0.50 | 0.50 | 0.29 | 0.54 | 0.60 | 0.60 | 0.60 | 0.57 | 0.58 | 0.22 | 0.34 |
\(\overline{\alpha_g}\) | 0.36 | 0.47 | 0.50 | 0.50 | 0.71 | 0.46 | 0.40 | 0.40 | 0.40 | 0.43 | 0.42 | 0.78 | 0.66 |
In this work, we tackle the limitations of previous GNNs regarding their limited applicability and inefficiency for representation learning of molecular systems with 3D structures and propose a universal framework, PAMNet, to accurately and efficiently learn the representations of 3D molecules in any molecular system. PAMNet explicitly models local and non-local interaction as well as their combined effects inspired by molecular mechanics. The resulting framework incorporates rich geometric information like distances and angles when modeling local interactions, and avoids using expensive operations on modeling non-local interactions. Besides, PAMNet learns the contribution of different interactions to combine the updated node embeddings for the final output. When designing the aforementioned operations in PAMNet, we preserve E(3)-invariance for scalar output and preserve E(3)-equivariance for vectorial output to enable more applicable cases. In our experiments, we evaluate the performance of PAMNet with state-of-the-art baselines on various tasks involving different molecular systems, including small molecules, RNAs, and protein-ligand complexes. In each task, PAMNet outperforms the corresponding baselines in terms of both accuracy and efficiency. These results clearly demonstrate the generalization power of PAMNet even though non-local interactions in molecules are modeled with only pairwise distances as geometric information.
An under-investigated aspect of our proposed PAMNet is that PAMNet preserves E(3)-invariance in operations when predicting scalar properties while requiring additional representations and operations to preserve E(3)-equivariance for vectorial properties. Considering that various equivariant GNNs have been proposed for predicting either scalar or vectorial properties solely by preserving equivariance, it would be worth extending the idea in PAMNet to equivariant GNNs with a potential to further improve both accuracy and efficiency. Another interesting direction is that although we only experiment PAMNet on single-task learning, PAMNet is promising to be used in multi-task learning across diverse tasks that involve molecules of varying sizes and types to gain better generalization. Besides using PAMNet for predicting physiochemical properties of molecules, PAMNet can be used as a universal building block for the representation learning of molecular systems in various molecular science problems. Another promising application of PAMNet is self-supervised learning for molecular systems with few labeled data (e.g., RNA structures). For example, we can use the features in one graph layer to learn properties in another graph layer by utilizing the multiplex nature of PAMNet.
In this section, we will describe PAMNet in detail, including the involved features, embeddings, and operations.
The input features of PAMNet include atomic features and geometric information as shown in Figure 1b. For atomic features, we use only atomic numbers \(Z\) for the tasks on QM9 and RNA-Puzzles following [13]–[16], [30], and use 18 chemical features like atomic numbers, hybridization, aromaticity, partial charge, etc., for the task on PDBbind following [18], [36]. The atomic numbers \(Z\) are represented by randomly initialized, trainable embeddings according to [13]–[16]. For geometric information, we capture the needed pairwise distances and angles in the multiplex molecular graph \(G\) as shown in Figure 1d. The features (\(d\), \(\theta\)) for the distances and angles are computed with the basis functions in [15] to reduce correlations. For the prediction of vectorial properties, we use the atomic position \(\vec{r}\) to be the initial associated geometric vector \(\vec{v}\) of each atom.
In the message passing scheme [7], the update of node embeddings \(\boldsymbol{h}\) relies on the passing of the related messages \(\boldsymbol{m}\) between nodes. In PAMNet, we define the input message embeddings \(\boldsymbol{m}\) of message passing schemes with the following way: \[\begin{align} \boldsymbol{m}_{ji} &= \mathrm{MLP}_m([\boldsymbol{h}_{j} | \boldsymbol{h}_{i} | \boldsymbol{e}_{ji}]), \end{align}\] where \(i, j \in G_{global}\) or \(G_{local}\) are connected nodes that can define a message embedding, \(\mathrm{MLP}\) denotes the multi-layer perceptron, \(|\) denotes the concatenation operation. The edge embedding \(\boldsymbol{e}_{ji}\) encodes the corresponding pairwise distance \(d\) between node \(i, j\).
As depicted in Figure 1e, the Global Message Passing in each hidden layer of PAMNet, which consists of a message block and an update block, updates the node embeddings \(\boldsymbol{h}\) in \(G_{global}\) by using the related adjacency matrix \(\boldsymbol{A}_{global}\) and pairwise distances \(d_{global}\). The message block is defined as below to perform the message passing operation: \[\begin{align} \boldsymbol{h}_{i}^{t} &= \boldsymbol{h}_{i}^{t-1} + \sum\nolimits_{j \in \mathcal{N}(i)} \boldsymbol{m}_{ji}^{t-1}\odot \phi_{d}(\boldsymbol{e}_{j i}), \label{node95update95g} \end{align}\tag{1}\] where \(i, j \in G_{global}\), \(\phi_{d}\) is a learnable function, \(\boldsymbol{e}_{ji}\) is the embedding of pairwise distance \(d\) between node \(i, j\), and \(\odot\) denotes the element-wise production. After the message block, an update block is used to compute the node embeddings \(\boldsymbol{h}\) for the next layer as well as the output \(\boldsymbol{z}\) for this layer. We define the update block using a stack of three residual blocks, where each residual block consists of a two-layer MLP and a skip connection across the MLP. There is also a skip connection between the input of the message block and the output of the first residual block. After the residual blocks, the updated node embeddings \(\boldsymbol{h}\) are passed to the next layer. For the output \(\boldsymbol{z}\) of this layer to be combined in the fusion module, we further use a three-layer MLP to get \(\boldsymbol{z}\) with desired dimension size.
For the updates of node embeddings \(\boldsymbol{h}\) in \(G_{local}\), we incorporate both pairwise distances \(d_{local}\) and angles \(\theta_{local}\) as shown in Figure 1e. To capture \(\theta_{local}\), we consider up to the two-hop neighbors of each node. In Figure 1d, we show an example of the angles we considered: Some angles are between one-hop edges and two-hop edges (e.g. \(\angle i j_1 k_1\)), while the other angles are between one-hop edges (e.g. \(\angle j_1 i j_2\)). Compared to previous GNNs [15]–[17] that incorporate only part of these angles, our PAMNet is able to encode the geometric information more comprehensively. In the Local Message Passing, we also use a message block and an update block following the design of the Global Message Passing as shown in Figure 1e. However, the message block is defined differently as the one in the Global Message Passing to encode additional angular information: \[\begin{align} \boldsymbol{m}_{ji}^{'t-1} &= \boldsymbol{m}_{ji}^{t-1} + \sum_{j' \in \mathcal{N}(i)\setminus\{j\}} \boldsymbol{m}_{j'i}^{t-1} \odot \phi_{d}(\boldsymbol{e}_{j'i}) \odot \phi_{\theta}(\boldsymbol{a}_{j'i, j i}) + \sum_{k \in \mathcal{N}(j)\setminus\{i\}} \boldsymbol{m}_{kj}^{t-1} \odot \phi_{d}(\boldsymbol{e}_{kj}) \odot \phi_{\theta}(\boldsymbol{a}_{k j, j i}), \tag{2} \\ \boldsymbol{h}_{i}^{t} &= \boldsymbol{h}_{i}^{t-1} + \sum_{j \in \mathcal{N}(i)} \boldsymbol{m}_{ji}^{'t-1}\odot \phi_{d}(\boldsymbol{e}_{ji}), \tag{3} \end{align}\] where \(i, j, k \in G_{local}\), \(\boldsymbol{e}_{ji}\) is the embedding of pairwise distance \(d\) between node \(i, j\), \(\boldsymbol{a}_{k j, j i}\) is the embedding of angle \(\theta_{k j, j i}=\angle kji\) defined by node \(i, j, k\), and \(\phi_{d}, \phi_{\theta}\) are learnable functions. In Equation (2 ), we use two summation terms to separately encode the angles in different hops with the associated pairwise distances to update \(\boldsymbol{m}_{ji}\). Then in Equation (3 ), the updated message embeddings \(\boldsymbol{m}_{ji}'\) are used to perform message passing. After the message block, we use the same update block as the one used in the Global Message Passing for updating the learned node embeddings.
The fusion module consists of two steps of pooling as shown in Figure 1b. In the first step, attention pooing is utilized to learn the combined embedding \(\boldsymbol{z}^{t}\) based on the output node embeddings \(\boldsymbol{z}_{g}^{t}\) and \(\boldsymbol{z}_{l}^{t}\) in each hidden layer \(t\). The detailed architecture of attention pooling is illustrated in Figure 1e. We first compute the attention weight \(\alpha_{\textcolor{black}{p},i}\) on node \(i\) that measures the contribution of the results from plex or graph layer \(\textcolor{black}{p} \in \{g, l\}\) in multiplex graph \(G\): \[\begin{align} \alpha_{\textcolor{black}{p},i}^{t} = \frac{\exp(\operatorname {LeakyReLU} (\boldsymbol{W}_{\textcolor{black}{p}}^t \boldsymbol{z}_{\textcolor{black}{p},i}^t))}{\sum_{\textcolor{black}{p}}\exp( \operatorname{LeakyReLU}(\boldsymbol{W}_{\textcolor{black}{p}}^t \boldsymbol{z}_{\textcolor{black}{p},i}^t)) },\label{softmax} \end{align}\tag{4}\] where \(\boldsymbol{W}_{\textcolor{black}{p}}^t \in \mathbb{R}^{1\times F}\) is a learnable weight matrix different for each hidden layer \(t\) and graph layer \(\textcolor{black}{p}\), and \(F\) is the dimension size of \(\boldsymbol{z}_{\textcolor{black}{p},i}^t\). With \(\alpha_{\textcolor{black}{p},i}^{t}\), we can compute the combined node embedding \(\boldsymbol{z}_i^t\) of node \(i\) using a weighted summation: \[\begin{align} \boldsymbol{z}_i^t = \sum\nolimits_{\textcolor{black}{p}} \alpha_{\textcolor{black}{p},i}^{t} (\boldsymbol{W}_{\textcolor{black}{p}}^{'t}\boldsymbol{z}_{\textcolor{black}{p},i}^t), \label{weight95sum} \end{align}\tag{5}\] where \(\boldsymbol{W}_{\textcolor{black}{p}}^{'t} \in \mathbb{R}^{D\times F}\) is a learnable weight matrix different for each hidden layer \(t\) and graph layer \(\textcolor{black}{p}\), \(D\) is the dimension size of \(\boldsymbol{z}_i^t\), and \(F\) is the dimension size of \(\boldsymbol{z}_{\textcolor{black}{p},i}^t\).
In the second step of the fusion module, we sum the combined node embedding \(\boldsymbol{z}\) of all hidden layers to compute the final node embeddings \(\boldsymbol{y}\). If a graph-level embedding \(\boldsymbol{y}\) is desired, we compute as follows: \[\begin{align} \boldsymbol{y} = \sum\nolimits_{i=1}^{N}\sum\nolimits_{t=1}^{T} \boldsymbol{z}_i^t.\label{sum} \end{align}\tag{6}\]
For the operations described above, they preserve the E(3)-invariance of the input atomic features and geometric information and can predict E(3)-invariant scalar properties. To predict E(3)-equivariant vectorial property \(\vec{u}\), we introduce an associated geometric vector \(\vec{v}_i\) for each node \(i\) and extend PAMNet to preserve the E(3)-equivariance for learning \(\vec{u}\). In detail, the associated geometric vector \(\vec{v}_i^{t}\) of node \(i\) in hidden layer \(t\) is defined as: \[\begin{align} \vec{v}_i^{t} = f_v(\{\boldsymbol{h}^{t}\}, \{\vec{r}\}),\label{vector} \end{align}\tag{7}\] where \(\{\boldsymbol{h}^{t}\}\) denotes the set of learned node embeddings of all nodes in hidden layer \(t\), \(\{\vec{r}\}\) denotes the set of position vectors of all nodes in 3d coordinate space, and \(f_v\) is a function that preserves the E(3)-equivariance of \(\vec{v}_i^{t}\) with respect to \(\{\vec{r}\}\). Equation (7 ) is computed after each message passing module in PAMNet.
To predict a final vectorial property \(\vec{u}\), we modify Equation (5 ) and (6 ) in the fusion module as the following operations: \[\begin{align} \vec{u}_{i}^{t} &= \sum\nolimits_{\textcolor{black}{p}} \alpha_{\textcolor{black}{p},i}^{t} (\boldsymbol{W}_{\textcolor{black}{p}}^{'t}\boldsymbol{z}_{\textcolor{black}{p},i}^t) \vec{v}_{\textcolor{black}{p},i}^t,\tag{8} \\ \vec{u} &= \sum\nolimits_{i=1}^{N}\sum\nolimits_{t=1}^{T} \vec{u}_{i}^{t},\tag{9} \end{align}\] where \(\vec{v}_{\textcolor{black}{p},i}^t\) is the associated geometric vector of node \(i\) on graph layer \(\textcolor{black}{p}\) in hidden layer \(t\), \(\vec{u}_{i}^{t}\) is the learned vector of node \(i\) in hidden layer \(t\), and \(\boldsymbol{W}_{\textcolor{black}{p}}^{'t} \in \mathbb{R}^{1\times F}\) is a learnable weight matrix different for each hidden layer \(t\) and graph layer \(\textcolor{black}{p}\). In Equation (8 ), we multiply \(\vec{v}_{\textcolor{black}{p},i}^t\) with the learned scalar node contributions. In Equation (9 ), we sum all node-level vectors in all hidden layers to compute the final prediction \(\vec{u}\).
For predicting dipole moment \(\vec{\mu}\) , which is an E(3)-equivariant vectorial property that describes the net molecular polarity in electric field, we design \(f_v\) in Equation (7 ) as motivated by quantum mechanics [44]. The conventional method to compute molecular dipole moment involves approximating electronic charge densities as concentrated at each atomic position, resulting in \(\vec{\mu}=\sum\nolimits_{i}\vec{r}_{c,i}q_i\), where \(q_i\) is the partial charge of node \(i\), and \(\vec{r}_{c,i}=\vec{r}_i - (\sum\nolimits_{i}\vec{r}_i)/N\) is the relative atomic position of node \(i\). However, this approximation is not accurate enough. Instead, we use a more accurate approximation by adding dipoles onto atomic positions in the distributed multipole analysis (DMA) approach [22]. This results in the dipole moment equation: \(\vec{\mu}=\sum\nolimits_{i}(\vec{r}_{c,i}q_i+\vec{\mu}_i)\), where \(\vec{\mu}_i\) is the associated partial dipole of node \(i\). The equation can be rewritten as \(\vec{\mu}=\sum\nolimits_{i}f_v(\vec{r}_{i})q_i\), where \(q_i\) is the scalar atomic contribution that can be modeled by an invariant fashion. By treating \(f_v(\vec{r}_{i})\) as \(\vec{v}_{i}^t\), the equation has a similar format as a combination of Equation (8 ) and Equation (9 ). We update \(\vec{v}_{i}^t\) in the following way: \[\begin{align} \vec{v}_{i}^{t} =\sum\nolimits_{j \in \mathcal{N}(i)}(\vec{r}_{i} - \vec{r}_{j})\lVert \boldsymbol{m}_{i j}^{t}\rVert, \end{align}\] where \(\lVert \cdot \rVert\) denotes the L2 norm. Since \(\vec{v}_{i}^{t}\) as well as \(\vec{\mu}\) are computed by a linear combination of \(\{\vec{r}\}\), our PAMNet can preserve E(3)-equivariance with respect to \(\{\vec{r}\}\) when performing the prediction.
We analyze the computational complexity of PAMNet by addressing the number of messages. We denote the cutoff distance when creating the edges as \(d_g\) and \(d_l\) in \(G_g\) and \(G_l\). The average degree is \(k_g\) in \(G_g\) and is \(k_l\) in \(G_l\). In each hidden layer of PAMNet, Global Message Passing needs \(O(Nk_g)\) messages because it requires one message for each pairwise distance between the central node and its one-hop neighbor. While Local Message Passing requires one message for each one-hop or two-hop angle around the central node. The number of angles can be estimated as follows: For \(k\) edges connected to a node, they can define \((k(k-1))/2\) angles which result in a complexity of \(O(Nk^2)\). The number of one-hop angles and two-hop angles all has such complexity. So that Local Message Passing needs \(O(2N{k_l}^2)\) messages. In total, PAMNet requires the computation of \(O(Nk_g+2N{k_l}^2)\) messages in each hidden layer, while previous approaches [15]–[18], [27] require \(O(N{k_g}^2)\) messages. For 3D molecules, we have \(k_g \propto {d_g}^3\) and \(k_l \propto {d_l}^3\). With proper choices of \(d_l\) and \(d_g\), we have \(k_l \ll k_g\). In such cases, our model is more efficient than the related GNNs. We here list the comparison of the number of messages needed in our experiments as an example: On QM9 with \(d_g=5\text{\normalfont\AA}\), our model needs 0.5k messages/molecule on average, while DimeNet++ needs 4.3k messages. On PDBBind with \(d_l=2\text{\normalfont\AA}\) and \(d_g=6\text{\normalfont\AA}\), our model needs only 12k messages/molecule on average, while DimeNet++ needs 264k messages.
For QM9, we use the source provided by [24]. Following the previous works [15]–[17], we process QM9 by removing about 3k molecules that fail a geometric consistency check or are difficult to converge [45]. For properties \(U_0\), \(U\), \(H\), and \(G\), only the atomization energies are used by subtracting the atomic reference energies as in [15]–[17]. For property \(\Delta \epsilon\), we follow the same way as the DFT calculation and predict it by calculating \(\epsilon_{\mathrm{LUMO}}-\epsilon_{\mathrm{HOMO}}\). For property \(\mu\), the final result is the magnitude of the predicted vectorial \(\boldsymbol{\mu}\) when using our geometric vector-based approaches with PAMNet. The 3D molecular structures are processed using the RDKit library [46]. Following [15], we randomly use 110000 molecules for training, 10000 for validation and 10831 for testing. In our multiplex molecular graphs, we use chemical bonds as the edges in the local layer, and a cutoff distance (5 or 10\(\text{\normalfont\AA}\)) to create the edges in the global layer.
RNA-Puzzles consists of the first 21 RNAs in the RNA-Puzzles structure prediction challenge [32]. Each RNA is used to generate at least 1500 structural models using FARFAR2, where 1\(\%\) of the models are near native (i.e., within a 2\(\text{\normalfont\AA}\) RMSD of the experimentally determined native structure). Following [30], we only use the carbon, nitrogen, and oxygen atoms in RNA structures. When building multiplex graphs for RNA structures, we use cutoff distance \(d_l=2.6\text{\normalfont\AA}\) for the local interactions in \(G_{local}\) and \(d_g=20\text{\normalfont\AA}\) for the global interactions in \(G_{global}\).
For PDBBind, we use PDBbind v2016 following [18], [36]. Besides, we use the same data splitting method according to [18] for a fair comparison. In detail, we use the core subset which contains 290 complexes in PDBbind v2016 for testing. The difference between the refined and core subsets, which includes 3767 complexes, is split with a ratio of 9:1 for training and validation. We use log\(K_i\) as the target property being predicted, which is proportional to the binding free energy. In each complex, we exclude the protein residues that are more than 6\(\text{\normalfont\AA}\) from the ligand and remove all hydrogen atoms. The resulting complexes contain around 300 atoms on average. In our multiplex molecular graphs, we use cutoff distance \(d_l=2\text{\normalfont\AA}\) in the local layer and \(d_g=6\text{\normalfont\AA}\) in the global layer.
In our message passing operations, we define \(\phi_{d}(\boldsymbol{e})=\boldsymbol{W}_{\boldsymbol{e}}\boldsymbol{e}\) and \(\phi_{\alpha}(\boldsymbol{\alpha})=\mathrm{MLP}_{\alpha}(\boldsymbol{\alpha})\), where \(\boldsymbol{W}_{\boldsymbol{e}}\) is a weight matrix, \(\mathrm{MLP}_{\alpha}\) is a multi-layer perceptron (MLP). All MLPs used in our model have two layers by taking advantage of the approximation capability of MLP [47]. For all activation functions, we use the self-gated Swish activation function [48]. For the basis functions, we use the same parameters as in [15]. To initialize all learnable parameters, we use the default settings used in PyTorch without assigning specific initializations except the initialization for the input node embeddings on QM9: \(\boldsymbol{h}\) are initialized with random values uniformly distributed between \(-\sqrt{3}\) and \(\sqrt{3}\). In all experiments, we use the Adam optimizer [49] to minimize the loss. In Supplementary Table 7, we list the typical hyperparameters used in our experiments. All of the experiments are done on an NVIDIA Tesla V100 GPU (32 GB).
In our experiment on QM9, we use the single-target training following [15] by using a separate model for each target instead of training a single shared model for all targets. The models are optimized by minimizing the mean absolute error (MAE) loss. We use a linear learning rate warm-up over 1 epoch and an exponential decay with a ratio 0.1 every 600 epochs. The model parameter values for validation and testing are kept using an exponential moving average with a decay rate of 0.999. To prevent overfitting, we use early stopping on the validation loss. For properties ZPVE, \(U_0\), \(U\), \(H\), and \(G\), we use the cutoff distance in the global layer \(d_g=5\text{\normalfont\AA}\). For the other properties, we use \(d_g=10\text{\normalfont\AA}\). We repeat our runs 3 times for each PAMNet variant following [50].
PAMNet is optimized by minimizing the smooth L1 loss[51] between the predicted value and the ground truth. An early-stopping strategy is adopted to decide the best epoch based on the validation loss.
We create three weight-sharing, replica networks, one each for predicting the target \(G\) of complex, protein pocket, and ligand following [52]. The final target is computed by \(\Delta G_{\text{complex}} = G_{\text{complex}} - G_{\text{pocket}} - G_{\text{ligand}}\). The full model is trained by minimizing the mean absolute error (MAE) loss between \(\Delta G_{\text{complex}}\) and the true values. The learning rate is dropped by a factor of 0.2 every 50 epochs. Moreover, we perform 5 independent runs according to [18].
In the experiment on investigating the efficiency of PAMNet, we use NVIDIA Tesla V100 GPU (32 GB) for a fair comparison. For small molecule property prediction, we use the related models for predicting property \(U_0\) of QM9 as an example. We use batch size=128 for all models and use the configurations reported in the corresponding papers. For RNA 3D structure prediction, we use PAMNet and ARES to predict the structural models of RNA in puzzle 5 of RNA-Puzzles challenge. The RNA being predicted has 6034 non-hydrogen atoms. The model settings of PAMNet and ARES are the same as those used for reproducing the best results. We use batch size=8 when performing the predictions. For protein-ligand binding affinity prediction, we use the configurations that can reproduce the best results for the related models.
The QM9 dataset is available at https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904. The datasets for RNA 3D structure prediction can be found at https://purl.stanford.edu/bn398fc4306. The PDBbind v2016 dataset is available at http://www.pdbbind.org.cn/ or https://github.com/PaddlePaddle/PaddleHelix/tree/dev/apps/drug_target_interaction/sign.
The source code of our model is publicly available on GitHub at the following repository: https://github.com/XieResearchGroup/Physics-aware-Multiplex-GNN.
This project has been funded with federal funds from the National Institute of General Medical Sciences of National Institute of Health (R01GM122845) and the National Institute on Aging of the National Institute of Health (R01AG057555).
L.X. and S.Z. conceived and designed the method and the experiments. S.Z. and Y.L. prepared the data. S.Z. implemented the algorithm and performed the experiments. All authors wrote and reviewed the manuscript.
The authors declare no competing interests.
The following methods are being compared with our PAMNet in experiments:
SchNet [13] is a GNN that uses continuous-filter convolutional layers to model atomistic systems. Interatomic distances are used when designing convolutions.
PhysNet [14] uses message passing scheme for predicting properties of chemical systems. It models chemical interactions with learnable distance-based functions.
MGCN [25] utilizes the multilevel structure in molecular system to learn the representations of quantum interactions level by level based on GNN. The final molecular property prediction is made with the overall interaction representation.
PaiNN [26] is a GNN that augments the invariant SchNet into equivariant flavor by projecting the pairwise distances via radial basis functions and iteratively updates the geometric vectors along with the scalar features.
DimeNet++ [16] is an improved version of DimeNet [15] with better accuracy and faster speed. It can also be used for non-equilibrium molecular structures.
SphereNet [27] is a GNN method that achieves local completeness by incorporating comprehensive 3D information like distance, angle, and torsion information for 3D graphs.
ARES [30] is a state-of-the-art machine learning approach for identifying accurate RNA 3D structural models from candidate ones. It is a GNN that integrates rotational equivariance into the message passing.
Rosetta [29] is a molecular modeling software package that provides tools for RNA 3D structure prediction.
RASP [33] is a full-atom knowledge-based potential with geometrical descriptors for RNA structure prediction.
3dRNAscore [28] is an all-heavy-atom knowledge-based potential that combines distance-dependent and dihedral-dependent energies for identifying native RNA structures and ranking predicted structures.
ML-based methods include linear regression (LR), support vector regression (SVR), and random forest (RF). These approaches use the inter-molecular interaction features introduced in RF-Score [35] as input for prediction.
Pafnucy [36] is a representative 3D CNN-based model that learns the spatial structure of protein-ligand complexes.
OnionNet [37] is a CNN-based method that generates 2D interaction features by considering rotation-free element-pair contacts in complexes.
GraphDTA [38] uses GNN models to learn the complex graph and utilizes CNN to learn the protein sequence. We use the best-performed variant (GAT-GCN) for comparison.
SGCN [39] utilizes atomic coordinates and leverages node positions based on graph convolutional network [53].
GNN-DTI [40] is a distance-aware graph attention network [23] that considers 3D structural information to learn the intermolecular interactions in protein-ligand complexes.
D-MPNN [12] is a message passing neural network that incorporates edge features. The aggregation process addresses the pairwise distance information contained in edge features.
MAT [41] utilizes inter-atomic distances and employs a molecule-augmented attention mechanism based on transformers for graph representation learning.
DimeNet [15] is a message passing neural network using directional message passing scheme for small molecules. Both distances and angles are used when modeling the molecular interactions.
CMPNN [42] is built based on D-MPNN and has a communicative message passing scheme between nodes and edges for better performance when learning molecular representations.
SIGN [18] is a recent state-of-the-art GNN for predicting protein-ligand binding affinity. It builds complex interaction graphs for protein-ligand complexes and integrates both distance and angle information in modeling.
For small molecule property prediction, we use the baseline results reported in their original works for baselines. For RNA 3D structure prediction, we use the baseline results in [30]. For protein-ligand binding affinity prediction, we use the baseline results in [18]. When performing efficiency evaluation in our experiments, we adopt the public-available implementations of the related models: For DimeNet and DimeNet++, we adopt the implementation by PyTorch Geometric [54] at https://github.com/rusty1s/pytorch_geometric/blob/73cfaf7e09/examples/qm9_dimenet.py. For SphereNet, we use the official implementation at https://github.com/divelab/DIG. For ARES, we use the official implementation at https://zenodo.org/record/6893040. For SIGN, we use the official implementation at https://github.com/PaddlePaddle/PaddleHelix/tree/dev/apps/drug_target_interaction/sign.
For each RNA in RNA-Puzzles, we rank the structural models using PAMNet and four baseline scoring functions. For each scoring function, we select the N \(\in \{1, 10, 100\}\) best-scoring structural models for each RNA. For each RNA, scoring function, and N, we show the lowest RMSD across structural models in Figure 5). The RMSD results are quantized to determine if each RMSD is below 2\(\text{\normalfont\AA}\), between 2\(\text{\normalfont\AA}\) and 5\(\text{\normalfont\AA}\), between 5\(\text{\normalfont\AA}\) and 10\(\text{\normalfont\AA}\), or above 10\(\text{\normalfont\AA}\). From the results, we find that for each RMSD threshold (2\(\text{\normalfont\AA}\), 5\(\text{\normalfont\AA}\), or 10\(\text{\normalfont\AA}\)) and for each N, the number of RNAs with at least one selected model that has RMSD below the threshold is greater when using PAMNet than when using any of the other four baseline scoring functions.
We use p-value to compute the statistical significance between SIGN and PAMNet on PDBbind. As shown in Table 5, PAMNet performs significantly better than SIGN on all four metrics with p-value < 0.05.
In Table 6, we list the results of all properties on QM9 in our ablation study.
Model | RMSE \(\downarrow\) | MAE \(\downarrow\) | SD \(\downarrow\) | R \(\uparrow\) |
---|---|---|---|---|
SIGN | 1.316 (0.031) | 1.027 (0.025) | 1.312 (0.035) | 0.797 (0.012) |
PAMNet | 1.263 (0.017) | 0.987 (0.013) | 1.261 (0.015) | 0.815 (0.005) |
Significance (p-value) | 0.0122 | 0.0156 | 0.0242 | 0.0212 |
Model | \(\mu\) | \(\alpha\) | \(\epsilon_{\text{HOMO}}\) | \(\epsilon_{\text{LUMO}}\) | \(\delta \epsilon\) | \(\left\langle R^{2}\right\rangle\) | ZPVE | \(U_0\) | \(U\) | \(H\) | \(G\) | \(c_v\) |
---|---|---|---|---|---|---|---|---|---|---|---|---|
PAMNet | 10.8 | 0.0447 | 22.8 | 19.2 | 31.0 | 0.093 | 1.17 | 5.90 | 5.92 | 6.04 | 7.14 | 0.0231 |
PAMNet w/o Attention Pooling | 11.1 | 0.0469 | 24.2 | 20.3 | 32.8 | 0.094 | 1.22 | 6.12 | 6.15 | 6.29 | 7.44 | 0.0234 |
PAMNet w/o Local MP | 13.9 | 0.0512 | 27.8 | 23.3 | 37.6 | 0.104 | 1.27 | 7.55 | 7.57 | 7.74 | 9.13 | 0.0262 |
PAMNet w/o Global MP | 21.8 | 0.0887 | 41.5 | 34.9 | 56.4 | 5.53 | 1.52 | 8.80 | 8.81 | 9.01 | 10.6 | 0.0316 |
Hyperparameters | Value | ||
QM9 | RNA-Puzzles | PDBbind | |
Batch Size | 32, 128 | 8 | 32 |
Hidden Dim. | 128 | 16 | 128 |
Initial Learning Rate | 1e-4 | 1e-4 | 1e-3 |
Number of Layers | 6 | 1 | 3 |
Max. Number of Epochs | 900 | 50 | 100 |