最近在阅读和复现各个大佬的空转论文,记录、交流学习下,如有错误,欢迎指出。

前言

首先是STAGATE,是中科院提出来的方法,具体发表在NC上,主要思路与空转普遍的思路类似,提取基因表达、空间信息和图像特征,然后进行聚类,以识别每个spot的类型。当然,STAGATE,没有用图像信息,就已经是是目前已发表论文中最好的结果了。

总体架构

总体架构如下。

总体来说模型就是一个四层的AutoEncode,两层编码器两层解码器,只是每一层都换成了GAT。将基因表达数据X输入进去再重构出来X’,损失函数自然而然的就是X和X’的MSE。值得注意的是第二层和第三层,第一层和第四层分别共用一组权重W,为转置关系,这点在图上已经表明。如果是spot级别的数据,模型就已经全部讲完了,如果是细胞级别的数据,还会构建SNN,即重新构建一个新的GAT的邻接矩阵,然后每一层的结果是新的邻接矩阵和旧邻接矩阵构成的GAT加权求和为下一层的输入。

代码

作者最初发布的是tensorflow1的代码,今年三月份又公布了torch的代码,但是torch版本没有构建SNN,在细节上与tensorflow也略有不同,比如损失函数,tensorflow中除了MSE,又加入了权重损失防止过拟合,具体的在代码中我发现的都会提到。下面我试着根据torch版本的代码来说下我对这篇论文的理解。(最好在linux系统上运行,在windows上总是会出现各种奇怪错误)

首先是数据预处理。包括数据读取,在根据论文下载数据就好。然后是Normalization,选择高表达基因,正则化,取对数。再然后是读取真实标签用于最后测评并做了可视化。

    input_dir = os.path.join('Data', section_id)adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')adata.var_names_make_unique()#Normalizationsc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)sc.pp.normalize_total(adata, target_sum=1e4)sc.pp.log1p(adata)Ann_df = pd.read_csv(os.path.join('Data',section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']plt.rcParams["figure.figsize"] = (3, 3)sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])

然后是spot和spot之间的距离。距离大于0小于150的spot构建邻接矩阵,在这个范围内认为有连接,邻接矩阵为1,否则是0。以下是计算符合距离范围的spot的距离,并保存adata.uns['Spatial_Net']中。

def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True):"""\Construct the spatial neighbor networks.Parameters----------adataAnnData object of scanpy package.rad_cutoffradius cutoff when model='Radius'k_cutoffThe number of nearest neighbors when model='KNN'modelThe network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors.Returns-------The spatial networks are saved in adata.uns['Spatial_Net']"""assert(model in ['Radius', 'KNN'])if verbose:print('------Calculating spatial graph...')coor = pd.DataFrame(adata.obsm['spatial'])coor.index = adata.obs.indexcoor.columns = ['imagerow', 'imagecol']if model == 'Radius':nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)distances, indices = nbrs.radius_neighbors(coor, return_distance=True)KNN_list = []for it in range(indices.shape[0]):KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))if model == 'KNN':nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)distances, indices = nbrs.kneighbors(coor)KNN_list = []for it in range(indices.shape[0]):KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))KNN_df = pd.concat(KNN_list)KNN_df.columns = ['Cell1', 'Cell2', 'Distance']Spatial_Net = KNN_df.copy()Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)if verbose:print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))adata.uns['Spatial_Net'] = Spatial_Net

随后是一个可视化,平均每个spot有多少个邻居。

def Stats_Spatial_Net(adata):import matplotlib.pyplot as pltNum_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]Mean_edge = Num_edge/adata.shape[0]plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))plot_df = plot_df/adata.shape[0]fig, ax = plt.subplots(figsize=[3,2])plt.ylabel('Percentage')plt.xlabel('')plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge)ax.bar(plot_df.index, plot_df)

下面就正式进入STAGATE的训练阶段了。

首先将是数据准备,包括两部分:根据挑选出来的邻居构建邻接矩阵和基因表达数据。

def Transfer_pytorch_Data(adata):G_df = adata.uns['Spatial_Net'].copy()cells = np.array(adata.obs_names)cells_id_tran = dict(zip(cells, range(cells.shape[0])))G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))G = G + sp.eye(G.shape[0])edgeList = np.nonzero(G)if type(adata.X) == np.ndarray:data = Data(edge_index=torch.LongTensor(np.array([edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X))  # .todense()else:data = Data(edge_index=torch.LongTensor(np.array([edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X.todense()))  # .todense()return data

然后构建STAGATE模型 正如前边所说四层GAT,其中h2是最后的特征向量,h4是重建的基因表达数据。

class STAGATE(torch.nn.Module):def __init__(self, hidden_dims):super(STAGATE, self).__init__()[in_dim, num_hidden, out_dim] = hidden_dimsself.conv1 = GATConv(in_dim, num_hidden, heads=1, concat=False,dropout=0, add_self_loops=False, bias=False)self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False,dropout=0, add_self_loops=False, bias=False)self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False,dropout=0, add_self_loops=False, bias=False)self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False,dropout=0, add_self_loops=False, bias=False)def forward(self, features, edge_index):h1 = F.elu(self.conv1(features, edge_index))h2 = self.conv2(h1, edge_index, attention=False)self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1)self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1)self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1)self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1)h3 = F.elu(self.conv3(h2, edge_index, attention=True,tied_attention=self.conv1.attentions))h4 = self.conv4(h3, edge_index, attention=False)return h2, h4  # F.log_softmax(x, dim=-1)

具体的GAT代码不放了,详见`"Graph Attention Networks" <https://arxiv.org/abs/1710.10903>

具体训练代码如下,不同点是加了梯度截断,最后返回h2,或者说是z,也就是特征向量用于下一步聚类分析,保存到adata中。

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)loss_list = []for epoch in tqdm(range(1, n_epochs+1)):model.train()optimizer.zero_grad()z, out = model(data.x, data.edge_index)loss = F.mse_loss(data.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss_list.append(loss)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)optimizer.step()model.eval()z, out = model(data.x, data.edge_index)STAGATE_rep = z.to('cpu').detach().numpy()adata.obsm[key_added] = STAGATE_repif save_loss:adata.uns['STAGATE_loss'] = lossif save_reconstrction:ReX = out.to('cpu').detach().numpy()ReX[ReX<0] = 0adata.layers['STAGATE_ReX'] = ReX

最后调用了R中的mclust包进行聚类。

def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):"""\Clustering using the mclust algorithm.The parameters are the same as those in the R package mclust."""np.random.seed(random_seed)import rpy2.robjects as robjectsrobjects.r.library("mclust")import rpy2.robjects.numpy2rirpy2.robjects.numpy2ri.activate()r_random_seed = robjects.r['set.seed']r_random_seed(random_seed)rmclust = robjects.r['Mclust']res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)mclust_res = np.array(res[-2])adata.obs['mclust'] = mclust_resadata.obs['mclust'] = adata.obs['mclust'].astype('int')adata.obs['mclust'] = adata.obs['mclust'].astype('category')return adata

去掉缺失值并计算ARI。tensorflow版本和后续的数据分析解析等我看明白再来记录,最后附上测试DFPFC数据库的主函数。所有代码、数据和论文可以再github上下载,欢迎交流。

import warnings
warnings.filterwarnings("ignore")import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import sys
from sklearn.metrics.cluster import adjusted_rand_score
# import sklearn
import STAGATE_pyG as STAGATE
os.environ['R_HOME'] = '/home/admin/anaconda3/envs/lib/R'
# os.environ['R_USER'] = '/home/admin/Anaconda3\Lib\site-packages/rpy2'dataset = ["151507", "151508", "151509", "151510", "151669", "151670", "151671", "151672", "151673", "151674", "151675","151676"]
knn = [7, 7, 7, 7, 5, 5, 5, 5, 7, 7, 7, 7]
ARIlist = []
for section_id, k in zip(dataset, knn):print(section_id,k)input_dir = os.path.join('Data', section_id)adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')adata.var_names_make_unique()#Normalizationsc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)sc.pp.normalize_total(adata, target_sum=1e4)sc.pp.log1p(adata)Ann_df = pd.read_csv(os.path.join('Data',section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']plt.rcParams["figure.figsize"] = (3, 3)sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])STAGATE.Cal_Spatial_Net(adata, rad_cutoff=150)STAGATE.Stats_Spatial_Net(adata)adata = STAGATE.train_STAGATE(adata)sc.pp.neighbors(adata, use_rep='STAGATE')sc.tl.umap(adata)adata = STAGATE.mclust_R(adata, used_obsm='STAGATE', num_cluster=k)obs_df = adata.obs.dropna()ARI = adjusted_rand_score(obs_df['mclust'], obs_df['ground_truth'])ARIlist.append(ARI)print('Adjusted rand index = %.2f' %ARI)
print("ari mean", np.mean(ARIlist))
print("ari median", np.median(ARIlist))

空间转录组 STAGATE相关推荐

  1. Science亮点!ExSeq:完整生物组织的原位空间转录组分析

       背 景 介 绍   新一代测序技术的革新使得我们对生物体的组学信息有了更深的了解,随着转录组技术的日渐普及,尤其是单细胞转录组技术的突飞猛进,使得我们对组织内细胞异质性的认识有了很大提升,另外, ...

  2. 包教会一对一跟着CNS学单细胞测序(含空间转录组、chipseq、RNAseq、Atacseq 和外显子等)3月13日开始...

    报名成功立马送您往期所有视频预习 本班实行"包教包会.一对一指导服务",即如果报本班,不仅有同步回放视频,而且一对一指导服务,解决学完无法消化问题.学不会免费继续学,直到学会为止. ...

  3. Visium空间转录组

    Visium空间转录组是在组织原位检测全转录组基因表达的一种技术,使得我们在检测基因表达水平的同时,获得基因在组织内部空间表达的位置信息.与空间转录组相比,传统的全基因转录组或单细胞转录组测序,丢掉了 ...

  4. 【空间转录组】MIA分析

    之前讲过一篇空间转录组的文献,里面首次提出了Multimodal intersection analysis(MIA)的空间转录组分析思路. 讲解视频在B站 MIA分析可以用来评估空间上某个regio ...

  5. 热点综述 | 单细胞+空间转录组的整合分析方法总结

    目前scRNA-seq将每个转录物与单个细胞相关联,但关于这些转录物在组织中的位置信息丢失了:相反的,空间转录组学技术知道转录物的位置,却不知道是哪个细胞产生了转录物.因此,scRNA-seq与空间转 ...

  6. 空间转录组学(Spatial Transcriptomics)

    01.空间转录组技术的发展 近年来单细胞转录组测序技术的应用大大拓宽了人们的视野,使人们能够深入了解组织中细胞的构成的多样性和基因表达状态.众所周知,基因表达具有时间和空间的特异性,通过对不同时间点的 ...

  7. 非因解读 | DSP空间转录组技术揭示食管鳞状细胞癌三级淋巴样结构的预后价值及分子特征

    食管鳞状细胞癌(oesophageal squamous cell carcinoma,ESCC)是中国第三大常见恶性肿瘤.研究发现,肿瘤微环境(tumor microenvironment,TME) ...

  8. ClusterMap:用于空间基因表达的多尺度聚类分析 | 空间转录组分析工具推荐

    在空间背景下量化RNA是了解复杂组织中基因表达和调控的关键.原位转录组方法可以在完整的组织中产生空间分辨率的RNA图谱.然而,目前还缺乏一个统一的计算工具来综合分析原位转录组数据.2021年10月,N ...

  9. 空间转录组第一讲:10x空间转录组技术介绍

    最近,空间转录组学研究炙手可热.细胞及其在组织样本中的相对位置之间的关系对于理解疾病病理可能至关重要.空间转录组学是一项开创性的技术,它使科学家能够测量组织样本中的所有基因活性,并绘制出发生该活性的位 ...

最新文章

  1. pyqt5教程11:绘制外观
  2. 信息安全与网络安全,你分清了吗?
  3. Redash 9安装与配置(基于Docker方式)
  4. 使用IDM下载,不适用默认浏览器下载
  5. 为什么多个线程不可能同时抢到一把锁_HFL Redis_12_redis分布式锁的3种实现方式...
  6. python发送qq消息linux_python 调用qq邮箱 linux 执行每天自动发送邮件
  7. linux运行大端程序,ARM 平台上的Linux系统启动流程
  8. Java多线程问题之同步器CyclicBarrier
  9. 海伦公式c语言double,海伦公式
  10. HTC Vive开发笔记之手柄控制
  11. Xshell下载文件到本地
  12. excel缩字间距_如何取消字体间距 excel字体间距紧缩
  13. Python爬虫实例--新浪热搜榜[xpath语法]
  14. 无线通信-信道模型概念(Wireless Communication Overview)
  15. 如何选择计算机软考科目,计算机软考科目众多 我们该如何选择考试科目?
  16. 使用TreeMap实现ASCII排序
  17. ENVI经验|基于多源遥感影像的红树林范围提取4-面向对象分类
  18. 【沃顿商学院学习笔记】商业基础——Operation Management:02运营管理活动中的详细流程分析
  19. 052基于SSM车辆维修管理系统
  20. Galaxy 9300 刷机和获取root权限

热门文章

  1. 深入理解分布式、微服务中CAP定律和BASE理论
  2. 服务器数据库监控系统,数据库监控系统
  3. [BOI Mokia]
  4. Interlagos推土机架构处理器推出了什么
  5. SQL根据一张表数据更新另外一张表
  6. 微信支付商户动态口令在哪里找?
  7. HTML介绍 与基础操作
  8. MySql数据库事务隔离级别底层实现原理总结
  9. VC中常见API函数使用方法(经验版)
  10. S3C2440驱动开发(四)