利用深度学习对蛋白质二级结构三态预测

通过输入蛋白质的氨基酸序列,预测其蛋白质结构,本文先通过训练100多万条的 pseudo label 数据,获得一个 pre-pretrain-model,接着使用3万多条labeled的数据进行fine-tuning,获得一个最终的模型作为预测

原创不易,共同学习.请诸位大神在转载时标明出处.

数据集示例

>1UCSA
NKASVVANQLIPINTALTLIMMKAEVVTPMGIPAEEIPKLVGMQVNRAVPLGTTLMPDMVKNYE
CCCEEEECCCECCCCECCHHHEEEECCCCCCCEHHHHHHHCCCEECCCECCCCECCHHHECCCC

依次对应为:

      seqs = {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8,'I': 9, 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16,'W': 17, 'Y': 18, 'V': 19}label = {'C': 0, 'H': 1, 'E': 2}

生成的对应矩阵为:

2 11 0 15 19 19 0 2 5 10 9 14 9 2 16 0 10 16 10 9 12 12 11 0 6 19 19 16 14 12 7 9 14 0 6 6 9 14 11 10 19 7 12 5 19 2 1 0 19 14 10 7 16 16 10 12 14 3 12 19 11 2 18 60 0 0 2 2 2 2 0 0 0 2 0 0 0 0 2 0 0 1 1 1 2 2 2 2 0 0 0 0 0 0 0 2 1 1 1 1 1 1 1 0 0 0 2 2 0 0 0 2 0 0 0 0 2 0 0 1 1 1 2 0 0 0 0

将矩阵作为输入,main代码如下:

import pdb
import sys
import os
from arg import getArgparse
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
import torch
from torch import nn
from network import S4PRED
from get_dataset import loadfasta
from torch.utils.data import DataLoader
import torch.optim as optim
import datetime
from sklearn.metrics import f1_score, precision_score, recall_score
curPath = os.path.abspath(os.path.dirname(__file__))
rootPath = os.path.split(curPath)[0]
sys.path.append(rootPath)start = datetime.datetime.now()
args_dict = getArgparse()
device = torch.device(args_dict['device'])
learn_rate = args_dict['learn_rate']
pre_train_epochs = args_dict['pre_train_epochs']
fine_tuning_epochs = args_dict['fine_tuning_epochs']
fine_tuning_batch_size = args_dict['fine_tuning_batch_size']
pre_train_batch_size = args_dict['pre_train_batch_size']
save_dpath = args_dict['save_path']
# test_flag = args_dict['test_flag']
if not os.path.exists(save_dpath):os.mkdir(save_dpath)
criterion = nn.CrossEntropyLoss(ignore_index=3)
model = S4PRED().to(device)
optimizer = optim.Adam(model.parameters(), lr=learn_rate, betas=(0.9, 0.999))
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.1)
pretrain_accuracy = [0.0]
fine_tuning_accuracy = [0.0]def main():test_loader = DataLoader(loadfasta("test"), batch_size=1, shuffle=False, num_workers=8)final_model_path = os.path.join(save_dpath, 'fine-tuning-train-best.pkl')pre_model_path = os.path.join(save_dpath, 'pre-train-best.pkl')if os.path.exists(final_model_path):print('Starting test...')test(model, test_loader)returntrain_loader = DataLoader(loadfasta("train"), batch_size=fine_tuning_batch_size, shuffle=True, num_workers=16)valid_loader = DataLoader(loadfasta("valid"), batch_size=1, shuffle=False, num_workers=8)if os.path.exists(pre_model_path):print('Load the pre-trained model...')model.load_state_dict(torch.load(os.path.join(save_dpath, 'pre-train-best.pkl')))print('Starting fine-tuning...')for epoch in range(fine_tuning_epochs):print('##Fine-Tuning Epoch-%s' % epoch)train(model, train_loader)valid(model, valid_loader, epoch, False)test(model, test_loader)else:print('Starting pre-train...')pre_train_loader = DataLoader(loadfasta("pre_train"), batch_size=pre_train_batch_size, shuffle=True, num_workers=32)for epoch in range(pre_train_epochs):print('##Pre-train Epoch-%s' % epoch)train(model, pre_train_loader)valid(model, valid_loader, epoch, True)print('Load the pre-trained model...')model.load_state_dict(torch.load(os.path.join(save_dpath, 'pre-train-best.pkl')))print('Starting fine-tuning...')for epoch in range(fine_tuning_epochs):print('##Fine-Tuning Epoch-%s' % epoch)train(model, train_loader)valid(model, valid_loader, epoch, False)print('Starting test...')test(model, test_loader)def train(model, train_loader):model.train()train_loss = 0epoch_acc = 0for i, data in enumerate(train_loader):sequence, label = dataoptimizer.zero_grad()output = model(sequence.to(torch.int).to(device))  # batch_size * 4 * 700loss = criterion(output, label.to(torch.long).to(device))loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.25)optimizer.step()train_loss += loss.item()if len(output.shape) == 2:output = torch.unsqueeze(output, 0)_, predicted = torch.max(output.data, 1)  # batch_size * 700# predicted  batch_size * 700# label  batch_size * 700epoch_acc += getAccuracy(predicted.to('cpu'), label.to(torch.long).to('cpu'))lr_scheduler.step()print('Train Accuracy: ', epoch_acc / len(train_loader), ' Train Loss', train_loss / len(train_loader),datetime.datetime.now() - start)def valid(model, valid_loader, epoch, pretrain):model.eval()val_loss = 0epoch_acc = 0with torch.no_grad():for i, data in enumerate(valid_loader):sequence, label = dataoutput = model(sequence.to(torch.int).to(device))  # batch_size * 4 * 700if len(output.shape) == 2:output = torch.unsqueeze(output, 0)loss = criterion(output, label.to(torch.long).to(device))val_loss += loss.item()_, predicted = torch.max(output.data, 1)  # batch_size * 700epoch_acc += getAccuracy(predicted.to('cpu'), label.to(torch.long).to(device))print('Val Accuracy: ', epoch_acc / len(valid_loader), ' Val Loss', val_loss / len(valid_loader))# save modelif pretrain:model_name = 'pre-train'accuracy = pretrain_accuracyelse:model_name = 'fine-tuning-train'accuracy = fine_tuning_accuracytorch.save(model.state_dict(), os.path.join(save_dpath, model_name + '_{}.pkl'.format(epoch)))torch.save(model.state_dict(), os.path.join(save_dpath, model_name + '_last.pkl'))if epoch_acc / len(valid_loader) > max(accuracy):torch.save(model.state_dict(), os.path.join(save_dpath, model_name + '-best.pkl'))accuracy.append(epoch_acc / len(valid_loader))def test(model, test_loader):model.eval()test_acc = 0f1 = 0precision = 0recall = 0with torch.no_grad():for i, data in enumerate(test_loader):sequence, label = datamodel.load_state_dict(torch.load(os.path.join(save_dpath, 'fine-tuning-train-best.pkl')))output = model(sequence.to(torch.int).to(device))  # batch_size * 4 * 700if len(output.shape) == 2:output = torch.unsqueeze(output, 0)# output = output[None]_, predicted = torch.max(output.data, 1)  # batch_size * 700test_acc += getAccuracy(predicted.to('cpu'), label.to(torch.long).to('cpu'))res = get_score(predicted.to('cpu'), label.to(torch.long).to('cpu'))f1 += res['f1']precision += res['precision']recall += res['recall']print('Test Accuracy: ', test_acc / len(test_loader), '\nF1 score', f1 / len(test_loader),'\nPrecision score', precision / len(test_loader),'\nRecall score', recall / len(test_loader), datetime.datetime.now() - start)def getAccuracy(output, label):accuracy = 0for i in range(label.size(0)):total = 0count = 0for j in range(label.size(1)):if label[i][j].item() != 3:total += 1if label[i][j].item() == output[i][j].item():count += 1else:breakaccuracy += count / totalreturn accuracy / label.size(0)def get_score(output, label):f1 = 0precision = 0recall = 0for i in range(label.size(0)):j = 0for j in range(len(label[i])):if label[i][j] == 3:breakf1 += f1_score(label[i][:j + 1], output[i][:j + 1], average='weighted', zero_division=1)precision += precision_score(label[i][:j + 1], output[i][:j + 1], average='weighted', zero_division=1)recall += recall_score(label[i][:j + 1], output[i][:j + 1], average='weighted', zero_division=1)return {'f1': f1 / label.size(0), 'precision': precision / label.size(0), 'recall': recall / label.size(0)}if __name__ == "__main__":main()

network代码:

import torch.nn as nn
import torch.nn.functional as Fclass ResidueEmbedding(nn.Embedding):def __init__(self, vocab_size=21, embed_size=128, padding_idx=None):super().__init__(vocab_size, embed_size, padding_idx=padding_idx)class GRUnet(nn.Module):def __init__(self,lstm_hdim=1024, embed_size=128, num_layers=3, bidirectional=True, lstm=False, outsize=4):super().__init__()self.lstm_hdim = lstm_hdimself.embed = ResidueEmbedding(vocab_size=22, embed_size=embed_size, padding_idx=21)self.lstm = nn.GRU(128, 1024, num_layers=3, bidirectional=True, batch_first=True, dropout=0)self.outlayer = nn.Linear(lstm_hdim*2, outsize)self.finalact = F.log_softmaxdef forward(self, x):x = self.embed(x) # torch.Size([8, 5980, 128])x, _ = self.lstm(x)x = self.outlayer(x)x = self.finalact(x, dim=-1)x = x.permute(0, 2, 1)# print('x3', x.shape)return x.squeeze()class S4PRED(nn.Module):def __init__(self):super().__init__()self.model_1 = GRUnet()def forward(self, x):y_1 = self.model_1(x)return y_1

参数输入代码arg.py

import argparsedef getArgparse():parser = argparse.ArgumentParser()parser.add_argument('--device', metavar='device', type=str, default='cpu',help='Device to run on, Either: cpu or coda (default; cpu)')parser.add_argument('--pre_train_batch_size', metavar='pre_train_batch_size', type=int, default='64',help='This is batch_size (default; 64)',)parser.add_argument('--fine_tuning_batch_size', metavar='fine_tuning_batch_size', type=int, default='32',help='This is batch_size (default; 32)',)parser.add_argument('--learn_rate', metavar='learn_rate', type=float, default='0.0001',help='this is learn_rate (default; 0.0001)')parser.add_argument('--pre_train_epochs', metavar='pre_train_epochs', type=int, default='10',help='this is pre_train_epochs (default; 10)')parser.add_argument('--fine_tuning_epochs', metavar='fine_tuning_epochs', type=int, default='1',help='this is fine_tuning_epochs (default; 1)')parser.add_argument('--save_path', metavar='lines', type=str, default='model/',help='this is for save model path (default; default)')args = parser.parse_args()return vars(args)# CUDA_VISIBLE_DEVICES=1 python main.py --device cuda --pre_train_batch_size 64 --fine_tuning_batch_size 32 --learn_rate 0.0001 --pre_train_epochs 10 --fine_tuning_epochs 5 --save_path model/first/# CUDA_VISIBLE_DEVICES=1 python main.py --device cuda --learn_rate 0.0001 --fine_tuning_epochs 5 --save_path model/first/

数据导入代码:

import numpy as np
import torch
import torch.utils.data as data_utilsdef loadfasta(type):if type == 'pre_train':seq = np.loadtxt('dataset/pseudo/pseudo_seq.txt')lab = np.loadtxt('dataset/pseudo/pseudo_lab.txt')dataset = data_utils.TensorDataset(torch.tensor(seq), torch.tensor(lab))return datasetif type == 'train':seq = np.loadtxt('dataset/train/train_seq.txt')lab = np.loadtxt('dataset/train/train_lab.txt')dataset = data_utils.TensorDataset(torch.tensor(seq), torch.tensor(lab))return datasetif type == 'valid':seq = np.loadtxt('dataset/valid/valid_seq.txt')lab = np.loadtxt('dataset/valid/valid_lab.txt')dataset = data_utils.TensorDataset(torch.tensor(seq), torch.tensor(lab))return datasetif type == 'test':seq = np.loadtxt('dataset/test/cb513_seq.txt')lab = np.loadtxt('dataset/test/cb513_lab.txt')dataset = data_utils.TensorDataset(torch.tensor(seq), torch.tensor(lab))return dataset

深度学习之蛋白质二级结构预测相关推荐

  1. 【计算机科学】【2020.05】基于深度学习的计算蛋白质结构预测

    本文为美国密苏里大学(作者:ZHAOYU LI)的博士论文,共136页. 蛋白质结构预测在生物信息学和计算生物学中具有重要意义.在过去的30年里,许多机器学习方法已经被发展出许多基于同源性和abini ...

  2. 关于深度学习在生物学领域的应用分析

    申明:本文来源于对论文"Applications of Deep Learning in Biomedicine"的理解. 深度学习研究及其在生物医药领域的潜在应用 深度学习已经在 ...

  3. 基于改进通道注意力和多尺度卷积模块的蛋白质二级结构预测

    一.背景: 传统的蛋白质三维结构预测可以通过一些传统方法预测,但是此类方法过于昂贵和耗费时间. 蛋白质二级结构是三维结构和序列的桥梁,其由多肽链中氢键的作用决定.许多研究表明,我们可以通过蛋白质的二级 ...

  4. 译文Deep Learning in Bioinformatics --深度学习在生物信息学领域的应用(2)

    译文Deep Learning in Bioinformatics --深度学习在生物信息学领域的应用(2) 深度学习在生物信息学领域的应用(2) 原文链接:https://arxiv.org/abs ...

  5. 《深度学习及其在生物医学中的应用》

    Deep Learning and Its Applications in Biomedicine paper 2018年发表与Genomics Proteomics Bioinformatics 1 ...

  6. 几何深度学习(Geometric Deep Learning)技术

    几何深度学习(Geometric Deep Learning)技术 几何深度学习综述 从论文Geometric Deep Learning: Grids, Groups, Graphs, Geodes ...

  7. Nat.Commun. | DeepAccNet:基于深度学习的准确性估计改善蛋白质结构优化

    今天给大家介绍的是华盛顿大学蛋白质设计研究所所长,著名的蛋白质设计天才科学家David Baker课题组发表在Nature Communications上的一项工作.在这项工作中,作者提出了一个深度学 ...

  8. Michael Brostein 最新几何深度学习综述:超越 WL 和原始消息传递的 GNN

    来源:AI科技评论 本文约8500字,建议阅读15+分钟 本文叫你如何突破基于 WL 测试和消息传递机制的 GNN 的性能瓶颈. 图可以方便地抽象关系和交互的复杂系统.社交网络.高能物理.化学等研究领 ...

  9. 张亚勤:深度学习更近一步,如何突破香农、冯诺依曼和摩尔瓶颈?

    来源:机器之心 本文约3100字,建议阅读6分钟 本文为你分享张亚勤在2020 CEO 年会上演讲<未来科技趋势展望>. 近日,在联想创投 2020 CEO 年会上,清华大学讲席教授.智能 ...

  10. 干货回顾丨深度学习应用大盘点

      当首次介绍深度学习时,我们认为它是一个要比机器学习更好的分类器.或者,我们亦理解成大脑神经计算. 第一种理解大大低估了深度学习构建应用的种类,而后者又高估了它的能力,因而忽略了那些不是一般人工智能 ...

最新文章

  1. 搭建 Visual Studio 和 Freeglut、GLEW的OpenGL环境
  2. 移动Web单页应用开发实践——页面结构化
  3. 更新r语言_【R语言学习最佳资料之一】R小抄速查表精简更新版
  4. LeetCode 17 电话号码的字母组合
  5. linux 修改ssh banner
  6. lambda表达式优化反射_反射选择器表达式
  7. 生产者-消费者 BlockingQueue 运用示例
  8. foxmail邮件怎样打印日历
  9. 关于封装的一个小问题和TA的例子
  10. 对不起,离开平台,你什么都不是
  11. linux计算机网络一般需要配置哪几部分,计算机网络技术 知识及应用第4章Linux操作系统和常用服务器配置.ppt...
  12. 用户环境变量_linux 初级3 环境变量命令env、set、export、declare的区别
  13. ajax submit 文件上传,ajaxSubmit() 上传文件和进度条显示
  14. 基于matlab的简易诊断系统,基于matlab的图像识别
  15. 查看win10的产品密钥过期时间
  16. 扫描到计算机桌面,win7怎么扫描文件到电脑上?扫描文件到win7电脑的两种方法...
  17. 【python爬虫】爬取《英雄联盟》英雄及皮肤图片数据
  18. namesilo续费
  19. 非遗在线商城小程序(后台PHP开发)
  20. [转]MinGW与MSYS

热门文章

  1. 什么是API接口平台?作用是什么?
  2. MSP430 BSL 下载
  3. 祝贺!周润发获颁授荣誉博士
  4. 系统运维工程师必备面试题库
  5. c++ ——二分查找函数
  6. 蒋建平:国内云计算刚刚起步
  7. 浩瀚先森(guohao1206.com)
  8. 涂抹mysql_《涂抹MySQL:跟着三思一步一步学MySQL》PDF版本下载
  9. Vue中自定义指令directive的使用
  10. 中国智能传感器投资规划建议及前景方向预测报告2022年版