HeMeNet: Heterogeneous Multichannel Equivariant Network for Protein Multi-task Learning

Ting Chen\(^1\) \(^*\)
\(^1\)Tsinghua University, \(^2\)Renmin University of China, \(^3\)Ant Group
\(^*\) Corresponding author


Abstract

Understanding and leveraging the 3D structures of proteins is central to a variety of biological and drug discovery tasks. While deep learning has been applied successfully for structure-based protein function prediction tasks, current methods usually employ distinct training for each task. However, each of the tasks is of small size, and such a single-task strategy hinders the models’ performance and generalization ability. As some labeled 3D protein datasets are biologically related, combining multi-source datasets for larger-scale multi-task learning is one way to overcome this problem. In this paper, we propose a neural network model to address multiple tasks jointly upon the input of 3D protein structures. In particular, we first construct a standard structure-based multi-task benchmark called Protein-MT, consisting of 6 biologically relevant tasks, including affinity prediction and property prediction, integrated from 4 public datasets. Then, we develop a novel graph neural network for multi-task learning, dubbed Heterogeneous Multichannel Equivariant Network (HeMeNet), which is E(3) equivariant and able to capture heterogeneous relationships between different atoms. Besides, HeMeNet can achieve task-specific learning via the task-aware readout mechanism. Extensive evaluations on our benchmark verify the effectiveness of multi-task learning, and our model generally surpasses state-of-the-art models.1

1 Introduction↩︎

Proteins consist of one or more chains of amino acids, and they are vital in many biological systems. The 3D structure of a protein sets the foundation of its interaction with other molecules, which finally determines its functions. In recent years, learning-based methods have been applied widely to leverage the 3D structures of proteins for various tasks such as property prediction [1], affinity prediction [2], rigid docking [3], and antibody generation [4], owing to their superior efficiency and lower cost compared to those wet-lab experimental approaches. A major part of learning-based methods resort to Graph Neural Networks (GNNs) [5], which naturally encode the 3D structures of proteins by modeling atoms or residues as nodes and the connections in between as edges. In addition, certain GNNs are geometry-aware and designed to capture the symmetry of E(3) transformations for better predictions [6], [7].

Figure 1: Comparison of different models with tasks. Full-atom models (left) predict binding affinity with interface information; Alpha-Carbon models (right) predict protein functions with chain information. They need to be retrained for each task. HeMeNet (middle) supports various full-atom input information and predicts all six tasks simultaneously. We omit the edges for simplicity.

Despite significant progress in geometric-aware GNNs for protein tasks, existing methods usually employ one model for one task. A clear drawback of such a single-task training strategy is that the model should be re-trained for each new task. However, structural datasets with annotations are often limited in size due to the expensive cost of acquiring protein 3D structures and labels via wet lab experiments, especially for affinity prediction. For example, in PDBbind [8], only 2852 complexes for the Protein-Protein Affinity (PPA) are experimentally annotated. Due to the sparsity of labeled structure samples, conducting model training on a single-task dataset of small size usually leads to defective performance and an inability to generalize.

To deal with the sparsity of labeled data issues, some previous works leverage multi-task learning, which designs a model and trains it with multiple related tasks. Collecting samples with related tasks can bring more information to improve the performance [9], [10]. However, most of these works are sequence-based [11], [12], and the formulations are separated annotations for different samples. The work by[13] is a structure-based multi-task method, but it mainly focuses on residue-level interface prediction.

Recent research shows that some protein properties potentially imply the protein’s binding activity. For example, gene ontology would contain knowledge for protein-protein interaction [10]; Enzyme commission and gene ontology can provide molecular context for protein-ligand binding affinity (LBA) [9]. These indicate that some single-chain functions may benefit the prediction of complex-level affinity and vice versa. Motivated by this fact, we propose combining affinity and property prediction datasets in the framework of joint training.

A key problem hindering structural data integration is the lack of appropriate models for various inputs and tasks. As shown in Figure 1, many affinity prediction models [2], [14] utilize full-atom information at the binding interface to predict affinity, which loses the information of the whole chain. While many function prediction models [15], [16] utilize alpha-carbon to predict chain-level functions, which loses the detailed atom interaction information for affinity prediction.

In this paper, we propose a structure-based multi-task learning paradigm. We use a heterogeneous full-atom model to deal with multiple tasks upon various 3D protein inputs. Nevertheless, accomplishing structural multi-task training is challenging. The first challenge is that there is no available benchmark. The ideal benchmark should cover a sufficient range of data and biologically related task types, with a fully labeled test set to compare how a model performs on the same input for different task outputs. The second challenge is that it is nontrivial to design a generalist model that is capable of processing the complicated 3D structures of input proteins of various types, including single-chain, protein-protein, and protein-ligand, and it should perform well across different tasks, including protein affinity and property predictions. By achieving the structure-based full-atom protein multi-task learning, we make the following contributions:

  • To the best of our knowledge, we are the first to propose the concept of structure-based protein multi-task learning. We carefully integrate the structures and labels from 4 public datasets with our proposed standard process and construct a new benchmark named Protein Multiple Tasks (Protein-MT), which consists of 6 representative tasks upon 3 different types of inputs.

  • We propose a novel model for protein structure learning, dubbed Heterogeneous Multichannel Equivariant Network (HeMeNet), which is E(3) equivariant and able to capture various relationships between different atoms owing to the heterogeneous multichannel graph construction of proteins. Additionally, we develop a task-aware readout mechanism by associating the output head of each task with a learnable task prompt for different tasks.

  • For the experiments on Protein-MT, HeMeNet surpasses other state-of-the-art methods in most tasks under both the single-task and multi-task settings. Particularly on the LBA and PPA tasks, we find that the multi-task HeMeNet is significantly better than its single-task counterpart.

2 Related Works↩︎

2.0.0.1 Protein Interaction and Property Prediction

Predicting the binding affinity and properties for proteins with computational methods is of growing interest [17], [18]. Previous research learns protein representations by information different forms, most of which take amino acid sequence [19], [20], multiple sequence alignment [21] or 3D structure [15], [22] as input. Many works encode the information of a protein’s 3D structure by GNNs [15], [23]. [2] take full-atom geometry at the interaction interface, and [15] take residue-level geometry of the protein for property prediction. Our method utilizes full-atom geometry on the whole protein to address affinity and property prediction tasks together.

2.0.0.2 Equivariant GNNs

Many equivariant GNNs have emerged recently with the inductive bias of 3D symmetry, modeling various tasks including docking, molecular function prediction and sequence design [6], [24][26]. To empower the model with the ability to handle the complicated full-atom geometry, some models design multi-channel equivariant message passing for atom sets, such as GMN [7] and dyMEAN [4]. We propose a powerful heterogeneous equivariant GNN capable of handling various incoming message types.

Figure 2: Construction of Protein-MT. We first extract the UniProt ID for each chain and construct a UniProt-Property dictionary to map the UniProt ID of each protein chain with EC and GO-MF, GO-BP, GO-CC labels annotated in the EC and GO datasets. With this dictionary, we can extract each chain’s UniProt ID and map it with its labels. The complex with one affinity label and all property labels for each chain is defined as fully-labeled. We take most of the fully-labeled data for val/test and most of the partially labeled data for training.

2.0.0.3 Protein Multi-Task Learning

Multi-task learning takes advantage of knowledge transfer from multiple tasks, achieving a better generalization performance. In the field of protein, several works leverage multi-task learning on the task of interaction prediction and property prediction, most of which are for sequence-based models. [27] design three Enzyme Commission number related hierarchical tasks to train the model.  [28] introduce a multi-task protein pre-training method with prompts.  [11], [12] are sequence-based multi-task benchmarks for protein function prediction.  [13] improves the protein interaction interface prediction by structural multitask auxiliary learning. To the best of our knowledge, we are the first to combine structure-based interaction prediction and property prediction in a multi-task setting.

3 New Dataset: Protein-MT↩︎

Based on the observation that protein property and binding affinity tasks may benefit each other, we construct a new dataset called Protein Multiple Tasks (Protein-MT) for protein multi-task learning. Protein-MT is composed of different types of tasks on 3D protein structures: the prediction of Ligand Binding Affinity (LBA) and Protein-Protein Affinity (PPA) based on two-instance complexes and the prediction of Enzyme Commission (EC) number and Gene Ontology (GO) terms based on single-chain structures. Particularly, the LBA and PPA tasks originated from the PDBbind database [8] aim at regressing the affinity value of a protein-ligand complex and protein-protein complex, respectively. The EC task is constructed by [29] to describe the catalysis of biochemical reactions consisting of samples, each with 538 binary-class labels. The GO task aims to predict the hierarchically related functional properties of gene products [29]: Molecular Function (MF), Biological Process (BP), and Cellular Component (CC). We treat the prediction of MF, BP and CC as three individual tasks, resulting in six different prediction tasks in total.

One key difficulty in integrating these tasks from their sourced datasets is that samples from one task may lack the labels for other tasks. It is crucial to obtain samples with a complete set of labels across tasks for the training and evaluation of multi-task learning methods. As shown in Figure 2, we propose a standard matching pipeline that enables us to transfer the labels between EC and GO, and assign EC and GO labels for the chains of complexes in LBA and PPA as well (it is impossible to conduct the inverse direction since it is meaningless to assign LBA or PPA for those single chains in EC and GO). Specifically, we utilize the UniProt ID to uniquely identify a protein chain2. We first obtain the UniProt IDs of all protein chains in Protein-MT from Protein Data Bank [30]. For each UniProt ID, we then determine the EC and GO properties based on the labels of the corresponding chains in the EC and GO datasets, resulting in a UniProt-Property dictionary. With this dictionary, for a chain missing EC or GO labels (e.g., a chain of a complex in LBA and PPA), we can supplement the missing labels by searching the UniProt-Property dictionary by its UniProt ID to retrieve any known EC and GO labels. We define a complex (from either LBA or PPA) as fully-labeled if the complex has one affinity label (LBA or PPA) and four function labels for each of its chains. After our above matching process, we formulate the train/validation/test split in terms of the chain-level sequence identity through the alignment methods commonly used in single-chain property prediction tasks [29]. For more details of the construction process and dataset statistics, please refer to Appendix 7.

4 Methodology↩︎

In this section, we first introduce our heterogeneous graph representation and the multi-task formulation in Section 4.1. Then, we design the architecture of the proposed HeMeNet in Section 4.2, which consists of two key components: heterogeneous multi-channel equivariant message passing and task-aware readout.

4.1 Heterogeneous Graph Representation and Task Formulation↩︎

The input of our model is of various types. It could be either a two-instance complex (protein-ligand for LBA and protein-protein for PPA) or a single chain (for EC and GO). Here, for consistency, we unify these two different kinds of input as a graph \({\mathcal{G}}\) composed of two sets of nodes \({\mathcal{V}}_{r}\) and \({\mathcal{V}}_{l}\). For the LBA complex input, \({\mathcal{V}}_{r}\) and \({\mathcal{V}}_{l}\) denote the receptor and the ligand, respectively, while for the PPA complex and single-chain input, \({\mathcal{V}}_{r}\) refers to the receptor protein chain and \({\mathcal{V}}_{l}\) becomes the corresponding binding protein chain. And for function prediction tasks, \({\mathcal{V}}_{r}\) refers to the protein chain and \({\mathcal{V}}_{l}\) becomes an empty set, as shown in the middle of Figure 1. We associate each node \(v_i\) with the representation \(({\boldsymbol{h}}_i, \vec{{\boldsymbol{X}}}_i)\), where \({\boldsymbol{h}}_i\in\mathbb{R}^d\) denotes the node feature and it is initialized as a learnable residue embedding, \(\vec{{\boldsymbol{X}}}_i\in\mathbb{R}^{3\times c_i}\) indicates the 3D coordinates of all \(c_i\) atoms within the node. As for edge construction, we include various types of edges. In detail, for residue nodes, we allow \(R\) heterogeneous types of edge connections including sequential edges of different distances (\(d = \{-2, -1, 1, 2\}\)), self-loop edges, and spatial edges; for single-atom nodes from small molecules, only spatial edges are created. We present a simplified example from the LBA task in Figure 3, where we only draw a few nodes and omit the self-loop edges except for the central node for simplicity. Overall, we obtain a full-atom heterogeneous graph representation \({\mathcal{G}}\) for each input.

4.1.0.1 Task Formulation

Given a full-atom heterogeneous graph \({\mathcal{G}}\), our goal is to design a model \({\boldsymbol{p}}=f({\mathcal{G}})\) with multiple-dimensional output \({\boldsymbol{p}}\) that is able to predict the complex-level affinity and chain-level functional properties simultaneously. By making use of our proposed dataset Protein-MT, we train the model with a partially labeled training set and test it on the fully-labeled test set. Notably, the prediction should be invariant with regard to E(3) transformation (rotation/reflection/translation) of the input coordinates. To do so, we will formulate an equivariant encoder plus an invariant output layer in our model, detailed in the next subsection.

4.2 HeMeNet: Heterogeneous Multi-channel Equivariant Network↩︎

To better cope with the 3D structures of different types for different tasks, we propose a heterogeneous multi-channel E(3) equivariant graph neural network with the ability to aggregate different relational messages. After several layers of the message passing, the node representations are transformed into task-specific representations by a task-aware readout module, generating appropriate complex-level and chain-level predictions via different task heads.

Figure 3: Overview of our pipeline. Left: HeMeNet takes two-instance complexes or a single chain as input and predicts complex-level affinity and chain-level properties simultaneously. Middle: An example of the heterogeneous graph and the relational equivariant message passing. We only annotate a small part of our multi-channel full-atom graph for simplicity. Each edge is bidirectional, and we only mark the incoming edge arrow and self-loop for the center node. Right: Task-aware readout module. We take a task prompt as the query for each task, generating attention maps for all the nodes to get a multi-level readout for different downstream tasks.

4.2.0.1 Heterogeneous Multi-channel Equivariant Message Passing

Inspired by dyMEAN [4], we leverage a multi-channel coordinate matrix with dynamic size to record the geometric information of a node in an input graph. Moreover, we extend the setting to heterogeneous message passing along multiple types of edges to capture rich relationships between nodes. We denote the node feature and coordinates as \(({\boldsymbol{h}}_i^{(l)}, \vec{{\boldsymbol{X}}}_i^{(l)})\) in the \(l\)-th layer. The message passing is calculated as: \[\begin{align} {\boldsymbol{m}}_{ijr} &= \phi_m({\boldsymbol{h}}_i^{(l)}, {\boldsymbol{h}}_j^{(l)}, \frac{T_R(\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)}, \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})}{||T_R(\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)}, \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})||_F +\epsilon}, {\boldsymbol{e}}_r), \\ \displaystyle \vec{{\boldsymbol{M}}}_{ijr} &= T_S(\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} - \frac{1}{c_j}\sum_{k=1}^{c_j}\displaystyle \vec{{\boldsymbol{X}}}_{j}^{(l)}(:, k), \phi_{x}({\boldsymbol{m}}_{ijr})), \label{E1} \end{align}\tag{1}\] where, \({\boldsymbol{m}}_{ijr}\) and \(\vec{{\boldsymbol{M}}}_{ijr}\) are separately the invariant and equivariant messages from node \(j\) to \(i\) along the \(r\)-th edge; \({\boldsymbol{e}}_r\) is the edge embedding feature; \(\phi_m, \phi_x\) are Multi-Layer Perceptrons (MLPs) [31] with one hidden layer; \(||\cdot||_F\) computes the Frobenius norm, \(T_R\) and \(T_S\) are the adaptive multichannel geometric relation extractor and geometric message scaler, in order to deal with the issue incurred by the varying shape of \(\vec{{\boldsymbol{X}}}_i^{(l)}\) and \(\vec{{\boldsymbol{X}}}_j^{(l)}\) since the number of atoms could be different for different nodes. With the calculated messages, the node representation is updated by: \[\begin{align} {\boldsymbol{h}}_i^{(l+1)} &= {\boldsymbol{h}}_i^{(l)} + \sigma(\mathrm{BN}(\phi_h(\sum_{r \in R} {\boldsymbol{W}}_r\sum_{j \in {\mathcal{N}}_r(i)}{\boldsymbol{m}}_{ijr}))), \\ \displaystyle \vec{{\boldsymbol{X}}}_i^{(l+1)} &= \displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} + \frac{1}{\sum_{r \in R}|\mathcal{N}_r(i)|}\sum_{r \in R}\sum_{j \in \mathcal{N}_r(i)}w_r \displaystyle \vec{{\boldsymbol{M}}}_{ijr}, \label{escxgjdf} \end{align}\tag{2}\]

where, \({\boldsymbol{W}}_r, w_r\) are a learnable matrix and a learnable scalar to project invariant an equivariant messages, respectively, for the \(r\)-th kind of edge; \({\mathcal{N}}_r(i)\) denotes the neighbor nodes of \(i\) regarding the \(r\)-th kind of edges; \(\phi_h\) is an MLP, \(\mathrm{BN}\) is the batch normalization operation, and \(\sigma\) is an activation function. During the message-passing process, our model gathers information from different relations for \({\boldsymbol{h}}_i\) and \(\vec{{\boldsymbol{X}}}_i\), ensuring the E(3) equivariance. For further details of the components in our model, please refer to Appendix 8.

Task-Aware ReadoutAfter \(L\) layers of the above relational message passing, we attain a set of E(3) invariant node features \({\boldsymbol{H}}^{(L)}\in\mathbb{R}^{n\times d_L}\), where \(n\) is the number of nodes and \(d_L\) is the feature dimension. To correlate the task-specific information with each node feature, we propose a task-aware readout function. We compute the attention weights between each node and each task-specific query, and then readout all weighted node features into dual-level representations: graph level or chain level for each task. As shown in Figure 3, the task-aware readout module is formulated as: \[\begin{align} \displaystyle \boldsymbol{\alpha}_{t} &= \mathrm{Softmax}(\frac{\displaystyle\displaystyle {\boldsymbol{K}}{\boldsymbol{q}}_{t}}{\sqrt{d_L}}),\\ \displaystyle {\boldsymbol{f}}_t &= \mathrm{FFN} (\displaystyle \boldsymbol{\alpha}_t{\boldsymbol{V}}+ \mathrm{Linear}(\displaystyle {\boldsymbol{q}}_t)), \label{E5} \end{align}\tag{3}\] where, \(\boldsymbol{\alpha}_t\in[0,1]^{n}\) defines the attention values for task \(t\); \(\displaystyle {\boldsymbol{q}}_t \in \mathbb{R}^{d_L}\) is the learnable query for task \(t\); \(\displaystyle {\boldsymbol{K}}=\displaystyle {\boldsymbol{H}}\displaystyle {\boldsymbol{W}}_K\in\mathbb{R}^{n\times d_L}\) and \(\displaystyle {\boldsymbol{V}}=\displaystyle {\boldsymbol{H}}\displaystyle {\boldsymbol{W}}_V\in\mathbb{R}^{n\times d_L}\) are the key and value matrices, respectively; \(\mathrm{FFN}\) is the feed-forward network containing layer normalization and linear layers; \(\mathrm{Linear}(\displaystyle {\boldsymbol{q}}_t) = \displaystyle {\boldsymbol{W}}_Q \displaystyle {\boldsymbol{q}}_t + \displaystyle {\boldsymbol{b}}\) is used before the shortcut addition. In our implementation, we apply the multi-head attention strategy by defining multiple queries for each task. Although we compute the attention by only using the invariant features \({\boldsymbol{H}}^{(L)}\), it indeed has involved the geometric information from the 3D coordinates during the previous \(L\)-layer message passing.

Multiple Task Heads We feed the above task-specific feature \(\displaystyle {\boldsymbol{f}}_i\) into different task heads implemented by MLPs, resulting in a prediction list \((\displaystyle {\boldsymbol{p}}_{\mathrm{lba}}, {\boldsymbol{p}}_{\mathrm{ppa}}, {\boldsymbol{p}}_{\mathrm{ec}},{\boldsymbol{p}}_{\mathrm{mf}}, {\boldsymbol{p}}_{\mathrm{bp}}, {\boldsymbol{p}}_{\mathrm{cc}})\). For regression tasks (including LBA and PPA), we use the Mean Square Error (MSE) loss \(\mathcal{L}_{\mathrm{MSE}}\). For classification tasks (including EC, GO-MF, GO-BP and GO-CC), we use the Binary Cross Entropy (BCE) loss \(\mathcal{L}_{\mathrm{BCE}}\). The training loss is formulated as: \[\begin{align} \mathcal{L} &= \mathcal{L}_{\mathrm{reg}} + \lambda \mathcal{L}_{\mathrm{cls}}, \label{E2} \end{align}\tag{4}\] where \(\mathcal{L}_{\mathrm{reg}} = \displaystyle \boldsymbol{1}_\mathrm{lba}\mathcal{L}_{\mathrm{MSE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{lba}}) + \displaystyle \boldsymbol{1}_\mathrm{ppa}\mathcal{L}_{\mathrm{MSE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{ppa}}), \mathcal{L}_{\mathrm{cls}} = \displaystyle \boldsymbol{1}_\mathrm{ec}\mathcal{L}_{\mathrm{BCE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{ec}}) + \displaystyle \boldsymbol{1}_\mathrm{mf}\mathcal{L}_{\mathrm{BCE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{mf}}) + \displaystyle \boldsymbol{1}_\mathrm{bp}\mathcal{L}_{\mathrm{BCE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{bp}}) + \displaystyle \boldsymbol{1}_\mathrm{cc}\mathcal{L}_{\mathrm{BCE}}(\displaystyle {\boldsymbol{p}}_{\mathrm{cc}})\), \(\lambda\) is a hyper-parameter to balance the trade-off of the losses. To allow training on the partially labeled sample, if the label of the task \(*\) exists, then \(\displaystyle \boldsymbol{1}_\mathrm{*} = \lambda_\mathrm{*}\), otherwise \(\displaystyle \boldsymbol{1}_\mathrm{*} = 0\). In addition, we adopt a balanced sampling strategy to ensure that each sampled mini-batch should contain at least one sample from each task, which further accelerates the training convergence.

5 Experiments↩︎

In this section, we will first introduce the experimental setup in Section 5.1. In Section 5.2, we evaluate our model on the proposed dataset Protein-MT for affinity and property prediction in both single-task and multi-task settings and compare it with other baseline models. In Section 5.3, we experiment with different readout strategies and compare their performance on property prediction tasks. In Section 5.4, we perform ablation experiments on different modules.

5.1 Experimental Setup↩︎

0.3cm

Table 1: The mean result for three runs on the full-label test set. We select representative invariant and equivariant models for affinity prediction and property prediction. The upper half reports the results for the single-task setting, and the lower half reports the results for the multi-task setting. The best results are marked in bold and the second best results are underlined. In the multi-task setting, we train the models with the same size compared to their corresponding single-task models.
Method LBA PPA EC\(\uparrow\) GO
3-4 (lr)5-6 (lr)8-10 RMSE\(\downarrow\) MAE\(\downarrow\) RMSE\(\downarrow\) MAE\(\downarrow\) MF\(\uparrow\) BP\(\uparrow\) CC\(\uparrow\)
GCN [32] 2.193 1.721 7.840 7.738 0.022 0.207 0.254 0.367
GAT [33] 2.301 1.838 7.820 7.720 0.018 0.223 0.249 0.354
SchNet [34] 2.162 1.692 7.839 7.729 0.097 0.311 0.281 0.431
GearNet* [15] 1.957 1.542 2.004 1.279 0.716 0.677 0.252 0.438
GearNet-fullatom [15] 2.178 1.716 2.753 2.709 0.046 0.212 0.229 0.471
2-10 EGNN [6] 2.282 1.849 4.854 4.756 0.039 0.206 0.253 0.357
GVP [16] 2.281 1.789 5.280 5.267 0.020 0.204 0.244 0.454
dyMEAN [4] 2.410 1.987 7.309 7.182 0.115 0.436 0.292 0.477
HemeNet (Ours) 1.912 1.490 6.031 5.891 0.863 0.778 0.404 0.544
1-10 SchNet [34] 1.763 1.447 1.216 1.120 0.093 0.192 0.264 0.402
GearNet* [15] 2.193 1.863 1.275 1.035 0.187 0.203 0.261 0.379
GearNet-fullatom [15] 1.839 1.350 1.821 1.491 0.047 0.155 0.258 0.443
CDConv [35] 1.579 1.352 2.386 1.822 0.324 0.246 0.241 0.424
2-10 EGNN [6] 1.777 1.441 0.999 0.821 0.048 0.169 0.244 0.352
GVP [16] 1.870 1.572 0.906 0.758 0.018 0.168 0.246 0.360
dyMEAN [4] 1.777 1.446 1.725 1.523 0.038 0.164 0.263 0.449
HeMeNet* (Ours) 1.799 1.420 0.861 0.719 0.630 0.595 0.279 0.426
HeMeNet (Ours) 1.730 1.335 1.087 0.912 0.810 0.727 0.379 0.436
1-10 GPT4-turbo-1106 (ST) [36] 2.347 1.780 1.654 1.343 - - - -
ESM2 (MT) [37] 2.009 1.334 1.692 1.333 0.917 0.764 0.389 0.533
ESM2-HeMeNet (MT) 1.867 1.661 1.846 1.418 0.921 0.796 0.455 0.567

represents trained under the alpha-Carbon atom only setting. ST, MT is the abbreviation of single-task and multi-task, respectively.

Task settings We compare the performance of HeMeNet with other models under single-task and multi-task settings using the same validation and test sets. For single-task training, models are trained on samples with labels of the corresponding task. We also remove the task-aware readout of our model for a fair comparison. For multi-task training, the models are trained on all partially labeled training samples.

We use a balanced sampler for each batch to sample at least one complex from LBA and one complex from PPA. We include samples with up to 15,000 atoms for training and evaluation. All models are trained for 30 epochs on 4 NVIDIA A100 GPUs. More details on experimental settings can be found in Appendix 9

Baselines We compared our model with nine representative baselines. GCN [32] aggregates information weighted by the degree of nodes. GAT [33] utilizes an attention mechanism for message passing. Schnet [34] is an invariant network with continuous filter convolution on the 3D molecular graph. GearNet [15] designs a relational message-passing network to capture information on protein function tasks. Its variant Gearnet-fullatom [38] utilizes the model to the full-atom setting. CDConv [35] models the geometric sequence with a continuous-discrete convolution. Besides the previous invariant models, we also compare our method with equivariant models. EGNN [6] is a lightweight but effective E(n) equivariant graph neural network. GVP [16] designs an equivariant geometric vector perceptron for protein representation. dyMEAN [4] is an equivariant model for antibody design; it takes a dynamic multichannel equivariant function for full-atom coordinates. ESM2 [37] is a general purpose protein language model. We also conducted three-shot prompting to test GPT-4 [36] on the two affinity prediction tasks, see Appendix 10 for the prompt details.

Evaluation For LBA and PPA tasks, we employ the commonly used Root Mean Square Error (RMSE) and Mean Average Error (MAE) as the evaluation metrics [39]. For EC and GO tasks, we use maximum F-score (Fmax) following  [15]. Each experiment is independently run three times with different random seeds.

5.2 Results on Protein-MT↩︎

We conduct experiments under both single-task and multi-task settings. The mean results of three runs are reported in Table 1. According to the results, we draw conclusions summarized in the subsequent paragraphs.

Our model outperforms the baselines on most of the tasks under both settings. Under the single-task setting, our model surpasses other models in five of the six tasks. Under the multi-task setting, our model surpasses other models in four of the six tasks, with the remaining two tasks reaching second and third place, respectively. We also compared our model with GPT4 and ESM2 (with multi-task head finetuning). Our method outperforms GPT4 in both LBA and PPA. ESM2 performs well on four property prediction tasks, and combining ESM2 and HeMeNet can further improve ESM2’s overall performance with geometric information. Notably, under the single-task setting, only the models with a heterogeneous message passing (GearNet and ours) can perform well on all of the four property prediction tasks. Under the multi-task setting, our full-atom model, benefiting from joint learning, shows superior results on different tasks, and there are two main interesting observations discussed next.

Our model benefits from the multi-task setting, especially on LBA and PPA. We observe that almost all models improve their performance on LBA and PPA tasks under the multi-task setting. In particular, our model significantly improves the PPA RMSE from 6.031 to 1.087 by utilizing a training set that is more than ten times larger (2587 for PPA single-task and 30904 for our multi-task training set). We also train our model with alpha C atom (HeMeNet*) as input, resulting in a best PPA RMSE of 0.861. To further understand the internal transfer of information within the tasks, we choose several different ratios to include different amounts of samples from the PPA tasks and keep the samples from other tasks unchanged. As shown in Figure 4, our model performs better as the training samples of PPA increase. And it performs much better than its single-task counterpart, even with a small amount of training samples. These results demonstrate that the model can handle challenging tasks (complex-level affinity prediction) better when more structural information is available (e.g., single-chain structures and their labels).

Our model performs harmonious multi-task training on property prediction tasks. We observe that when switching from the single-task to multi-task setting, baseline models experience performance degradation to some extent across the four property prediction tasks.

This is probably because of task interference among diverse tasks, and combining different tasks for training without careful adaption can harm performance.

With the guidance of our task-aware readout module, our model is able to learn from multiple tasks in a task-harmonious way, while achieving performance on the property prediction tasks comparable to their single-task counterparts with the same parameter size.

Table 2: Comparison of different readout functions for multi-task learning. \(s\), \(w\), and \(t\) represent sum, weighted feature and task-aware readout, respectively.
Method EC\(\uparrow\) GO-MF\(\uparrow\) GO-BP\(\uparrow\) GO-CC\(\uparrow\)
1-5 Gearnet\(_s\) 0.187 0.203 0.261 0.379
Gearnet\(_w\) 0.066 0.164 0.271 0.414
Gearnet\(_t\) 0.421 0.310 0.287 0.403
HeMeNet\(_s\) 0.722 0.558 0.302 0.413
HeMeNet\(_w\) 0.325 0.312 0.276 0.440
HeMeNet\(_t\) 0.810 0.727 0.379 0.436

5.3 Comparison of Different Readout Methods↩︎

To verify the effectiveness of our task-aware readout module, we take HeMeNet and Gearnet as the backbone and compare the task-aware readout method with two commonly used readout functions: sum readout and task-prompt weighted node feature [40]. The results are presented in Table 2. We can conclude with the following observations:

Our proposed task-aware readout model injects task-related information using an attention mechanism, leading to overall improvements for various tasks, especially on the Enzyme Commission task.

Simply element-wise multiplication of the task prompt feature with all the nodes fails to provide sufficient guidance to learning across all tasks.

Figure 4: PPA performance

To further investigate the relationship between multiple tasks, we calculate Pearson’s correlation between prompts. As shown in Figure [fig5], the correlations between tasks within the same category (e.g. EC and MF) are high, while the correlations between tasks from different categories (e.g. LBA and BP) are low. A high correlation between prompts indicates similar attention queries, leading to similar readout functions. This suggests that with the task-aware guidance, the model employs similar readout strategies for tasks from the same category and divergent strategies from tasks from different categories.

0.3cm

Table 3: Ablation study for different components in HeMeNet.
Method LBA\(\downarrow\) PPA\(\downarrow\) EC\(\uparrow\) GO
5-7 MF\(\uparrow\) BP\(\uparrow\) CC\(\uparrow\)
HeMeNet 1.730 1.087 0.810 0.727 0.379 0.436
- TAR 1.905 1.970 0.722 0.558 0.302 0.413
- \(e_r, W_r, w_r\) 1.790 1.446 0.547 0.663 0.359 0.391
- full-atom 1.799 0.861 0.630 0.595 0.279 0.426

5.4 Ablation Study↩︎

We perform ablation experiments to evaluate the necessity of different components, including the task-aware readout, the relational message passing mechanism, and the full-atom geometry. Specifically, the ablation of TAR replaces the task-aware readout module with a sum readout. For \(e_r, W_r\) and \(w_r\), we remove different types of edges and the relational message passing weights. For full-atom ablation, we represent the coordinates of residues by their alpha C atoms.

We present the results for ablation studies in Table 3, the observations are as follows:

Without our task-aware readout strategy, significant performance degradation are observed in all tasks, indicating that the tasks can hinder each other without appropriate guidance.

Without the heterogeneous graph and the relational message passing, our model’s performance drops on property prediction tasks, especially the Enzyme Commission number prediction.

Removing the full-atom geometry decreases the performance in multiple tasks. However, it improves our model’s performance in PPA. Similar to the explanation in Section 5.2, we suppose that the large number of atoms in the full-atom protein-protein complex introduces excessive noise to prediction compared with input with alpha-Carbon atoms only.

6 Conclusion↩︎

In this paper, we alleviate the problem of sparse data in structured protein datasets by a multi-task setting. We construct a standard multi-task benchmark Protein-MT, consisting of 6 representative tasks integrated from 4 public datasets for joint learning. We propose a novel network called HeMeNet to address multiple tasks in protein 3D learning, with a novel heterogeneous equivariant full-atom encoder and a task-aware readout module. Comprehensive experiments demonstrate our model’s performance on the affinity and property prediction tasks. Our work brings insights for utilizing different structural datasets to train a more powerful generalist model in future research.

7 Data↩︎

7.1 Dataset sources↩︎

7.1.1 Enzyme Commission and Gene Ontology↩︎

We adopt the data set from [29]. This data set contains 19201 PDB chains from 538 EC numbers, selected from the third and fourth levels of EC tree. The GO terms with at least 50 and not more than 5,000 training examples are selected. The number of classes in each branch of GO (MF, BP, CC) is 490, 1944 and 321, respectively. This data set consolidates 36641 PDB chains with their GO terms label. We obtain the structure of PDB chains from TorchDrug [41].

7.1.2 LBA and PPA↩︎

We adopt the data set from PDBbind [8] (version 2020). This dataset contains 5316 protein-ligand complexes in the refined set and 2852 protein-protein complexes with known binding data (in the form of \(K_d, \;K_i, \;IC_{50}\) values) manually collected from the original references [8]. We obtain the structure of the complexes from PDBbind. The PDBbind dataset can be downloaded from http://www.pdbbind.org.cn.

7.2 Protein-MT statstics↩︎

As described in section [sec:sec:dataconstruction], we yield 1327 fully-labeled complexes.Specifically, we employ MMSeq2 to cluster all the chains in Protein-MT with an alignment coverage \(>30\%\) and sequence identity of the aligned fragment \(>90\%\), leading to 33704 chain-level clusters. Then, we merge the clusters that contain the chains belonging to the same complex and finally get 30034 clusters. For the fully-sampled complexes, we randomly split them into the training, validation and test sets, with the number of complexes as 328, 530, and 469, respectively. For the partially labeled samples, we only retain those samples located in clusters different from the above test complexes and add them into the training set, resulting in an augmented training set with a total of 31252 samples. After the procedure of labeling and splitting, we get a new dataset named Protein-MT. Table 5 shows the detailed statistics of Protein-MT in different tasks. Note that the sample number for multi-task is slightly different from that in Table 4 since we removed samples with atom numbers greater than 15,000.

Table 4: Dataset split. The fully-labeled data are randomly divided into the train, validation and test sets. Partially labeled samples located in the clusters different from the above test complexes are retained and added to the training set.
Clusters after merged Train set size Validation set size Test set size
30034 31252 530 469
Table 5: Dataset details for different tasks. We summarize the number of samples that contains a specific task’s annotation in Protein-MT.
Task Name #
samples #
samples #
samples
missing ratios
seq length # Average atoms Chains
multi-task 30904 516 467 99.9% 325.67 2537.68 Mixed
LBA 3247 493 452 89.5% 425.29 3337.21 Multi
PPA 1976 23 15 93.6% 846.53 6604.35 Multi
EC 15025 516 467 51.4% 331.68 2585.91 Single
GO-MF 22997 516 467 25.6% 311.18 2420.01 Single
GO-BP 21626 516 467 30.0% 305.90 2382.08 Single
GO-CC 10543 516 467 65.9% 302.47 2346.91 Single

8 Details for HeMeNet components↩︎

In this section, we mainly introduce \(T_R\) and \(T_S\) from [4] and our modification for the heterogeneous graph.

The Geometric Relation Extractor \(T_R\) can deal with coordinate sets with different channels. Given \(\displaystyle {\boldsymbol{X}}_i \in \mathbb{R}^{3\times c_i}\) and \(\displaystyle {\boldsymbol{X}}_j \in \mathbb{R}^{3\times c_j}\), we can compute the channel-wise distance for each coordinate pairs: \(D_{ij}(p,q) = ||\displaystyle {\boldsymbol{X}}_{i}(:, p) - \displaystyle {\boldsymbol{X}}_{j}(:,q)||_{2}\). Different from [4], we use two fixed binary vectors \(\displaystyle {\boldsymbol{w}}_i \in \mathbb{R}^{c_i\times 1}\) and \(\displaystyle {\boldsymbol{w}}_j \in \mathbb{R}^{c_j\times 1}\), when there is an element in the channel, its weight is set to 1, otherwise 0. We also adjusted the learnable attribute matrices \(\displaystyle {\boldsymbol{A}}_i \in \mathbb{R}^{c_i \times d}\) and \(\displaystyle {\boldsymbol{A}}_j \in \mathbb{R}^{c_j \times d}\) to be suitable to our input, assigning different element embedding for each channel. The final output \(\displaystyle {\boldsymbol{R}}_{ij} \in \mathbb{R}^ {d \times d}\) is given by:

\[\displaystyle {\boldsymbol{R}}_{ij} = \displaystyle {\boldsymbol{A}}^T_i(\displaystyle {\boldsymbol{w}}_i \displaystyle {\boldsymbol{w}}_j^T \odot \displaystyle {\boldsymbol{D}}_{ij}) \displaystyle {\boldsymbol{A}}_j. \label{eq8}\tag{5}\]

\(\displaystyle {\boldsymbol{R}}_{ij}\) keeps its shape awareness of \(c_i\) and \(c_j\).

The Geometric Message Scaler \(T_S\) aims to generate geometric information of vary coordinate set \(\displaystyle {\boldsymbol{X}}\in \mathbb{R}^{3 \times c}\) with the fixed length incoming message \(s = \phi_x(\displaystyle {\boldsymbol{m}}_{ij}) \in \mathbb{R}^C\), where \(C=14\) is the max channel size of the common amino acids. Then, \(T_S(\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}})\) is calculated by:

\[\displaystyle {\boldsymbol{X}}' = \displaystyle {\boldsymbol{X}}\cdot \mathrm{diag}(s'), \label{eq9}\tag{6}\]

where \(s' \in \mathbb{R}^c\) is the average pooling of \(\displaystyle {\boldsymbol{s}}\) with a sliding window of size \(C-c+1\) and stride 1, and \(\mathrm{\cdot}\) is a diagonal matrix with the input vector \(s\) as the diagonal elements.

9 Implementation Details and hyperparameters↩︎

In this section, we introduce the implementation details of all baselines and our model. For all models, we concatenate the hidden output for the final output. For the multi-task setting, all the models except HeMeNet take the sum readout method. The feature after the readout function will be fed into six two-layer MLPs to make predictions for different prediction tasks. The input for models are full-atom with KNN edges, except for GearNet and HeMeNet.

GCN [32], GAT [33] and SchNet [34]. We take the implementation in PyTorch Geometric [42], with a 3-layer design. For all the models, the hidden size is set to 256.

GearNet [15]. We re-implement GearNet with reference to its official implementation, with a six-layer design. The hidden size is set to 512, and the cutoff value is 4.5 following the original settings. For the multi-task setting, we take the sum readout method. We use the alpha-Carbon atom only graph for GearNet as the input.

EGNN [6]. We re-implement EGNN with reference to its official implementation, with a 3-layer design. The hidden size is set to 256.

GVP [16].We take the implementation in PyTorch Geometric [42], with a 3-layer design. The hidden size is set to 128 following the original implementation.

dyMEAN [4]. We re-implement dyMEAN with reference to its official implementation, with a 6-layer design. The hidden size is set to 256.

HeMeNet (ours). We take a 6-layer design for our model, and the hidden size is set to 256. We take our task-aware readout module to generate features for different tasks. We use the full-atom heterogeneous graph for HeMeNet as the input.

10 More fine-grained task combinations↩︎

In order to findout the effect of different tasks to LBA&PPA tasks, we further train our model on different task combinations (LBA&PPA+one property task). And the results can be seen in Table 6.

Table 6: Results on different task combinations.
Tasks LBA PPA EC MF BP CC
LP&EC 1.758 1.064 0.628 - - -
LP&MF 1.787 1.190 - 0.671 - -
LP&BP 1.820 1.089 - - 0.374 -
LP&CC 1.798 1.215 - - - 0.392
All 1.730 1.087 0.810 0.727 0.379 0.436

As shown in the table, different property prediction tasks can improve the performance of LBA and PPA prediction, compared with the single-task settings. However, adding six tasks together will result in an overall better result. We suppose this is because the property prediction tasks shares a high correlation with each other, and each of them contains information from different aspects. Combining them will inprover the overall performance, benefiting our model by increasing the sample size and label diversity.

11 Equivariance of HeMeNet↩︎

In this section, we will provide the equivariance of out HeMeNet encoder. Notice that the task-aware readout function only utilize the invariant feature with geometric information encoded for the downstream invariant tasks, including affinity prediction and property prediction. Therefore, we will prove the equivariance of the encoder and the invariance of the task-aware readout function.

11.0.0.1 Theorem 1. Equivariance of the HeMeNet encoder

Given the input (\(h_i\),\(\displaystyle {\boldsymbol{X}}_i\)), we have the output (\(\tilde{h}_i\),\(\tilde{\displaystyle {\boldsymbol{X}}_i}\)) = \(\mathrm{H}_e(h_i,\displaystyle {\boldsymbol{X}}_i)\), where \(\mathrm{H}_e\) is the abbreviation of HeMeNet encoder. \(\mathrm{H}_e\) possesses the good property of E(3) equivariance. In other words, for any transformations \(g \in E(3)\), we have (\(\tilde{h}_i\),\(g \cdot \tilde{\displaystyle {\boldsymbol{X}}_i}\)) = \(\mathrm{H}_e(h_i,g \cdot \displaystyle {\boldsymbol{X}}_i)\)

11.0.0.2 Lemma 1.

The geometric relation extractor is E(3) invariant. Specifically, \(\forall \displaystyle {\boldsymbol{X}}_i \in \mathbb{R}^{3 \times c_i}\), \(\displaystyle {\boldsymbol{X}}_j \in \mathbb{R}^{3 \times c_j}\), suppose \(\displaystyle {\boldsymbol{R}}_{ij} = \mathrm{T_R}(\displaystyle {\boldsymbol{X}}_i, \displaystyle {\boldsymbol{X}}_j)\), for any \(g \in E(3)\), we have \(\displaystyle {\boldsymbol{R}}_{ij} = \mathrm{T_R}(g \cdot \displaystyle {\boldsymbol{X}}_i, g \cdot \displaystyle {\boldsymbol{X}}_j)\), where \(g = \displaystyle {\boldsymbol{Q}}x + t, Q \in O(3), t \in \mathbb{R}^3\).

Proof. Since in the calculation of \(\mathbf{R}_{ij}\), only \(\mathbf{D}_{ij}\) is related to the input of \(\mathbf{X}\) (as shown in Equation 5 ), the invariance of \(\mathbf{R}_{ij}\) is equivalent of the invarinace of \(\mathbf{D}_{ij}\). We now prove the invariance of \(\mathbf{D}_{ij}\):

\[\begin{align} \mathbf{D}_{ij} & = ||(Q \displaystyle {\boldsymbol{X}}_i(:,p) + t) - (Q \displaystyle {\boldsymbol{X}}_j(:,q) + t)||_2 \\ & = ||Q (\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q))||_2 \\ & = \sqrt{[\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)]^TQ^TQ[\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)]} \\ & = \sqrt{[\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)]^T[\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)]} \\ & = ||\displaystyle {\boldsymbol{X}}_i(:,p) - \displaystyle {\boldsymbol{X}}_j(:,q)||_2 \label{eq10} \end{align}\tag{7}\]

The other terms in \(\mathrm{T_R}\) ,such as \(\displaystyle {\boldsymbol{A}}_i\) and \(\omega\), dose not change with respect to the transformation of \(\displaystyle {\boldsymbol{X}}_{ij}\). Therefore, \(\mathbf{R}_{ij}\) is E(3) invariant. Similarly, because other parts in Equation 1 dose not change with respect to the input of \(\mathbf{X}\). Therefore, \(\displaystyle {\boldsymbol{R}}_{ij} = \mathrm{T_R}(g \cdot \displaystyle {\boldsymbol{X}}_i, g \cdot \displaystyle {\boldsymbol{X}}_j)\). \(\hfill\qed\)

11.0.0.3 Lemma 2.

The geometric message scaler \(\mathrm{T_S}\) is O(3) equivariant. Specifically, \(\forall \displaystyle {\boldsymbol{X}}\in \mathbb{R}^{3\times c}, s\in \mathbb{R}^C\), suppose \(X' = \mathrm{T_S}(\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}})\), \(\forall Q \in O(3)\), we have \(Q\displaystyle {\boldsymbol{X}}' = \mathrm{T_S}(Q\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}})\).

Proof. We can derive as follows:

\[\begin{align} \mathrm{T_S}(Q\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}}) &= Q\displaystyle {\boldsymbol{X}}\cdot diag(s') \\ &= Q(\displaystyle {\boldsymbol{X}}\cdot diag(s'))\\ &= Q\mathrm{T_S}({\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}}}) \\ &= Q\displaystyle {\boldsymbol{X}}' \label{eq11} \end{align}\tag{8}\] Therefore, \(Q\displaystyle {\boldsymbol{X}}' = \mathrm{T_S}(Q\displaystyle {\boldsymbol{X}}, \displaystyle {\boldsymbol{s}})\). \(\hfill\qed\)

With these lemmas, we can prove Theorem 1 now:

Proof. Since \(\displaystyle {\boldsymbol{h}}_i, \displaystyle {\boldsymbol{h}}_j, \displaystyle{\boldsymbol{e}}_r\) will not change with respect to the transformation of \(\displaystyle {\boldsymbol{X}}\), \(\forall g \in E(3)\), we can get:

\[\begin{align} \displaystyle {\boldsymbol{m}}_{ijr} &= \phi_m(\displaystyle {\boldsymbol{h}}_i^{(l)}, \displaystyle {\boldsymbol{h}}_j^{(l)}, \frac{\mathrm{T_R}(g\cdot\displaystyle \vec{{\boldsymbol{X}}}_{i}^{(l)}, g\cdot\displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})}{||\mathrm{T_R}(g\cdot\displaystyle \vec{{\boldsymbol{X}}}_{i}^{(l)}, g\cdot\displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})||_F+\epsilon}) \\ &= \phi_m(\displaystyle {\boldsymbol{h}}_i^{(l)}, \displaystyle {\boldsymbol{h}}_j^{(l)}, \frac{\mathrm{T_R}(\displaystyle \vec{{\boldsymbol{X}}}_{i}^{(l)}, \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})}{||\mathrm{T_R}(\displaystyle \vec{{\boldsymbol{X}}}_{i}^{(l)}, \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)})||_F+\epsilon}), \label{eq12} \end{align}\tag{9}\]

and with the invariance of \(\displaystyle {\boldsymbol{m}}_{ijr}\), \(\forall g \in E(3)\), that is, \(\forall g\cdot \displaystyle {\boldsymbol{X}}= Q\displaystyle {\boldsymbol{X}}+ \displaystyle {\boldsymbol{t}}\), we can now derive the equivariance of \(\displaystyle \vec{{\boldsymbol{M}}}_{ijr}\):

\[\begin{align} g\cdot \displaystyle \vec{{\boldsymbol{M}}}_{ijr} &= \mathrm{T_S}(Q\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} + \displaystyle {\boldsymbol{t}}- \frac{1}{c_j}\sum_{k=1}^{c_j} Q\displaystyle \vec{{\boldsymbol{X}}}_j^{(l)}(:,k) \\ &+ \displaystyle {\boldsymbol{t}}, \phi_x(\displaystyle {\boldsymbol{m}}_{ijr})) \\ &= \mathrm{T_S}(Q\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} - \frac{1}{c_j}\sum_{k=1}^{c_j} Q\displaystyle \vec{{\boldsymbol{X}}}_j^{(l)}(:,k), \phi_x(\displaystyle {\boldsymbol{m}}_{ijr})) \\ & = Q\mathrm{T_S}(\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} - \frac{1}{c_j}\sum_{k=1}^{c_j} \displaystyle \vec{{\boldsymbol{X}}}_j^{(l)}(:,k), \phi_x(\displaystyle {\boldsymbol{m}}_{ijr})), \label{eq13} \end{align}\tag{10}\]

we can now prove the E(3) equivariance of our HeMeNet encoder \(H_e\): \[\begin{align} g\cdot{\boldsymbol{h}}_i^{(l+1)} &= {\boldsymbol{h}}_i^{(l)} + \sigma(\mathrm{BN}(\phi_h(\sum_{r \in R} {\boldsymbol{W}}_r\sum_{j \in {\mathcal{N}}_r(i)}g\cdot{\boldsymbol{m}}_{ijr}))) \\ &= {\boldsymbol{h}}_i^{(l)} + \sigma(\mathrm{BN}(\phi_h(\sum_{r \in R} {\boldsymbol{W}}_r\sum_{j \in {\mathcal{N}}_r(i)}{\boldsymbol{m}}_{ijr}))) \\ &= {\boldsymbol{h}}_i^{(l+1)}, \\ g\cdot\displaystyle \vec{{\boldsymbol{X}}}_i^{(l+1)} &= g\cdot\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} + \frac{1}{\sum_{r \in R}|\mathcal{N}_r(i)|}\sum_{r \in R}\sum_{j \in \mathcal{N}_r(i)}w_r g\cdot\displaystyle \vec{{\boldsymbol{M}}}_{ijr} \\ &=Q\vec{{\boldsymbol{X}}}_i^{(l)} + \displaystyle{\boldsymbol{t}} + \frac{1}{\sum_{r \in R}|\mathcal{N}_r(i)|}\sum_{r \in R}\sum_{j \in \mathcal{N}_r(i)}w_r Q\displaystyle \vec{{\boldsymbol{M}}}_{ijr}\\ &= Q[\displaystyle \vec{{\boldsymbol{X}}}_i^{(l)} + \frac{1}{\sum_{r \in R}|\mathcal{N}_r(i)|}\sum_{r \in R}\sum_{j \in \mathcal{N}_r(i)}w_r \displaystyle \vec{{\boldsymbol{M}}}_{ijr}] + \displaystyle {\boldsymbol{t}}\\ &= Q\displaystyle \vec{{\boldsymbol{X}}}_i^{(l+1)} + \displaystyle {\boldsymbol{t}}. \label{eq14} \end{align}\tag{11}\]

Therefore, \(\forall g \in E(e)\), (\(\tilde{h}_i\),\(g \cdot \tilde{\displaystyle {\boldsymbol{X}}_i}\)) = \(\mathrm{H}_e(h_i,g \cdot \displaystyle {\boldsymbol{X}}_i)\). \(\hfill \qed\)

11.0.0.4 Theorem 2. Invariance of the Task-aware Readout

Given the input \((\tilde{h}, \tilde{\displaystyle {\boldsymbol{X}}})\), we have the output \((h', \displaystyle {\boldsymbol{X}}') = \mathrm{TAR}(\tilde{h}, \tilde{\displaystyle {\boldsymbol{X}}})\). TAR possesses the property of E(3) invariance. In other words, for any transformation \(g\in E(3)\), we have \((h', \displaystyle {\boldsymbol{X}}') = \mathrm{TAR}(\tilde{h}, g\cdot\tilde{\displaystyle {\boldsymbol{X}}})\).

Proof. According to Equation 3 , the TAR only calculates over the invariant feature \(\tilde{h}\). The coordinates information \(\tilde{\displaystyle {\boldsymbol{X}}}\) is identically set as the final output. Therefore, TAR is E(3) invariant. \(\hfill \qed\)

12 GPT4 Prompt information↩︎

To test how GPT4 can be used to predict the binding affinity of protein-protein and protein-ligand, we provide the sequence of the receptor and the ligand (in the form of either SMILES or protein sequence) . The system prompt is designed as follows:

System Prompt: You are a drug assistant and should be able to help with drug discovery tasks. Given the SMILES sequence of a drug and the FASTA sequence of a protein target, you need to calculate the binding affinity score. You can think step-by-step to get the answer and call any function you want. You should try your best to estimate the affinity with tools. The output should be a float number, which is the estimated affinity score without other words.

We implemented a 3-shot in-context prompt as the input, we provide 3 smiles, protein sequences and affinity examples as follows:

Example 1:

CC[C@H](C)[C@H](NC(=O)OC)C(=O)N1CCC[C@H] 1c1ncc(-c2ccc3cc(-c4ccc5[nH]c([C@H?]6CCCN6C(=O) [C@H?](NC(=O)OC)[C@H?](C)OC)nc5c4)ccc3c2) [nH]1,

MDSIQAEEWYFGKITRRESERLLLNAENPRGTFLVR ESETTKGAYCLSVSDFDNAKGLNVKHYKIRKLDS GGFYITSRTQFNSLQQLVAYYSKHADGLCHRLTT VCP

11.52

Example 2:

[H]C1:C([H]):C(S(=O)(=O)N([H])[H]):C([H]):C([H]):C: 1/N=N/N1C([H])([H])C([H])([H])C([H])([H])C([H])([H]) C([H])([H])C1([H])[H],

HWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHT AKYDPSLKPLSVSYDQATSLRILNNGHAFNVE FDDSQDKAVLKGGPLDGTYRLIQFHFHWGSL DGQGSEHTVDKKKYAAELHLVHWNTKYGDFG KAVQQPDGLAVLGIFLKVGSAKPGLQKVVDV LDSIKTKGKSADFTNFDPRGLLPESLDYWTY PGSLTTPPLLECVTWIVLKEPISVSSEQVLKFRK LNFNGEGEPEELMVDNWRPAQPLKNRQIKASFK 6.5

Example 3:

[H]/C1=C(C([H])([H])C(=O)N([H])C2:C([H]):C([H]) :C(S(=O)(=O)N([H])[H]):C([H]):C:2[H])C2:C([H]): C([H]):C(OC([H]) ([H])[H]):C([H]):C:2OC1=O,

HWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHT AKYDPSLKPLSVSYDQATSLRILNNGHAFNVE FDDSQDKAVLKGGPLDGTYRLIQFHFHWGSLD GQGSEHTVDKKKYAAELHLVHWNTKYGDFGKA VQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDS IKTKGKSADFTNFDPRGLLPESLDYWTYPGSL TTPPLLECVTWIVLKEPISVSSEQVLKFRKLN FNGEGEPEELMVDNWRPAQPLKNRQIKASFK
For smiles, fasta in messages

Test input:

{smiles},

{fasta},

13 Limitations and future work↩︎

Although Protein-MT creates a new benchmark for geometric protein multi-task learning, it is hard to add more tasks into Protein-MT while maintaining the fully-labeled sample size under our definition of fully-labeled data. Relaxing the restriction of the test set can alleviate this issue. Meanwhile, the input is now a mixture of single chains and complexes, we cai randomly augment single-chain samples from their original PDB complex to form ‘complexes’ and label them based on their UniProt IDs. Besides, we only consider invariant tasks in this work, we can also extend our model to more tasks in future work (e.g. equivariant tasks).

References↩︎

[1]
Guangyu Wang, Xiaohong Liu, Kai Wang, Yuanxu Gao, Gen Li, Daniel T Baptista-Hon, Xiaohong Helena Yang, Kanmin Xue, Wa Hou Tai, Zeyu Jiang, et al. Deep-learning-enabled protein-protein interaction analysis for prediction of sars-cov-2 infectivity and variant evolution. Nature Medicine, pages 1–12, 2023.
[2]
Shuangli Li, Jingbo Zhou, Tong Xu, Liang Huang, Fan Wang, Haoyi Xiong, Weili Huang, Dejing Dou, and Hui Xiong. Structure-aware interactive graph neural networks for the prediction of protein-ligand binding affinity. In Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining, pages 975–985, 2021.
[3]
Octavian-Eugen Ganea, Xinyuan Huang, Charlotte Bunne, Yatao Bian, Regina Barzilay, Tommi S. Jaakkola, and Andreas Krause. Independent SE(3)-equivariant models for end-to-end rigid protein docking. In International Conference on Learning Representations, 2022.
[4]
Xiangzhe Kong, Wenbing Huang, and Yang Liu. End-to-end full-atom antibody design. arXiv preprint arXiv:2302.00203, 2023.
[5]
Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graph neural networks? In International Conference on Learning Representations, 2019.
[6]
Vı́ctor Garcia Satorras, Emiel Hoogeboom, and Max Welling. E(n) equivariant graph neural networks. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 9323–9332. PMLR, 18–24 Jul 2021.
[7]
Wenbing Huang, Jiaqi Han, Yu Rong, Tingyang Xu, Fuchun Sun, and Junzhou Huang. Equivariant graph mechanics networks with constraints. In International Conference on Learning Representations, 2022.
[8]
Renxiao Wang, Xueliang Fang, Yipin Lu, and Shaomeng Wang. The pdbbind database: Collection of binding affinities for protein-ligand complexes with known three-dimensional structures. Journal of Medicinal Chemistry, 47(12):2977–2980, 2004. PMID: 15163179.
[9]
Chengxin Zhang, Xi Zhang, Peter L Freddolino, and Yang Zhang. Biolip2: an updated structure database for biologically relevant ligand–protein interactions. Nucleic Acids Research, page gkad630, 2023.
[10]
Xiaoxu Wang, Yijia Zhang, Peixuan Zhou, and Xiaoxia Liu. A supervised protein complex prediction method with network representation learning and gene ontology knowledge. BMC bioinformatics, 23(1):300, 2022.
[11]
Minghao Xu, Zuobai Zhang, Jiarui Lu, Zhaocheng Zhu, Yangtian Zhang, Ma Chang, Runcheng Liu, and Jian Tang. Peer: a comprehensive and multi-task benchmark for protein sequence understanding. Advances in Neural Information Processing Systems, 35:35156–35173, 2022.
[12]
Henriette Capel, Robin Weiler, Maurits Dijkstra, Reinier Vleugels, Peter Bloem, and K Anton Feenstra. Proteinglue multi-task benchmark suite for self-supervised protein modeling. Scientific Reports, 12(1):16047, 2022.
[13]
Henriette Capel, K Anton Feenstra, and Sanne Abeln. Multi-task learning to leverage partially annotated data for ppi interface prediction. Scientific Reports, 12(1):10487, 2022.
[14]
Xiangzhe Kong, Wenbing Huang, and Yang Liu. Generalist equivariant transformer towards 3d molecular interaction learning. arXiv preprint arXiv:2306.01474, 2023.
[15]
Zuobai Zhang, Minghao Xu, Arian Rokkum Jamasb, Vijil Chenthamarakshan, Aurelie Lozano, Payel Das, and Jian Tang. Protein representation learning by geometric structure pretraining. In The Eleventh International Conference on Learning Representations, 2023.
[16]
Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael John Lamarre Townshend, and Ron Dror. Learning from protein structure with geometric vector perceptrons. In International Conference on Learning Representations, 2021.
[17]
Huiwen Wang, Haoquan Liu, Shangbo Ning, Chengwei Zeng, and Yunjie Zhao. . Physical Chemistry Chemical Physics (Incorporating Faraday Transactions), 24(17):10124–10133, May 2022.
[18]
Yingwen Zhao, Jun Wang, Jian Chen, Xiangliang Zhang, Maozu Guo, and Guoxian Yu. A literature review of gene function prediction by modeling gene ontology. Frontiers in genetics, 11:400, 2020.
[19]
Ethan C Alley, Grigory Khimulya, Surojit Biswas, Mohammed AlQuraishi, George M Church, and George M Church. Unified rational protein engineering with sequence-based deep representation learning. Nature methods, 16(12):1315—1322, December 2019.
[20]
Roshan Rao, Nicholas Bhattacharya, Neil Thomas, Yan Duan, Peter Chen, John Canny, Pieter Abbeel, and Yun Song. Evaluating protein transfer learning with tape. In H. Wallach, H. Larochelle, A. Beygelzimer, F. dAlché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
[21]
Roshan M Rao, Jason Liu, Robert Verkuil, Joshua Meier, John Canny, Pieter Abbeel, Tom Sercu, and Alexander Rives. Msa transformer. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 8844–8856. PMLR, 18–24 Jul 2021.
[22]
Pedro Hermosilla, Marco Schäfer, Matej Lang, Gloria Fackelmann, Pere-Pau Vázquez, Barbora Kozlikova, Michael Krone, Tobias Ritschel, and Timo Ropinski. Intrinsic-extrinsic convolution and pooling for learning on 3d protein structures. In International Conference on Learning Representations, 2021.
[23]
Vladimir Gligorijević, P. Douglas Renfrew, Tomasz Kosciólek, Julia Koehler Leman, Daniel Berenberg, Tommi Vatanen, Chris Chandler, Bryn C. Taylor, Ian Fisk, Hera Vlamakis, et al. Structure-based protein function prediction using graph convolutional networks. Nature Communications, 12, 2021.
[24]
Nathaniel Thomas, Tess Smidt, Steven Kearnes, Lusann Yang, Li Li, Kai Kohlhoff, and Patrick Riley. . arXiv e-prints, page arXiv:1802.08219, feb 2018.
[25]
Johannes Gasteiger, Janek Groß, and Stephan Günnemann. Directional message passing for molecular graphs. In International Conference on Learning Representations, 2020.
[26]
Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik J Bekkers, and Max Welling. Geometric and physical quantities improve e(3) equivariant message passing. In International Conference on Learning Representations, 2022.
[27]
Zhenkun Shi, Rui Deng, Qianqian Yuan, Zhitao Mao, Ruoyu Wang, Haoran Li, Xiaoping Liao, and Hongwu Ma. Enzyme commission number prediction and benchmarking with hierarchical dual-core multitask learning framework. Research, 6:0153, 2023.
[28]
Zeyuan Wang, Qiang Zhang, Shuang-Wei HU, Haoran Yu, Xurui Jin, Zhichen Gong, and Huajun Chen. Multi-level protein structure pre-training via prompt learning. In The Eleventh International Conference on Learning Representations, 2023.
[29]
Vladimir Gligorijević, P Douglas Renfrew, Tomasz Kosciolek, Julia Koehler Leman, Daniel Berenberg, Tommi Vatanen, Chris Chandler, Bryn C Taylor, Ian M Fisk, Hera Vlamakis, et al. Structure-based protein function prediction using graph convolutional networks. Nature communications, 12(1):3168, 2021.
[30]
Helen M Berman, John Westbrook, Zukang Feng, Gary Gilliland, Talapady N Bhat, Helge Weissig, Ilya N Shindyalov, and Philip E Bourne. The protein data bank. Nucleic acids research, 28(1):235–242, 2000.
[31]
Matt W Gardner and SR Dorling. Artificial neural networks (the multilayer perceptron)—a review of applications in the atmospheric sciences. Atmospheric environment, 32(14-15):2627–2636, 1998.
[32]
Thomas N. Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations, 2017.
[33]
Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, and Yoshua Bengio. Graph attention networks. In International Conference on Learning Representations, 2018.
[34]
Kristof Schütt, Pieter-Jan Kindermans, Huziel Enoc Sauceda Felix, Stefan Chmiela, Alexandre Tkatchenko, and Klaus-Robert Müller. Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. In I. Guyon, U. Von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
[35]
Hehe Fan, Zhangyang Wang, Yi Yang, and Mohan Kankanhalli. Continuous-discrete convolution for geometry-sequence modeling in proteins. In The Eleventh International Conference on Learning Representations, 2022.
[36]
Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
[37]
Zeming Lin, Halil Akin, Roshan Rao, Brian Hie, Zhongkai Zhu, Wenting Lu, Nikita Smetanin, Robert Verkuil, Ori Kabeli, Yaniv Shmueli, et al. Evolutionary-scale prediction of atomic-level protein structure with a language model. Science, 379(6637):1123–1130, 2023.
[38]
Zuobai Zhang, Minghao Xu, Aurélie Lozano, Vijil Chenthamarakshan, Payel Das, and Jian Tang. Physics-inspired protein encoder pre-training via siamese sequence-structure diffusion trajectory prediction. arXiv preprint arXiv:2301.12068, 2023.
[39]
Raphael Townshend, Martin Vögele, Patricia Suriana, Alex Derry, Alexander Powers, Yianni Laloudakis, Sidhika Balachandar, Bowen Jing, Brandon Anderson, Stephan Eismann, et al. Atom3d: Tasks on molecules in three dimensions. In J. Vanschoren and S. Yeung, editors, Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks, volume 1. Curran, 2021.
[40]
Zemin Liu, Xingtong Yu, Yuan Fang, and Xinming Zhang. Graphprompt: Unifying pre-training and downstream tasks for graph neural networks. In Proceedings of the ACM Web Conference 2023, pages 417–428, 2023.
[41]
Zhaocheng Zhu, Chence Shi, Zuobai Zhang, Shengchao Liu, Minghao Xu, Xinyu Yuan, Yangtian Zhang, Junkun Chen, Huiyu Cai, Jiarui Lu, et al. Torchdrug: A powerful and flexible machine learning platform for drug discovery, 2022.
[42]
Matthias Fey and Jan E. Lenssen. Fast graph representation learning with PyTorch Geometric. In ICLR Workshop on Representation Learning on Graphs and Manifolds, 2019.

  1. Code Availability: https://github.com/hanrthu/GMSL↩︎

  2. The UniProt dataset is the world’s leading non-redundant protein sequence and function dataset and it identifies proteins by their UniProt IDs.↩︎