April 02, 2024
Understanding and leveraging the 3D structures of proteins is central to a variety of biological and drug discovery tasks. While deep learning has been applied successfully for structure-based protein function prediction tasks, current methods usually employ distinct training for each task. However, each of the tasks is of small size, and such a single-task strategy hinders the models’ performance and generalization ability. As some labeled 3D protein datasets are biologically related, combining multi-source datasets for larger-scale multi-task learning is one way to overcome this problem. In this paper, we propose a neural network model to address multiple tasks jointly upon the input of 3D protein structures. In particular, we first construct a standard structure-based multi-task benchmark called Protein-MT, consisting of 6 biologically relevant tasks, including affinity prediction and property prediction, integrated from 4 public datasets. Then, we develop a novel graph neural network for multi-task learning, dubbed Heterogeneous Multichannel Equivariant Network (HeMeNet), which is E(3) equivariant and able to capture heterogeneous relationships between different atoms. Besides, HeMeNet can achieve task-specific learning via the task-aware readout mechanism. Extensive evaluations on our benchmark verify the effectiveness of multi-task learning, and our model generally surpasses state-of-the-art models.^{1}
Proteins consist of one or more chains of amino acids, and they are vital in many biological systems. The 3D structure of a protein sets the foundation of its interaction with other molecules, which finally determines its functions. In recent years, learning-based methods have been applied widely to leverage the 3D structures of proteins for various tasks such as property prediction [1], affinity prediction [2], rigid docking [3], and antibody generation [4], owing to their superior efficiency and lower cost compared to those wet-lab experimental approaches. A major part of learning-based methods resort to Graph Neural Networks (GNNs) [5], which naturally encode the 3D structures of proteins by modeling atoms or residues as nodes and the connections in between as edges. In addition, certain GNNs are geometry-aware and designed to capture the symmetry of E(3) transformations for better predictions [6], [7].
Despite significant progress in geometric-aware GNNs for protein tasks, existing methods usually employ one model for one task. A clear drawback of such a single-task training strategy is that the model should be re-trained for each new task. However, structural datasets with annotations are often limited in size due to the expensive cost of acquiring protein 3D structures and labels via wet lab experiments, especially for affinity prediction. For example, in PDBbind [8], only 2852 complexes for the Protein-Protein Affinity (PPA) are experimentally annotated. Due to the sparsity of labeled structure samples, conducting model training on a single-task dataset of small size usually leads to defective performance and an inability to generalize.
To deal with the sparsity of labeled data issues, some previous works leverage multi-task learning, which designs a model and trains it with multiple related tasks. Collecting samples with related tasks can bring more information to improve the performance [9], [10]. However, most of these works are sequence-based [11], [12], and the formulations are separated annotations for different samples. The work by[13] is a structure-based multi-task method, but it mainly focuses on residue-level interface prediction.
Recent research shows that some protein properties potentially imply the protein’s binding activity. For example, gene ontology would contain knowledge for protein-protein interaction [10]; Enzyme commission and gene ontology can provide molecular context for protein-ligand binding affinity (LBA) [9]. These indicate that some single-chain functions may benefit the prediction of complex-level affinity and vice versa. Motivated by this fact, we propose combining affinity and property prediction datasets in the framework of joint training.
A key problem hindering structural data integration is the lack of appropriate models for various inputs and tasks. As shown in Figure 1, many affinity prediction models [2], [14] utilize full-atom information at the binding interface to predict affinity, which loses the information of the whole chain. While many function prediction models [15], [16] utilize alpha-carbon to predict chain-level functions, which loses the detailed atom interaction information for affinity prediction.
In this paper, we propose a structure-based multi-task learning paradigm. We use a heterogeneous full-atom model to deal with multiple tasks upon various 3D protein inputs. Nevertheless, accomplishing structural multi-task training is challenging. The first challenge is that there is no available benchmark. The ideal benchmark should cover a sufficient range of data and biologically related task types, with a fully labeled test set to compare how a model performs on the same input for different task outputs. The second challenge is that it is nontrivial to design a generalist model that is capable of processing the complicated 3D structures of input proteins of various types, including single-chain, protein-protein, and protein-ligand, and it should perform well across different tasks, including protein affinity and property predictions. By achieving the structure-based full-atom protein multi-task learning, we make the following contributions:
To the best of our knowledge, we are the first to propose the concept of structure-based protein multi-task learning. We carefully integrate the structures and labels from 4 public datasets with our proposed standard process and construct a new benchmark named Protein Multiple Tasks (Protein-MT), which consists of 6 representative tasks upon 3 different types of inputs.
We propose a novel model for protein structure learning, dubbed Heterogeneous Multichannel Equivariant Network (HeMeNet), which is E(3) equivariant and able to capture various relationships between different atoms owing to the heterogeneous multichannel graph construction of proteins. Additionally, we develop a task-aware readout mechanism by associating the output head of each task with a learnable task prompt for different tasks.
For the experiments on Protein-MT, HeMeNet surpasses other state-of-the-art methods in most tasks under both the single-task and multi-task settings. Particularly on the LBA and PPA tasks, we find that the multi-task HeMeNet is significantly better than its single-task counterpart.
Predicting the binding affinity and properties for proteins with computational methods is of growing interest [17], [18]. Previous research learns protein representations by information different forms, most of which take amino acid sequence [19], [20], multiple sequence alignment [21] or 3D structure [15], [22] as input. Many works encode the information of a protein’s 3D structure by GNNs [15], [23]. [2] take full-atom geometry at the interaction interface, and [15] take residue-level geometry of the protein for property prediction. Our method utilizes full-atom geometry on the whole protein to address affinity and property prediction tasks together.
Many equivariant GNNs have emerged recently with the inductive bias of 3D symmetry, modeling various tasks including docking, molecular function prediction and sequence design [6], [24]–[26]. To empower the model with the ability to handle the complicated full-atom geometry, some models design multi-channel equivariant message passing for atom sets, such as GMN [7] and dyMEAN [4]. We propose a powerful heterogeneous equivariant GNN capable of handling various incoming message types.
Multi-task learning takes advantage of knowledge transfer from multiple tasks, achieving a better generalization performance. In the field of protein, several works leverage multi-task learning on the task of interaction prediction and property prediction, most of which are for sequence-based models. [27] design three Enzyme Commission number related hierarchical tasks to train the model. [28] introduce a multi-task protein pre-training method with prompts. [11], [12] are sequence-based multi-task benchmarks for protein function prediction. [13] improves the protein interaction interface prediction by structural multitask auxiliary learning. To the best of our knowledge, we are the first to combine structure-based interaction prediction and property prediction in a multi-task setting.
Based on the observation that protein property and binding affinity tasks may benefit each other, we construct a new dataset called Protein Multiple Tasks (Protein-MT) for protein multi-task learning. Protein-MT is composed of different types of tasks on 3D protein structures: the prediction of Ligand Binding Affinity (LBA) and Protein-Protein Affinity (PPA) based on two-instance complexes and the prediction of Enzyme Commission (EC) number and Gene Ontology (GO) terms based on single-chain structures. Particularly, the LBA and PPA tasks originated from the PDBbind database [8] aim at regressing the affinity value of a protein-ligand complex and protein-protein complex, respectively. The EC task is constructed by [29] to describe the catalysis of biochemical reactions consisting of samples, each with 538 binary-class labels. The GO task aims to predict the hierarchically related functional properties of gene products [29]: Molecular Function (MF), Biological Process (BP), and Cellular Component (CC). We treat the prediction of MF, BP and CC as three individual tasks, resulting in six different prediction tasks in total.
One key difficulty in integrating these tasks from their sourced datasets is that samples from one task may lack the labels for other tasks. It is crucial to obtain samples with a complete set of labels across tasks for the training and evaluation of multi-task learning methods. As shown in Figure 2, we propose a standard matching pipeline that enables us to transfer the labels between EC and GO, and assign EC and GO labels for the chains of complexes in LBA and PPA as well (it is impossible to conduct the inverse direction since it is meaningless to assign LBA or PPA for those single chains in EC and GO). Specifically, we utilize the UniProt ID to uniquely identify a protein chain^{2}. We first obtain the UniProt IDs of all protein chains in Protein-MT from Protein Data Bank [30]. For each UniProt ID, we then determine the EC and GO properties based on the labels of the corresponding chains in the EC and GO datasets, resulting in a UniProt-Property dictionary. With this dictionary, for a chain missing EC or GO labels (e.g., a chain of a complex in LBA and PPA), we can supplement the missing labels by searching the UniProt-Property dictionary by its UniProt ID to retrieve any known EC and GO labels. We define a complex (from either LBA or PPA) as fully-labeled if the complex has one affinity label (LBA or PPA) and four function labels for each of its chains. After our above matching process, we formulate the train/validation/test split in terms of the chain-level sequence identity through the alignment methods commonly used in single-chain property prediction tasks [29]. For more details of the construction process and dataset statistics, please refer to Appendix 7.
In this section, we first introduce our heterogeneous graph representation and the multi-task formulation in Section 4.1. Then, we design the architecture of the proposed HeMeNet in Section 4.2, which consists of two key components: heterogeneous multi-channel equivariant message passing and task-aware readout.
The input of our model is of various types. It could be either a two-instance complex (protein-ligand for LBA and protein-protein for PPA) or a single chain (for EC and GO). Here, for consistency, we unify these two different kinds of input as a graph \({\mathcal{G}}\) composed of two sets of nodes \({\mathcal{V}}_{r}\) and \({\mathcal{V}}_{l}\). For the LBA complex input, \({\mathcal{V}}_{r}\) and \({\mathcal{V}}_{l}\) denote the receptor and the ligand, respectively, while for the PPA complex and single-chain input, \({\mathcal{V}}_{r}\) refers to the receptor protein chain and \({\mathcal{V}}_{l}\) becomes the corresponding binding protein chain. And for function prediction tasks, \({\mathcal{V}}_{r}\) refers to the protein chain and \({\mathcal{V}}_{l}\) becomes an empty set, as shown in the middle of Figure 1. We associate each node \(v_i\) with the representation \(({\boldsymbol{h}}_i, \vec{{\boldsymbol{X}}}_i)\), where \({\boldsymbol{h}}_i\in\mathbb{R}^d\) denotes the node feature and it is initialized as a learnable residue embedding, \(\vec{{\boldsymbol{X}}}_i\in\mathbb{R}^{3\times c_i}\) indicates the 3D coordinates of all \(c_i\) atoms within the node. As for edge construction, we include various types of edges. In detail, for residue nodes, we allow \(R\) heterogeneous types of edge connections including sequential edges of different distances (\(d = \{-2, -1, 1, 2\}\)), self-loop edges, and spatial edges; for single-atom nodes from small molecules, only spatial edges are created. We present a simplified example from the LBA task in Figure 3, where we only draw a few nodes and omit the self-loop edges except for the central node for simplicity. Overall, we obtain a full-atom heterogeneous graph representation \({\mathcal{G}}\) for each input.
Given a full-atom heterogeneous graph \({\mathcal{G}}\), our goal is to design a model \({\boldsymbol{p}}=f({\mathcal{G}})\) with multiple-dimensional output \({\boldsymbol{p}}\) that is able to predict the complex-level affinity and chain-level functional properties simultaneously. By making use of our proposed dataset Protein-MT, we train the model with a partially labeled training set and test it on the fully-labeled test set. Notably, the prediction should be invariant with regard to E(3) transformation (rotation/reflection/translation) of the input coordinates. To do so, we will formulate an equivariant encoder plus an invariant output layer in our model, detailed in the next subsection.
To better cope with the 3D structures of different types for different tasks, we propose a heterogeneous multi-channel E(3) equivariant graph neural network with the ability to aggregate different relational messages. After several layers of the message passing, the node representations are transformed into task-specific representations by a task-aware readout module, generating appropriate complex-level and chain-level predictions via different task heads.
Inspired by dyMEAN [4], we leverage a multi-channel coordinate matrix with dynamic size to record the geometric information of a node in an input graph. Moreover, we extend the setting to heterogeneous message passing along multiple types of edges to capture rich relationships between nodes. We denote the node feature and coordinates as \(({\boldsymbol{h}}_i^{(l)}, \vec{{\boldsymbol{X}}}_i^{(l)})\) in the \(l\)-th layer. The message passing is calculated as: \[\begin{align} {\boldsymbol{m}}_{ijr} &= \phi_m({\boldsymbol{h}}_i^{(l)}, {\boldsymbol{h}}_j^{(l)}, \frac{T_R(\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)}, \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})}{||T_R(\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)}, \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})||_F +\epsilon}, {\boldsymbol{e}}_r), \\ \displaystyle \vec{{\boldsymbol{M}}}_{ijr} &= T_S(\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} - \frac{1}{c_j}\sum_{k=1}^{c_j}\displaystyle \vec{{\boldsymbol{X}}}_{j}^{(l)}(:, k), \phi_{x}({\boldsymbol{m}}_{ijr})), \label{E1} \end{align}\tag{1}\] where, \({\boldsymbol{m}}_{ijr}\) and \(\vec{{\boldsymbol{M}}}_{ijr}\) are separately the invariant and equivariant messages from node \(j\) to \(i\) along the \(r\)-th edge; \({\boldsymbol{e}}_r\) is the edge embedding feature; \(\phi_m, \phi_x\) are Multi-Layer Perceptrons (MLPs) [31] with one hidden layer; \(||\cdot||_F\) computes the Frobenius norm, \(T_R\) and \(T_S\) are the adaptive multichannel geometric relation extractor and geometric message scaler, in order to deal with the issue incurred by the varying shape of \(\vec{{\boldsymbol{X}}}_i^{(l)}\) and \(\vec{{\boldsymbol{X}}}_j^{(l)}\) since the number of atoms could be different for different nodes. With the calculated messages, the node representation is updated by: \[\begin{align} {\boldsymbol{h}}_i^{(l+1)} &= {\boldsymbol{h}}_i^{(l)} + \sigma(\mathrm{BN}(\phi_h(\sum_{r \in R} {\boldsymbol{W}}_r\sum_{j \in {\mathcal{N}}_r(i)}{\boldsymbol{m}}_{ijr}))), \\ \displaystyle \vec{{\boldsymbol{X}}}_i^{(l+1)} &= \displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} + \frac{1}{\sum_{r \in R}|\mathcal{N}_r(i)|}\sum_{r \in R}\sum_{j \in \mathcal{N}_r(i)}w_r \displaystyle \vec{{\boldsymbol{M}}}_{ijr}, \label{escxgjdf} \end{align}\tag{2}\]
where, \({\boldsymbol{W}}_r, w_r\) are a learnable matrix and a learnable scalar to project invariant an equivariant messages, respectively, for the \(r\)-th kind of edge; \({\mathcal{N}}_r(i)\) denotes the neighbor nodes of \(i\) regarding the \(r\)-th kind of edges; \(\phi_h\) is an MLP, \(\mathrm{BN}\) is the batch normalization operation, and \(\sigma\) is an activation function. During the message-passing process, our model gathers information from different relations for \({\boldsymbol{h}}_i\) and \(\vec{{\boldsymbol{X}}}_i\), ensuring the E(3) equivariance. For further details of the components in our model, please refer to Appendix 8.
Task-Aware ReadoutAfter \(L\) layers of the above relational message passing, we attain a set of E(3) invariant node features \({\boldsymbol{H}}^{(L)}\in\mathbb{R}^{n\times d_L}\), where \(n\) is the number of nodes and \(d_L\) is the feature dimension. To correlate the task-specific information with each node feature, we propose a task-aware readout function. We compute the attention weights between each node and each task-specific query, and then readout all weighted node features into dual-level representations: graph level or chain level for each task. As shown in Figure 3, the task-aware readout module is formulated as: \[\begin{align} \displaystyle \boldsymbol{\alpha}_{t} &= \mathrm{Softmax}(\frac{\displaystyle\displaystyle {\boldsymbol{K}}{\boldsymbol{q}}_{t}}{\sqrt{d_L}}),\\ \displaystyle {\boldsymbol{f}}_t &= \mathrm{FFN} (\displaystyle \boldsymbol{\alpha}_t{\boldsymbol{V}}+ \mathrm{Linear}(\displaystyle {\boldsymbol{q}}_t)), \label{E5} \end{align}\tag{3}\] where, \(\boldsymbol{\alpha}_t\in[0,1]^{n}\) defines the attention values for task \(t\); \(\displaystyle {\boldsymbol{q}}_t \in \mathbb{R}^{d_L}\) is the learnable query for task \(t\); \(\displaystyle {\boldsymbol{K}}=\displaystyle {\boldsymbol{H}}\displaystyle {\boldsymbol{W}}_K\in\mathbb{R}^{n\times d_L}\) and \(\displaystyle {\boldsymbol{V}}=\displaystyle {\boldsymbol{H}}\displaystyle {\boldsymbol{W}}_V\in\mathbb{R}^{n\times d_L}\) are the key and value matrices, respectively; \(\mathrm{FFN}\) is the feed-forward network containing layer normalization and linear layers; \(\mathrm{Linear}(\displaystyle {\boldsymbol{q}}_t) = \displaystyle {\boldsymbol{W}}_Q \displaystyle {\boldsymbol{q}}_t + \displaystyle {\boldsymbol{b}}\) is used before the shortcut addition. In our implementation, we apply the multi-head attention strategy by defining multiple queries for each task. Although we compute the attention by only using the invariant features \({\boldsymbol{H}}^{(L)}\), it indeed has involved the geometric information from the 3D coordinates during the previous \(L\)-layer message passing.
Multiple Task Heads We feed the above task-specific feature \(\displaystyle {\boldsymbol{f}}_i\) into different task heads implemented by MLPs, resulting in a prediction list \((\displaystyle {\boldsymbol{p}}_{\mathrm{lba}}, {\boldsymbol{p}}_{\mathrm{ppa}}, {\boldsymbol{p}}_{\mathrm{ec}},{\boldsymbol{p}}_{\mathrm{mf}}, {\boldsymbol{p}}_{\mathrm{bp}}, {\boldsymbol{p}}_{\mathrm{cc}})\). For regression tasks (including LBA and PPA), we use the Mean Square Error (MSE) loss \(\mathcal{L}_{\mathrm{MSE}}\). For classification tasks (including EC, GO-MF, GO-BP and GO-CC), we use the Binary Cross Entropy (BCE) loss \(\mathcal{L}_{\mathrm{BCE}}\). The training loss is formulated as: \[\begin{align} \mathcal{L} &= \mathcal{L}_{\mathrm{reg}} + \lambda \mathcal{L}_{\mathrm{cls}}, \label{E2} \end{align}\tag{4}\] where \(\mathcal{L}_{\mathrm{reg}} = \displaystyle \boldsymbol{1}_\mathrm{lba}\mathcal{L}_{\mathrm{MSE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{lba}}) + \displaystyle \boldsymbol{1}_\mathrm{ppa}\mathcal{L}_{\mathrm{MSE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{ppa}}), \mathcal{L}_{\mathrm{cls}} = \displaystyle \boldsymbol{1}_\mathrm{ec}\mathcal{L}_{\mathrm{BCE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{ec}}) + \displaystyle \boldsymbol{1}_\mathrm{mf}\mathcal{L}_{\mathrm{BCE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{mf}}) + \displaystyle \boldsymbol{1}_\mathrm{bp}\mathcal{L}_{\mathrm{BCE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{bp}}) + \displaystyle \boldsymbol{1}_\mathrm{cc}\mathcal{L}_{\mathrm{BCE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{cc}})\), \(\lambda\) is a hyper-parameter to balance the trade-off of the losses. To allow training on the partially labeled sample, if the label of the task \(*\) exists, then \(\displaystyle \boldsymbol{1}_\mathrm{*} = \lambda_\mathrm{*}\), otherwise \(\displaystyle \boldsymbol{1}_\mathrm{*} = 0\). In addition, we adopt a balanced sampling strategy to ensure that each sampled mini-batch should contain at least one sample from each task, which further accelerates the training convergence.
In this section, we will first introduce the experimental setup in Section 5.1. In Section 5.2, we evaluate our model on the proposed dataset Protein-MT for affinity and property prediction in both single-task and multi-task settings and compare it with other baseline models. In Section 5.3, we experiment with different readout strategies and compare their performance on property prediction tasks. In Section 5.4, we perform ablation experiments on different modules.
0.3cm
Method | LBA | PPA | EC\(\uparrow\) | GO | |||||
---|---|---|---|---|---|---|---|---|---|
3-4 (lr)5-6 (lr)8-10 | RMSE\(\downarrow\) | MAE\(\downarrow\) | RMSE\(\downarrow\) | MAE\(\downarrow\) | MF\(\uparrow\) | BP\(\uparrow\) | CC\(\uparrow\) | ||
GCN [32] | 2.193 | 1.721 | 7.840 | 7.738 | 0.022 | 0.207 | 0.254 | 0.367 | |
GAT [33] | 2.301 | 1.838 | 7.820 | 7.720 | 0.018 | 0.223 | 0.249 | 0.354 | |
SchNet [34] | 2.162 | 1.692 | 7.839 | 7.729 | 0.097 | 0.311 | 0.281 | 0.431 | |
GearNet* [15] | 1.957 | 1.542 | 2.004 | 1.279 | 0.716 | 0.677 | 0.252 | 0.438 | |
GearNet-fullatom [15] | 2.178 | 1.716 | 2.753 | 2.709 | 0.046 | 0.212 | 0.229 | 0.471 | |
2-10 | EGNN [6] | 2.282 | 1.849 | 4.854 | 4.756 | 0.039 | 0.206 | 0.253 | 0.357 |
GVP [16] | 2.281 | 1.789 | 5.280 | 5.267 | 0.020 | 0.204 | 0.244 | 0.454 | |
dyMEAN [4] | 2.410 | 1.987 | 7.309 | 7.182 | 0.115 | 0.436 | 0.292 | 0.477 | |
HemeNet (Ours) | 1.912 | 1.490 | 6.031 | 5.891 | 0.863 | 0.778 | 0.404 | 0.544 | |
1-10 | SchNet [34] | 1.763 | 1.447 | 1.216 | 1.120 | 0.093 | 0.192 | 0.264 | 0.402 |
GearNet* [15] | 2.193 | 1.863 | 1.275 | 1.035 | 0.187 | 0.203 | 0.261 | 0.379 | |
GearNet-fullatom [15] | 1.839 | 1.350 | 1.821 | 1.491 | 0.047 | 0.155 | 0.258 | 0.443 | |
CDConv [35] | 1.579 | 1.352 | 2.386 | 1.822 | 0.324 | 0.246 | 0.241 | 0.424 | |
2-10 | EGNN [6] | 1.777 | 1.441 | 0.999 | 0.821 | 0.048 | 0.169 | 0.244 | 0.352 |
GVP [16] | 1.870 | 1.572 | 0.906 | 0.758 | 0.018 | 0.168 | 0.246 | 0.360 | |
dyMEAN [4] | 1.777 | 1.446 | 1.725 | 1.523 | 0.038 | 0.164 | 0.263 | 0.449 | |
HeMeNet* (Ours) | 1.799 | 1.420 | 0.861 | 0.719 | 0.630 | 0.595 | 0.279 | 0.426 | |
HeMeNet (Ours) | 1.730 | 1.335 | 1.087 | 0.912 | 0.810 | 0.727 | 0.379 | 0.436 | |
1-10 | GPT4-turbo-1106 (ST) [36] | 2.347 | 1.780 | 1.654 | 1.343 | - | - | - | - |
ESM2 (MT) [37] | 2.009 | 1.334 | 1.692 | 1.333 | 0.917 | 0.764 | 0.389 | 0.533 | |
ESM2-HeMeNet (MT) | 1.867 | 1.661 | 1.846 | 1.418 | 0.921 | 0.796 | 0.455 | 0.567 | |
represents trained under the alpha-Carbon atom only setting. ST, MT is the abbreviation of single-task and multi-task, respectively.
Task settings We compare the performance of HeMeNet with other models under single-task and multi-task settings using the same validation and test sets. For single-task training, models are trained on samples with labels of the corresponding task. We also remove the task-aware readout of our model for a fair comparison. For multi-task training, the models are trained on all partially labeled training samples.
We use a balanced sampler for each batch to sample at least one complex from LBA and one complex from PPA. We include samples with up to 15,000 atoms for training and evaluation. All models are trained for 30 epochs on 4 NVIDIA A100 GPUs. More details on experimental settings can be found in Appendix 9
Baselines We compared our model with nine representative baselines. GCN [32] aggregates information weighted by the degree of nodes. GAT [33] utilizes an attention mechanism for message passing. Schnet [34] is an invariant network with continuous filter convolution on the 3D molecular graph. GearNet [15] designs a relational message-passing network to capture information on protein function tasks. Its variant Gearnet-fullatom [38] utilizes the model to the full-atom setting. CDConv [35] models the geometric sequence with a continuous-discrete convolution. Besides the previous invariant models, we also compare our method with equivariant models. EGNN [6] is a lightweight but effective E(n) equivariant graph neural network. GVP [16] designs an equivariant geometric vector perceptron for protein representation. dyMEAN [4] is an equivariant model for antibody design; it takes a dynamic multichannel equivariant function for full-atom coordinates. ESM2 [37] is a general purpose protein language model. We also conducted three-shot prompting to test GPT-4 [36] on the two affinity prediction tasks, see Appendix 10 for the prompt details.
Evaluation For LBA and PPA tasks, we employ the commonly used Root Mean Square Error (RMSE) and Mean Average Error (MAE) as the evaluation metrics [39]. For EC and GO tasks, we use maximum F-score (Fmax) following [15]. Each experiment is independently run three times with different random seeds.
We conduct experiments under both single-task and multi-task settings. The mean results of three runs are reported in Table 1. According to the results, we draw conclusions summarized in the subsequent paragraphs.
Our model outperforms the baselines on most of the tasks under both settings. Under the single-task setting, our model surpasses other models in five of the six tasks. Under the multi-task setting, our model surpasses other models in four of the six tasks, with the remaining two tasks reaching second and third place, respectively. We also compared our model with GPT4 and ESM2 (with multi-task head finetuning). Our method outperforms GPT4 in both LBA and PPA. ESM2 performs well on four property prediction tasks, and combining ESM2 and HeMeNet can further improve ESM2’s overall performance with geometric information. Notably, under the single-task setting, only the models with a heterogeneous message passing (GearNet and ours) can perform well on all of the four property prediction tasks. Under the multi-task setting, our full-atom model, benefiting from joint learning, shows superior results on different tasks, and there are two main interesting observations discussed next.
Our model benefits from the multi-task setting, especially on LBA and PPA. We observe that almost all models improve their performance on LBA and PPA tasks under the multi-task setting. In particular, our model significantly improves the PPA RMSE from 6.031 to 1.087 by utilizing a training set that is more than ten times larger (2587 for PPA single-task and 30904 for our multi-task training set). We also train our model with alpha C atom (HeMeNet*) as input, resulting in a best PPA RMSE of 0.861. To further understand the internal transfer of information within the tasks, we choose several different ratios to include different amounts of samples from the PPA tasks and keep the samples from other tasks unchanged. As shown in Figure 4, our model performs better as the training samples of PPA increase. And it performs much better than its single-task counterpart, even with a small amount of training samples. These results demonstrate that the model can handle challenging tasks (complex-level affinity prediction) better when more structural information is available (e.g., single-chain structures and their labels).
Our model performs harmonious multi-task training on property prediction tasks. We observe that when switching from the single-task to multi-task setting, baseline models experience performance degradation to some extent across the four property prediction tasks.
This is probably because of task interference among diverse tasks, and combining different tasks for training without careful adaption can harm performance.
With the guidance of our task-aware readout module, our model is able to learn from multiple tasks in a task-harmonious way, while achieving performance on the property prediction tasks comparable to their single-task counterparts with the same parameter size.
Method | EC\(\uparrow\) | GO-MF\(\uparrow\) | GO-BP\(\uparrow\) | GO-CC\(\uparrow\) | ||||
---|---|---|---|---|---|---|---|---|
1-5 Gearnet\(_s\) | 0.187 | 0.203 | 0.261 | 0.379 | ||||
Gearnet\(_w\) | 0.066 | 0.164 | 0.271 | 0.414 | ||||
Gearnet\(_t\) | 0.421 | 0.310 | 0.287 | 0.403 | ||||
HeMeNet\(_s\) | 0.722 | 0.558 | 0.302 | 0.413 | ||||
HeMeNet\(_w\) | 0.325 | 0.312 | 0.276 | 0.440 | ||||
HeMeNet\(_t\) | 0.810 | 0.727 | 0.379 | 0.436 |
To verify the effectiveness of our task-aware readout module, we take HeMeNet and Gearnet as the backbone and compare the task-aware readout method with two commonly used readout functions: sum readout and task-prompt weighted node feature [40]. The results are presented in Table 2. We can conclude with the following observations:
Our proposed task-aware readout model injects task-related information using an attention mechanism, leading to overall improvements for various tasks, especially on the Enzyme Commission task.
Simply element-wise multiplication of the task prompt feature with all the nodes fails to provide sufficient guidance to learning across all tasks.
To further investigate the relationship between multiple tasks, we calculate Pearson’s correlation between prompts. As shown in Figure [fig5], the correlations between tasks within the same category (e.g. EC and MF) are high, while the correlations between tasks from different categories (e.g. LBA and BP) are low. A high correlation between prompts indicates similar attention queries, leading to similar readout functions. This suggests that with the task-aware guidance, the model employs similar readout strategies for tasks from the same category and divergent strategies from tasks from different categories.
0.3cm
Method | LBA\(\downarrow\) | PPA\(\downarrow\) | EC\(\uparrow\) | GO | ||
---|---|---|---|---|---|---|
5-7 | MF\(\uparrow\) | BP\(\uparrow\) | CC\(\uparrow\) | |||
HeMeNet | 1.730 | 1.087 | 0.810 | 0.727 | 0.379 | 0.436 |
- TAR | 1.905 | 1.970 | 0.722 | 0.558 | 0.302 | 0.413 |
- \(e_r, W_r, w_r\) | 1.790 | 1.446 | 0.547 | 0.663 | 0.359 | 0.391 |
- full-atom | 1.799 | 0.861 | 0.630 | 0.595 | 0.279 | 0.426 |
We perform ablation experiments to evaluate the necessity of different components, including the task-aware readout, the relational message passing mechanism, and the full-atom geometry. Specifically, the ablation of TAR replaces the task-aware readout module with a sum readout. For \(e_r, W_r\) and \(w_r\), we remove different types of edges and the relational message passing weights. For full-atom ablation, we represent the coordinates of residues by their alpha C atoms.
We present the results for ablation studies in Table 3, the observations are as follows:
Without our task-aware readout strategy, significant performance degradation are observed in all tasks, indicating that the tasks can hinder each other without appropriate guidance.
Without the heterogeneous graph and the relational message passing, our model’s performance drops on property prediction tasks, especially the Enzyme Commission number prediction.
Removing the full-atom geometry decreases the performance in multiple tasks. However, it improves our model’s performance in PPA. Similar to the explanation in Section 5.2, we suppose that the large number of atoms in the full-atom protein-protein complex introduces excessive noise to prediction compared with input with alpha-Carbon atoms only.
In this paper, we alleviate the problem of sparse data in structured protein datasets by a multi-task setting. We construct a standard multi-task benchmark Protein-MT, consisting of 6 representative tasks integrated from 4 public datasets for joint learning. We propose a novel network called HeMeNet to address multiple tasks in protein 3D learning, with a novel heterogeneous equivariant full-atom encoder and a task-aware readout module. Comprehensive experiments demonstrate our model’s performance on the affinity and property prediction tasks. Our work brings insights for utilizing different structural datasets to train a more powerful generalist model in future research.
We adopt the data set from [29]. This data set contains 19201 PDB chains from 538 EC numbers, selected from the third and fourth levels of EC tree. The GO terms with at least 50 and not more than 5,000 training examples are selected. The number of classes in each branch of GO (MF, BP, CC) is 490, 1944 and 321, respectively. This data set consolidates 36641 PDB chains with their GO terms label. We obtain the structure of PDB chains from TorchDrug [41].
We adopt the data set from PDBbind [8] (version 2020). This dataset contains 5316 protein-ligand complexes in the refined set and 2852 protein-protein complexes with known binding data (in the form of \(K_d, \;K_i, \;IC_{50}\) values) manually collected from the original references [8]. We obtain the structure of the complexes from PDBbind. The PDBbind dataset can be downloaded from http://www.pdbbind.org.cn.
As described in section [sec:sec:dataconstruction], we yield 1327 fully-labeled complexes.Specifically, we employ MMSeq2 to cluster all the chains in Protein-MT with an alignment coverage \(>30\%\) and sequence identity of the aligned fragment \(>90\%\), leading to 33704 chain-level clusters. Then, we merge the clusters that contain the chains belonging to the same complex and finally get 30034 clusters. For the fully-sampled complexes, we randomly split them into the training, validation and test sets, with the number of complexes as 328, 530, and 469, respectively. For the partially labeled samples, we only retain those samples located in clusters different from the above test complexes and add them into the training set, resulting in an augmented training set with a total of 31252 samples. After the procedure of labeling and splitting, we get a new dataset named Protein-MT. Table 5 shows the detailed statistics of Protein-MT in different tasks. Note that the sample number for multi-task is slightly different from that in Table 4 since we removed samples with atom numbers greater than 15,000.
Clusters after merged | Train set size | Validation set size | Test set size |
---|---|---|---|
30034 | 31252 | 530 | 469 |
Task Name | # | ||||||
samples | # | ||||||
samples | # | ||||||
samples | |||||||
missing ratios | |||||||
seq length | # Average atoms | Chains | |||||
multi-task | 30904 | 516 | 467 | 99.9% | 325.67 | 2537.68 | Mixed |
LBA | 3247 | 493 | 452 | 89.5% | 425.29 | 3337.21 | Multi |
PPA | 1976 | 23 | 15 | 93.6% | 846.53 | 6604.35 | Multi |
EC | 15025 | 516 | 467 | 51.4% | 331.68 | 2585.91 | Single |
GO-MF | 22997 | 516 | 467 | 25.6% | 311.18 | 2420.01 | Single |
GO-BP | 21626 | 516 | 467 | 30.0% | 305.90 | 2382.08 | Single |
GO-CC | 10543 | 516 | 467 | 65.9% | 302.47 | 2346.91 | Single |
In this section, we mainly introduce \(T_R\) and \(T_S\) from [4] and our modification for the heterogeneous graph.
The Geometric Relation Extractor \(T_R\) can deal with coordinate sets with different channels. Given \(\displaystyle {\boldsymbol{X}}_i \in \mathbb{R}^{3\times c_i}\) and \(\displaystyle {\boldsymbol{X}}_j \in \mathbb{R}^{3\times c_j}\), we can compute the channel-wise distance for each coordinate pairs: \(D_{ij}(p,q) = ||\displaystyle {\boldsymbol{X}}_{i}(:, p) - \displaystyle {\boldsymbol{X}}_{j}(:,q)||_{2}\). Different from [4], we use two fixed binary vectors \(\displaystyle {\boldsymbol{w}}_i \in \mathbb{R}^{c_i\times 1}\) and \(\displaystyle {\boldsymbol{w}}_j \in \mathbb{R}^{c_j\times 1}\), when there is an element in the channel, its weight is set to 1, otherwise 0. We also adjusted the learnable attribute matrices \(\displaystyle {\boldsymbol{A}}_i \in \mathbb{R}^{c_i \times d}\) and \(\displaystyle {\boldsymbol{A}}_j \in \mathbb{R}^{c_j \times d}\) to be suitable to our input, assigning different element embedding for each channel. The final output \(\displaystyle {\boldsymbol{R}}_{ij} \in \mathbb{R}^ {d \times d}\) is given by:
\[\displaystyle {\boldsymbol{R}}_{ij} = \displaystyle {\boldsymbol{A}}^T_i(\displaystyle {\boldsymbol{w}}_i \displaystyle {\boldsymbol{w}}_j^T \odot \displaystyle {\boldsymbol{D}}_{ij}) \displaystyle {\boldsymbol{A}}_j. \label{eq8}\tag{5}\]
\(\displaystyle {\boldsymbol{R}}_{ij}\) keeps its shape awareness of \(c_i\) and \(c_j\).
The Geometric Message Scaler \(T_S\) aims to generate geometric information of vary coordinate set \(\displaystyle {\boldsymbol{X}}\in \mathbb{R}^{3 \times c}\) with the fixed length incoming message \(s = \phi_x(\displaystyle {\boldsymbol{m}}_{ij}) \in \mathbb{R}^C\), where \(C=14\) is the max channel size of the common amino acids. Then, \(T_S(\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}})\) is calculated by:
\[\displaystyle {\boldsymbol{X}}' = \displaystyle {\boldsymbol{X}}\cdot \mathrm{diag}(s'), \label{eq9}\tag{6}\]
where \(s' \in \mathbb{R}^c\) is the average pooling of \(\displaystyle {\boldsymbol{s}}\) with a sliding window of size \(C-c+1\) and stride 1, and \(\mathrm{\cdot}\) is a diagonal matrix with the input vector \(s\) as the diagonal elements.
In this section, we introduce the implementation details of all baselines and our model. For all models, we concatenate the hidden output for the final output. For the multi-task setting, all the models except HeMeNet take the sum readout method. The feature after the readout function will be fed into six two-layer MLPs to make predictions for different prediction tasks. The input for models are full-atom with KNN edges, except for GearNet and HeMeNet.
GCN [32], GAT [33] and SchNet [34]. We take the implementation in PyTorch Geometric [42], with a 3-layer design. For all the models, the hidden size is set to 256.
GearNet [15]. We re-implement GearNet with reference to its official implementation, with a six-layer design. The hidden size is set to 512, and the cutoff value is 4.5 following the original settings. For the multi-task setting, we take the sum readout method. We use the alpha-Carbon atom only graph for GearNet as the input.
EGNN [6]. We re-implement EGNN with reference to its official implementation, with a 3-layer design. The hidden size is set to 256.
GVP [16].We take the implementation in PyTorch Geometric [42], with a 3-layer design. The hidden size is set to 128 following the original implementation.
dyMEAN [4]. We re-implement dyMEAN with reference to its official implementation, with a 6-layer design. The hidden size is set to 256.
HeMeNet (ours). We take a 6-layer design for our model, and the hidden size is set to 256. We take our task-aware readout module to generate features for different tasks. We use the full-atom heterogeneous graph for HeMeNet as the input.
In order to findout the effect of different tasks to LBA&PPA tasks, we further train our model on different task combinations (LBA&PPA+one property task). And the results can be seen in Table 6.
Tasks | LBA | PPA | EC | MF | BP | CC |
---|---|---|---|---|---|---|
LP&EC | 1.758 | 1.064 | 0.628 | - | - | - |
LP&MF | 1.787 | 1.190 | - | 0.671 | - | - |
LP&BP | 1.820 | 1.089 | - | - | 0.374 | - |
LP&CC | 1.798 | 1.215 | - | - | - | 0.392 |
All | 1.730 | 1.087 | 0.810 | 0.727 | 0.379 | 0.436 |
As shown in the table, different property prediction tasks can improve the performance of LBA and PPA prediction, compared with the single-task settings. However, adding six tasks together will result in an overall better result. We suppose this is because the property prediction tasks shares a high correlation with each other, and each of them contains information from different aspects. Combining them will inprover the overall performance, benefiting our model by increasing the sample size and label diversity.
In this section, we will provide the equivariance of out HeMeNet encoder. Notice that the task-aware readout function only utilize the invariant feature with geometric information encoded for the downstream invariant tasks, including affinity prediction and property prediction. Therefore, we will prove the equivariance of the encoder and the invariance of the task-aware readout function.
Given the input (\(h_i\),\(\displaystyle {\boldsymbol{X}}_i\)), we have the output (\(\tilde{h}_i\),\(\tilde{\displaystyle {\boldsymbol{X}}_i}\)) = \(\mathrm{H}_e(h_i,\displaystyle {\boldsymbol{X}}_i)\), where \(\mathrm{H}_e\) is the abbreviation of HeMeNet encoder. \(\mathrm{H}_e\) possesses the good property of E(3) equivariance. In other words, for any transformations \(g \in E(3)\), we have (\(\tilde{h}_i\),\(g \cdot \tilde{\displaystyle {\boldsymbol{X}}_i}\)) = \(\mathrm{H}_e(h_i,g \cdot \displaystyle {\boldsymbol{X}}_i)\)
The geometric relation extractor is E(3) invariant. Specifically, \(\forall \displaystyle {\boldsymbol{X}}_i \in \mathbb{R}^{3 \times c_i}\), \(\displaystyle {\boldsymbol{X}}_j \in \mathbb{R}^{3 \times c_j}\), suppose \(\displaystyle {\boldsymbol{R}}_{ij} = \mathrm{T_R}(\displaystyle {\boldsymbol{X}}_i, \displaystyle {\boldsymbol{X}}_j)\), for any \(g \in E(3)\), we have \(\displaystyle {\boldsymbol{R}}_{ij} = \mathrm{T_R}(g \cdot \displaystyle {\boldsymbol{X}}_i, g \cdot \displaystyle {\boldsymbol{X}}_j)\), where \(g = \displaystyle {\boldsymbol{Q}}x + t, Q \in O(3), t \in \mathbb{R}^3\).
Proof. Since in the calculation of \(\mathbf{R}_{ij}\), only \(\mathbf{D}_{ij}\) is related to the input of \(\mathbf{X}\) (as shown in Equation 5 ), the invariance of \(\mathbf{R}_{ij}\) is equivalent of the invarinace of \(\mathbf{D}_{ij}\). We now prove the invariance of \(\mathbf{D}_{ij}\):
\[\begin{align} \mathbf{D}_{ij} & = ||(Q \displaystyle {\boldsymbol{X}}_i(:,p) + t) - (Q \displaystyle {\boldsymbol{X}}_j(:,q) + t)||_2 \\ & = ||Q (\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q))||_2 \\ & = \sqrt{[\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)]^TQ^TQ[\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)]} \\ & = \sqrt{[\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)]^T[\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)]} \\ & = ||\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)||_2 \label{eq10} \end{align}\tag{7}\]
The other terms in \(\mathrm{T_R}\) ,such as \(\displaystyle {\boldsymbol{A}}_i\) and \(\omega\), dose not change with respect to the transformation of \(\displaystyle {\boldsymbol{X}}_{ij}\). Therefore, \(\mathbf{R}_{ij}\) is E(3) invariant. Similarly, because other parts in Equation 1 dose not change with respect to the input of \(\mathbf{X}\). Therefore, \(\displaystyle {\boldsymbol{R}}_{ij} = \mathrm{T_R}(g \cdot \displaystyle {\boldsymbol{X}}_i, g \cdot \displaystyle {\boldsymbol{X}}_j)\). \(\hfill\qed\)
The geometric message scaler \(\mathrm{T_S}\) is O(3) equivariant. Specifically, \(\forall \displaystyle {\boldsymbol{X}}\in \mathbb{R}^{3\times c}, s\in \mathbb{R}^C\), suppose \(X' = \mathrm{T_S}(\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}})\), \(\forall Q \in O(3)\), we have \(Q\displaystyle {\boldsymbol{X}}' = \mathrm{T_S}(Q\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}})\).
Proof. We can derive as follows:
\[\begin{align} \mathrm{T_S}(Q\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}}) &= Q\displaystyle {\boldsymbol{X}}\cdot diag(s') \\ &= Q(\displaystyle {\boldsymbol{X}}\cdot diag(s'))\\ &= Q\mathrm{T_S}({\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}}}) \\ &= Q\displaystyle {\boldsymbol{X}}' \label{eq11} \end{align}\tag{8}\] Therefore, \(Q\displaystyle {\boldsymbol{X}}' = \mathrm{T_S}(Q\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}})\). \(\hfill\qed\)
With these lemmas, we can prove Theorem 1 now:
Proof. Since \(\displaystyle {\boldsymbol{h}}_i, \displaystyle {\boldsymbol{h}}_j, \displaystyle{\boldsymbol{e}}_r\) will not change with respect to the transformation of \(\displaystyle {\boldsymbol{X}}\), \(\forall g \in E(3)\), we can get:
\[\begin{align} \displaystyle {\boldsymbol{m}}_{ijr} &= \phi_m(\displaystyle {\boldsymbol{h}}_i^{(l)}, \displaystyle {\boldsymbol{h}}_j^{(l)}, \frac{\mathrm{T_R}(g\cdot\displaystyle \vec{{\boldsymbol{X}}}_{i}^{(l)}, g\cdot\displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})}{||\mathrm{T_R}(g\cdot\displaystyle \vec{{\boldsymbol{X}}}_{i}^{(l)}, g\cdot\displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})||_F+\epsilon}) \\ &= \phi_m(\displaystyle {\boldsymbol{h}}_i^{(l)}, \displaystyle {\boldsymbol{h}}_j^{(l)}, \frac{\mathrm{T_R}(\displaystyle \vec{{\boldsymbol{X}}}_{i}^{(l)}, \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})}{||\mathrm{T_R}(\displaystyle \vec{{\boldsymbol{X}}}_{i}^{(l)}, \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})||_F+\epsilon}), \label{eq12} \end{align}\tag{9}\]
and with the invariance of \(\displaystyle {\boldsymbol{m}}_{ijr}\), \(\forall g \in E(3)\), that is, \(\forall g\cdot \displaystyle {\boldsymbol{X}}= Q\displaystyle {\boldsymbol{X}}+ \displaystyle {\boldsymbol{t}}\), we can now derive the equivariance of \(\displaystyle \vec{{\boldsymbol{M}}}_{ijr}\):
\[\begin{align} g\cdot \displaystyle \vec{{\boldsymbol{M}}}_{ijr} &= \mathrm{T_S}(Q\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} + \displaystyle {\boldsymbol{t}}- \frac{1}{c_j}\sum_{k=1}^{c_j} Q\displaystyle \vec{{\boldsymbol{X}}}_j^{(l)}(:,k) \\ &+ \displaystyle {\boldsymbol{t}}, \phi_x(\displaystyle {\boldsymbol{m}}_{ijr})) \\ &= \mathrm{T_S}(Q\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} - \frac{1}{c_j}\sum_{k=1}^{c_j} Q\displaystyle \vec{{\boldsymbol{X}}}_j^{(l)}(:,k), \phi_x(\displaystyle {\boldsymbol{m}}_{ijr})) \\ & = Q\mathrm{T_S}(\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} - \frac{1}{c_j}\sum_{k=1}^{c_j} \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)}(:,k), \phi_x(\displaystyle {\boldsymbol{m}}_{ijr})), \label{eq13} \end{align}\tag{10}\]
we can now prove the E(3) equivariance of our HeMeNet encoder \(H_e\): \[\begin{align} g\cdot{\boldsymbol{h}}_i^{(l+1)} &= {\boldsymbol{h}}_i^{(l)} + \sigma(\mathrm{BN}(\phi_h(\sum_{r \in R} {\boldsymbol{W}}_r\sum_{j \in {\mathcal{N}}_r(i)}g\cdot{\boldsymbol{m}}_{ijr}))) \\ &= {\boldsymbol{h}}_i^{(l)} + \sigma(\mathrm{BN}(\phi_h(\sum_{r \in R} {\boldsymbol{W}}_r\sum_{j \in {\mathcal{N}}_r(i)}{\boldsymbol{m}}_{ijr}))) \\ &= {\boldsymbol{h}}_i^{(l+1)}, \\ g\cdot\displaystyle \vec{{\boldsymbol{X}}}_i^{(l+1)} &= g\cdot\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} + \frac{1}{\sum_{r \in R}|\mathcal{N}_r(i)|}\sum_{r \in R}\sum_{j \in \mathcal{N}_r(i)}w_r g\cdot\displaystyle \vec{{\boldsymbol{M}}}_{ijr} \\ &=Q\vec{{\boldsymbol{X}}}_i^{(l)} + \displaystyle{\boldsymbol{t}} + \frac{1}{\sum_{r \in R}|\mathcal{N}_r(i)|}\sum_{r \in R}\sum_{j \in \mathcal{N}_r(i)}w_r Q\displaystyle \vec{{\boldsymbol{M}}}_{ijr}\\ &= Q[\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} + \frac{1}{\sum_{r \in R}|\mathcal{N}_r(i)|}\sum_{r \in R}\sum_{j \in \mathcal{N}_r(i)}w_r \displaystyle \vec{{\boldsymbol{M}}}_{ijr}] + \displaystyle {\boldsymbol{t}}\\ &= Q\displaystyle \vec{{\boldsymbol{X}}}_i^{(l+1)} + \displaystyle {\boldsymbol{t}}. \label{eq14} \end{align}\tag{11}\]
Therefore, \(\forall g \in E(e)\), (\(\tilde{h}_i\),\(g \cdot \tilde{\displaystyle {\boldsymbol{X}}_i}\)) = \(\mathrm{H}_e(h_i,g \cdot \displaystyle {\boldsymbol{X}}_i)\). \(\hfill \qed\)
Given the input \((\tilde{h}, \tilde{\displaystyle {\boldsymbol{X}}})\), we have the output \((h', \displaystyle {\boldsymbol{X}}') = \mathrm{TAR}(\tilde{h}, \tilde{\displaystyle {\boldsymbol{X}}})\). TAR possesses the property of E(3) invariance. In other words, for any transformation \(g\in E(3)\), we have \((h', \displaystyle {\boldsymbol{X}}') = \mathrm{TAR}(\tilde{h}, g\cdot\tilde{\displaystyle {\boldsymbol{X}}})\).
Proof. According to Equation 3 , the TAR only calculates over the invariant feature \(\tilde{h}\). The coordinates information \(\tilde{\displaystyle {\boldsymbol{X}}}\) is identically set as the final output. Therefore, TAR is E(3) invariant. \(\hfill \qed\)
To test how GPT4 can be used to predict the binding affinity of protein-protein and protein-ligand, we provide the sequence of the receptor and the ligand (in the form of either SMILES or protein sequence) . The system prompt is designed as follows:
System Prompt: You are a drug assistant and should be able to help with drug discovery tasks. Given the SMILES sequence of a drug and the FASTA sequence of a protein target, you need to calculate the binding affinity score. You can think step-by-step to get the answer and call any function you want. You should try your best to estimate the affinity with tools. The output should be a float number, which is the estimated affinity score without other words.
We implemented a 3-shot in-context prompt as the input, we provide 3 smiles, protein sequences and affinity examples as follows:
Example 1:
CC[C@H](C)[C@H](NC(=O)OC)C(=O)N1CCC[C@H] 1c1ncc(-c2ccc3cc(-c4ccc5[nH]c([C@H?]6CCCN6C(=O) [C@H?](NC(=O)OC)[C@H?](C)OC)nc5c4)ccc3c2) [nH]1,
MDSIQAEEWYFGKITRRESERLLLNAENPRGTFLVR ESETTKGAYCLSVSDFDNAKGLNVKHYKIRKLDS GGFYITSRTQFNSLQQLVAYYSKHADGLCHRLTT VCP
11.52
Example 2:
[H]C1:C([H]):C(S(=O)(=O)N([H])[H]):C([H]):C([H]):C: 1/N=N/N1C([H])([H])C([H])([H])C([H])([H])C([H])([H]) C([H])([H])C1([H])[H],
HWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHT AKYDPSLKPLSVSYDQATSLRILNNGHAFNVE FDDSQDKAVLKGGPLDGTYRLIQFHFHWGSL DGQGSEHTVDKKKYAAELHLVHWNTKYGDFG KAVQQPDGLAVLGIFLKVGSAKPGLQKVVDV LDSIKTKGKSADFTNFDPRGLLPESLDYWTY PGSLTTPPLLECVTWIVLKEPISVSSEQVLKFRK LNFNGEGEPEELMVDNWRPAQPLKNRQIKASFK 6.5
Example 3:
[H]/C1=C(C([H])([H])C(=O)N([H])C2:C([H]):C([H]) :C(S(=O)(=O)N([H])[H]):C([H]):C:2[H])C2:C([H]): C([H]):C(OC([H]) ([H])[H]):C([H]):C:2OC1=O,
HWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHT AKYDPSLKPLSVSYDQATSLRILNNGHAFNVE FDDSQDKAVLKGGPLDGTYRLIQFHFHWGSLD GQGSEHTVDKKKYAAELHLVHWNTKYGDFGKA VQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDS IKTKGKSADFTNFDPRGLLPESLDYWTYPGSL TTPPLLECVTWIVLKEPISVSSEQVLKFRKLN
FNGEGEPEELMVDNWRPAQPLKNRQIKASFK
For smiles, fasta in messages
Test input:
{smiles},
{fasta},
Although Protein-MT creates a new benchmark for geometric protein multi-task learning, it is hard to add more tasks into Protein-MT while maintaining the fully-labeled sample size under our definition of fully-labeled data. Relaxing the restriction of the test set can alleviate this issue. Meanwhile, the input is now a mixture of single chains and complexes, we cai randomly augment single-chain samples from their original PDB complex to form ‘complexes’ and label them based on their UniProt IDs. Besides, we only consider invariant tasks in this work, we can also extend our model to more tasks in future work (e.g. equivariant tasks).
Code Availability: https://github.com/hanrthu/GMSL↩︎
The UniProt dataset is the world’s leading non-redundant protein sequence and function dataset and it identifies proteins by their UniProt IDs.↩︎