The robust self-supervised learning strategy to tackle the inherent sparsity of single-cell RNA-seq data
https://zenodo.org/records/10602754
https://zenodo.org/records/10608134 https://zenodo.org/records/12741301
with Python >= 2.1.2, sklearn, scipy, scanpy
from scRobust import *
cuda_condition = torch.cuda.is_available()
device = torch.device("cuda:1" if cuda_condition else "cpu")
scRo = scRobust(device)
adata_path = './data/Processed_Filtered_Segerstolpe_HumanPancreas_data.h5ad'
scRo.read_adata(adata_path)
gene_vocab, tokenizer = scRo.set_vocab()
scRo.set_encoder(hidden = 64*8, n_layers = 1, attn_heads= 8)
## pre_train scRobust
scRo.set_pretraining_model(hidden = 64*8, att_dropout = 0.3)
scRo.train_SSL(epoch = 1000, lr = 0.00005, batch_size = 128, n_ge = 250, save_path = './weights/')
## load weight
weight_path = './weights/Segerstolpe_CL_GE_BERT_Hid_512_Att_8_nGenes_200_ly_1_bt_128_encoder.pt'
scRo.load_encoder_weight(weight_path)
## get cell embeddings
cell_embeddings = scRo.get_cell_embeddings(n_ge = 1000, batch_size = 64)
scRobust_adata = scRo.get_cell_adata(cell_embeddings, umap = False, tsne = True, leiden = True, n_comps = 50, n_neighbors=10, n_pcs=50)
scRobust_adata.obs['label'] = scRo.adata.obs['label']
sc.pl.tsne(scRobust_adata, color='label')
sc.pl.tsne(scRobust_adata, color=['leiden'])
## cell type annotation with pathway vectors
alpha_genes = 'VGF;SYT5;GADD45G;CITED2;FEV;C2CD4B;LOXL4;TBC1D30;C2CD4A;PTP4A3;STC2;LBH;KCNMA1;HERC5;FAP;KCNK17;GLS;RGS4;CRYBA2;APOH;MUC13'
beta_genes = 'DLK1;SYT13;GSN;PRKACB;TGFBR3;SCN1B;SUSD4;PCSK1;NPTX2;EDN3;THBD;CADM1;MEG3;PLCXD3;WSCD2;SIX3;IL17RB;SLC16A9;PDX1;HADH;PRSS23;KCNMA1;KLHDC8A;BMP5;ADRA2A;NEFM;TSPAN1;SOCS2;CASR;TIMP2;SCD5;KLF10;CDKN1A;ERRFI1;ID4;RNF187;ENTPD3;GREM1;ADCYAP1;RASSF6;DNAJB9;SYBU;LRRTM3;CAPN13;IAPP;RASD1;SORL1;ATRNL1;ID1;ABCG1;TXNRD1;CLGN;LONRF2;BHLHE41;ZNF331;RGS16'
delta_genes = 'GABRB3;LEPR;PHGR1;MS4A8;SLC7A5;EDN3;MRAP2;CDHR3;PIPOX;HADH;PRKACB;RBP4;LONRF2;VAT1L;UCP2;EHF;CALB1;SORL1;PRSS23;HHEX;DNAJB9;HAP1;FFAR4;FOXP2;FOS;AQP3;PKIB;AKAP12;PDE2A;BCHE;BHLHE41;DUSP6;MLPH;BAALC;F5'
gamma_genes = 'CALY;CDKN1A;TMEM45B;IL17RB;DPYSL3;FGB;CARD11;FOS;ID4;ID1;EGR1;PTP4A3;RASD1;DDC;ETV1;SERTM1;LMO3;THSD7A;GLT8D2;FEV;KCNJ8;EGR2;EHF;NR4A2;PEG10;PON3;ID2'
ductal_genes = 'ANXA9;C6;SCNN1A;MUC20;PPP1R1B;SLC3A1;CA4;DEFB1;SLC17A4;CDX2;SFRP5;AGR3;LEFTY1;SCTR;TTYH1;ANXA4;TRPV6;HABP2;ELF3;UGT2B15;KCNJ16;PLD1;LAD1;OLFM4;FUT3;LAMA1;SLC4A4;SYT8'
epsilon_genes = 'ACSL1,AGT,ANXA13,ARX,ASAH1,ASGR1,BMP7,C5orf4,CDKN2A,CLU,DEFB1,DIRAS3,EBF1,FFAR4,FGF14,FRZB,GHRL,HEPACAM2,KCNH2,LINC00261,NEDD9,NUAK1,PCSK1,PEG10,PHGR1,PIGK,PROX1,SAMD5,SEMA3E,SERPINA1,SERPINA10,SERPINB6,SPINK1,SPON2,SPTSSB,SUCNR1,SYNJ2,TM4SF4,TM4SF5,TMEM45B,TRNP1,VSTM2L,VTN,ZKSCAN1'
acinar_genes = 'AADAC,ABCC3,ABHD11,ABRACL,ACAA1,ACO1,ACOX1,ACSL4,ACTG1,ACTN4,ADAM9,ADAMTS1,ADD3,ADH1C,ADI1,ADM2,AHCY,AKR1B10,AKR1C1,AKR1C2,AKR1C3,AKR7A3,ALAD,ALB,ALDH2,ALDH3A2,ALDOB,AMOTL2,AMY2A,AMY2B,ANGPTL4,ANKRD36,ANKRD36B,ANO6,ANPEP,ANXA11,ANXA4,AOX1,AP4B1,AQP8,ARHGAP18,ARHGAP26,ARPC1A,ARRDC3,ARSD,ASNS,ASS1,ATF3,ATF4,ATP10B,ATP8B1,ATRN,B2M,B3GALNT2,B3GNT2,B3GNT5,B3GNT7,B4GALT1,BACE2,BAIAP2L1,BAZ1A,BCAT1,BCL3,BHLHE40,BIRC3,BLVRB,BMF,BTF3,BTG1,POU2AF3,C15orf48,C1R,C1RL,C2CD2,C3,C4orf19,C5,BNIP5,TCIM,CA12,CAPN2,CAPZA1,CASP4,CASP6,CASP7,CAV2,CBR1,CBS,CCL2,CCL20,CCL28,CCL8,CCNC,CCND2,CCNG1,CCPG1,CD151,CD24,CD2AP,CD44,CD47,CD74,CDC42EP1,CDH1,CDR2L,CDV3,CEBPD,CEL,CELA2A,CELA2B,CELA3A,CELA3B,CES2,CFB,CFI,CGNL1,CHAC1,CHMP4C,CHP1,CLDN1,CLDN10,CLDN12,CLDN2,CLDN4,CLDN7,CLIC6,CLMN,CLPS,CMPK1,CNN2,CNN3,CPA1,CPA2,CPA4,CPB1,CPM,CREB3L1,CREB5,CREG1,YBX3,CTRB1,CTRB2,CTRC,CTRL,CTSD,CTTN,CXADR,CXCL1,CXCL12,CXCL16,CXCL17,CXCL2,CXCL3,CYB5A,CYFIP1,CYP17A1,CYP1A1,CYP3A5,DAB2IP,DAG1,DBNDD1,DDAH1,DHRS3,DNAJC3,DUOX2,DUSP16,DUSP23,DUSP4,EBP,EBPL,EDEM3,EDN1,EEF1A1,EEF1B2,EEF1D,EEF1G,EEF2,EGFR,EHD4,EIF2S2,EIF2S3,EIF3D,EIF3E,EIF3F,EIF3L,EIF4E2,EIF4EBP1,EIF4EBP2,ELF3,EML4,ENAH,ENC1,EPB41L4B,EPHX1,ERBB3,ERO1A,ERP27,ESRP1,ETS2,ETV6,F11R,F3,FA2H,FAM107B,NIBAN1,NIBAN2,RETREG1,FAM13A,FAM161A,FAM162A,FAM3B,FBP1,FBXO25,FDX1,FGA,FGD6,FGG,FGL1,FKBP11,FKBP1A,FKBP9P1,NHSL2,FLRT2,FNDC4,FNIP2,FOXC1,FRMD8,FTH1,FTH1P3,FTL,FUT4,FZD5,GALNT2,GALNT3,GALNT7,GAREM1,GATA4,GATA6,GATM,GBP2,GCLC,GCNT3,GDF15,GJB1,GLRX,NOP53,GLUL,GMNN,RACK1,GOLM1,GP2,GPR160,GPRC5A,GPRC5B,GPT2,GPX2,GRB10,GSTA1,GSTA2,GSTP1,GUCD1,GULP1,H6PD,HDGF,HDHD3,HEBP1,HERPUD1,HES1,HEYL,HHLA2,HIF1A,HLA-B,HLA-C,HLA-DMA,HLA-DRA,HLA-DRB1,HLA-F,JPT1,HOMER2,HPN,HSD17B11,HSD17B4,HSPB1,HSPB8,HTATIP2,IDH1,IDH2,IER2,IFITM3,IFNGR1,IGBP1,IGFBP2,IL18,IL22RA1,IL32,CXCL8,IMPA2,IMPDH2,INSR,IQGAP2,ITGA2,ITGA6,ITGAV,JUN,JUP,KCNQ1,KDELR3,NHSL3,KIF1C,KITLG,KLF3,KLF5,KLF6,KLK1,KRAS,KRT18,KRT7,KRT8,KRTCAP3,LAD1,LAMB3,LASP1,LCN2,LDHA,LDHB,LGALS2,LGALS3,LGALS9,LGR4,LIMA1,LIMK2,LMO7,CAMK1D,PTMAP11,LRG1,LRP10,LRRC59,LSR,LTB,LYN,LYPD1,LYZ,MACC1,MAGT1,MAL2,MAN2A1,MAP3K5,MARCKS,MARCKSL1,MBNL2,MCCC1,MCFD2,MCL1,MECOM,MET,TMT1A,MGAT4B,MGST1,MKNK2,MLEC,MLPH,MLXIP,MOB3B,MSN,MT1G,MTHFD2,MTMR12,MTUS1,MUC1,MUC20,MXI1,MXRA5,MYC,MYH9,MYL12A,MYL12B,MYO1C,MYO5B,MYO5C,NACA,NACA2,NAMPT,NCOA7,NDRG1,NEAT1,NEDD4L,NET1,NEXN,NFE2L2,NFIB,NFKBIA,NFKBIZ,NPM1,NR0B2,NR5A2,NRARP,NTN4,NUPR1,OAF,OCLN,OLFM4,OSMR,OSTC,OXA1L,P4HB,PABPC1,PABPC3,PABPC4,PACSIN2,PARD3,PARD6B,PARP4,PARVA,PBLD,PCBP2,PDCD4,PDGFA,PDLIM1,PDLIM5,PDP1,PDZK1IP1,PGM1,PGM2L1,PGM3,PHGDH,PHLDA1,PIGR,PIK3AP1,PIM3,PLA2G1B,PLIN5,PLS3,PLSCR1,PLTP,PMAIP1,PNLIP,PNLIPRP1,PNLIPRP2,PNRC1,PODXL,POLD4,POLDIP2,POLR1D,PPIA,PPIC,PPIF,PPP1R3C,PPP2R5A,SLC66A3,PRDX1,PRDX4,PRDX6,PRPS2,PRR15L,PRSS1,PRSS3,PRSS3P2,PSAT1,PSMB10,PSMB8,PTER,PTF1A,PTGFRN,PTGR1,PTMA,PTP4A1,PTP4A2,PTPRF,PTPRK,NECTIN4,PYCR1,RAB11FIP1,RAB27B,RAB3D,RAB9A,RASEF,RASGRP3,RASSF6,RBM3,RBM47,RBPJ,RBPMS,RDH10,RDX,REG1A,REG1B,REG1CP,REG3A,REG3G,REPS2,RER1,REST,RETSAT,RHOC,RLIM,RNASE1,RND1,RNF181,RNF19A,RNF213,RPL10,RPL10A,RPL10L,RPL11,RPL12,RPL13,RPL13A,RPL13AP20,RPL13AP5,RPL14,RPL17,RPL18A,RPL19,,RPL22L1,RPL23,RPL23A,RPL23P8,RPL24,RPL26,RPL27,RPL27A,RPL29,RPL30,RPL31,RPL32,RPL34,RPL35,RPL35A,RPL36,RPL36A,RPL36A-HNRNPH2,RPL36AL,RPL37,RPL37A,RPL38,RPL39,RPL4,RPL41,RPL7,RPL7A,RPL9,RPLP0,RPLP0P2,RPLP1,RPLP2,RPN2,RPS11,RPS12,RPS13,RPS15A,RPS16,RPS19,RPS2,RPS20,RPS21,RPS23,RPS24,RPS25,RPS27,RPS27A,RPS27L,RPS28,RPS29,RPS3,RPS3A,RPS4X,RPS4Y1,RPS5,RPS6,RPS6KA2,RPS7,RPS8,RPS9,RPSA,RPSA2,RRBP1,RSU1,S100A11,SAT1,SAV1,SCN9A,SDC4,SEL1L,SEL1L3,SERINC2,SERP1,SERPINA3,SERPINB1,SERPINB6,SERPINF1,SERPINI2,SETD7,SGK1,SH2D4A,SH3BP4,SHMT2,SHROOM3,SKIL,SLC12A2,SLC12A7,SLC16A7,SLC22A23,SLC25A37,SLC25A5,SLC30A2,SLC35D2,SLC35E1,SLC38A1,SLC38A2,SLC39A14,SLC39A8,SLC44A4,SLC7A11,SLFN5,SMAD3,SMIM14,SNAP23,SNHG5,SOCS3,SOD2,SORBS2,SORD,SOX4,SOX9,SP100,SPATA13,SPCS3,SPINK1,SPSB1,SPTBN1,SPTSSA,SQSTM1,SRD5A3,SRI,ITPRID2,SSR3,STAT6,STEAP4,STK17A,STK24,SUCLG2,SYNGR2,SYTL2,TACSTD2,TAGLN2,TALDO1,TAPBP,TBC1D1,TC2N,TCEA3,TCIRG1,TCN1,TES,TFPI,TGIF1,TJP2,TKT,TM4SF1,TM7SF2,TM9SF3,TMBIM1,TMC4,TMC5,TMED2,TMED9,TMEM123,TMEM125,TMEM165,CEMIP2,TMEM41A,TMEM51,TMEM87B,TMEM97,TMEM98,TMPRSS2,TMSB10,TMSB4X,TNFAIP1,TNFAIP3,TNFAIP8,TNFRSF10B,TNFRSF12A,TNFSF10,TOB1,TOR1AIP2,TP53I11,TP53INP1,TPD52L1,TPM1,TPST2,TPT1,TRAF4,TRAM1,TRIB1,TRIB2,TRIM47,TSPAN13,TSPAN15,TSPAN6,TST,TUBA1C,TUBB2A,TUBB3,TUBB4B,TXN,TXNRD1,UBA52,UBD,UBE2H,UGDH,UGT2B15,UNC5CL,USP53,VAMP8,VAPA,VIM,VMP1,VTN,WIPI1,XBP1,XPNPEP1,YAP1,YWHAZ,MAP3K20,ZC3H12A,ZFAND5,ZFP36L1,RBSN,ZNF114,ZNF117,ZNF217,ZNF704'
endothelial = 'A2M,ABI3,ACE,ACVRL1,ADAMTS4,ADAMTS9,AFAP1L1,AHR,ANGPT2,ANGPTL2,ANKRD11,ANKRD50,APLNR,APP,ARHGAP23,ARHGAP29,ARHGAP31,ARHGDIB,ARHGEF12,ARPC2,ATAD3C,ATP1B3,BCL6B,BDKRB2,BGN,BMPR2,BNIP2,NOL4L,TCIM,CALCRL,CAV1,CBL,CBLB,CCND1,CD34,CD36,CD59,CD81,CD9,CD93,CDC37,CDH5,CDK2AP1,CDR2L,CEBPG,CEP170,CFL1,CHSY1,CLEC14A,CLIC2,CLIC4,COL13A1,COL15A1,COL4A1,COL4A2,COL5A3,CSF2RB,CTNNA1,CTNNB1,CXCL12,CXCR4,DIPK2B,CYGB,CYP1A1,CYP1B1,CYYR1,DAB2,DAB2IP,DGKD,DHRS3,DKK3,DLC1,DOCK4,DPYSL3,DUSP6,ECE1,ECSCR,EFNA1,EFNB2,EFNB3,EHD4,ELK3,ADGRL4,EMCN,EMP1,EPAS1,ERBIN,ERG,ESAM,ESM1,ETS1,ETS2,EXOC3L2,F2R,F2RL3,FABP5,RFLNB,GASK1B,FAM43A,RIPOR1,FCN3,FHOD1,FKBP1A,FLT1,FLT4,FMNL3,FRMD4A,FRMD8,FSCN1,FSTL1,FZD4,GAB1,GATA2,GBP4,GIMAP4,GIMAP7,GIMAP8,GMFG,GNG11,ADGRF5,GPR4,ADGRG1,GPX1,GRB10,HDAC7,HEG1,HEY1,HIP1,HK1,HLA-B,HLA-E,HLX,HMGB1,HOXB6,HPCAL1,HTRA1,HTRA3,ICAM1,ID1,ID2,ID3,IFI16,IFI27,IFITM1,IFITM2,IFITM3,IGFBP4,IGFBP7,IL4R,IL6ST,INSR,ITGA1,ITGA2,ITGA5,ITGAV,ITPRIP,ITPRIPL2,JAG1,JAK1,JAM3,JUP,KBTBD2,KCNN3,KCTD20,KDR,GARRE1,DENND11,KLF13,KLHL24,LAMA4,LAMB1,LAMC1,LATS2,LBH,LDB2,LDHB,LGALS1,LHFPL6,LHFPL2,LIFR,LIMS1,LMO2,PCAT19,PTMAP11,LPAR6,ADGRL2,LRRC32,LRRC8A,LRRC8C,LUZP1,LXN,MACF1,MADCAM1,MARCKSL1,MCAM,MCF2L,MECOM,MEF2A,MEF2C,MRTFB,MLEC,MLLT1,MMP1,MMP2,MMRN2,MPRIP,MSN,MSX1,MYCN,MYCT1,MYH9,MYL12B,MYLK,NEDD9,NID1,NKX2-3,NLRC3,NOTCH1,NOTCH4,NOVA2,NQO1,NRARP,NRP1,NRP2,OLFML2A,P2RY6,PASK,PDE2A,PDGFA,PDGFB,PEA15,PECAM1,PELO,PFKP,PGM2,PGM2L1,PIM3,PITPNC1,PKP4,PLEKHG1,PLEKHO1,PLVAP,PLXNA2,PMEPA1,PNP,PODXL,PLPP3,PPP1R18,PRCP,PRDM1,PREX1,PRSS23,PTMA,PTPN12,PTPRB,PTPRE,CAVIN1,PXDN,PXN,QKI,RAB31,RAI14,RAMP2,RAP1A,RAPGEF1,RAPGEF4,RASA4,RASGRP3,RB1,RBMS1,RCC2,RDX,RELL1,REST,RGCC,RGS5,RHOJ,RILPL2,ROBO4,RSU1,S100A13,SASH1,SEC14L1,SELE,SEMA3G,SERPINE1,SERPINH1,SGK1,SH2B3,SH2D3C,SH3KBP1,SHANK3,SHE,SKAP2,SLC12A7,SLC26A2,SLC2A3,SLC38A2,SLC44A2,SLC7A11,SLCO2A1,SMAD6,SMAD7,SNAI1,SNAP23,SNRK,SNTB2,SOCS3,SOX18,SOX4,SPARC,SPARCL1,SPG7,SPON2,SPRED1,SPRY4,SPTBN1,SRGN,ST8SIA4,STC1,STOM,SULF2,SWAP70,SYNM,SYNPO,TBC1D1,TBC1D8,TCF4,TEK,TGFBR1,TGFBR2,THBD,THBS1,TIE1,TM4SF1,TM4SF18,STING1,TMEM204,TMEM47,TMSB10,TP53I11,TP53INP2,TSC22D1,TSPAN14,TUBB6,UACA,UBE2J1,UNC5B,UPP1,UTRN,UXS1,VASH1,VIM,VWA1,VWF,WDR1,WWTR1,YES1,ZEB1,ZFP36L2,ZNF503,ZNF532'
mast = 'RASGRP1,NDRG1,CD226,CPLX2,CHGA,SCN11A,CD300A,CLNK,MRGPRX2,CNR1,CNR2,ADORA2B,CD300LF,UNC13D,FCER1A,FCGR2B,FER,FES,FGR,FOXF1,GATA1,GATA2,HAVCR1,LAT,RABGEF1,MILR1,GRP,ADGRE2,IL4R,IL13,IL13RA2,KIT,GPR15LG,LCP2,LGALS9,RHOH,RAB44,LYN,PLA2G3,ENPP3,PDPK1,PIK3CD,PIK3CG,PLSCR1,SPHK2,C12orf4,PTGDR,PTGDS,PTPN6,NECTIN2,SNX6,RAC2,S100A12,S100A13,CRLF2,SLC18A2,STXBP1,STXBP2,VAMP2,VAMP7,SYK,BTK,LAT2,NR4A3,TSLP,CBL,VAMP8,SNX4,SNAP23,CD84,VAMP3,GAB2'
psc_genes = 'ACTA2,ACTN1,ACTN4,ADAMTS1,ADAMTS2,ADAMTSL2,AEBP1,AGTR1,,ALDH1A1,ANK3,ANTXR1,ARID5B,ATN1,BGN,BMP2,BTG2,DEPP1,C11orf96,C1QTNF1,C1R,C1S,C7,CALD1,CAV1,CCBE1,CCDC3,CCDC80,CCL2,CCNI,CD63,CD81,CDC42EP3,CDKN1A,CEBPB,CEBPD,CFH,CH25H,CITED2,CLDN11,COL14A1,COL1A1,COL1A2,COL3A1,COL4A1,COL4A2,COL5A1,COL6A1,COL6A2,COL6A3,COLEC11,CPXM2,CRISPLD2,CRYAB,CSRNP1,CSRP1,CTDSPL,CCN2,CXCL12,CYBRD1,CCN1,DCN,DDR2,DKK3,DPT,DPYSL2,DSTN,DYNLL1,EDNRB,EFEMP1,EGR1,EMP1,ENG,EPAS1,FBLN1,FBLN2,FBLN5,FBN1,FHL1,FILIP1L,FOS,FOSB,FSTL1,GEM,GGT5,GLUL,GPC6,GPX3,GSN,H19,HGF,HLA-DRB1,HSPB1,IER3,IFITM2,IFITM3,IGFBP3,IGFBP4,IGFBP5,IGFBP7,IL1RL1,INMT,ITGA8,ITGA9,ITGB1,ITGB5,ITM2C,JUNB,KAZN,KLF4,LAMB1,LAMC1,LAMC3,LGALS1,LGALS3BP,LMNA,LRP1,LTBP4,LUM,MAFF,MAP1B,MARCKS,MASP1,MATN2,MFAP4,MGP,MITF,MKNK2,MMP19,MMP2,MT1M,MXRA7,MYADM,MYC,MYH10,MYL9,MYLK,NEXN,NFIA,NFIX,NR2F1,NR2F2,OLFML3,PALLD,PCDH9,PCOLCE,PDGFRA,PHLDA1,PLCXD3,PRELP,CAVIN3,PROS1,PRSS23,PTGDS,PTN,CAVIN1,QSOX1,RBMS3,RBP1,RBPMS,RGS16,RGS2,RGS4,RN7SK,RND3,RRAS,SERPINE1,SERPINF1,SFRP1,SLC38A2,SLIT3,SOCS3,SOD3,SPARC,SRPX,SSPN,SYNPO2,TCEAL4,TCF4,THBS1,THBS2,THY1,TIMP1,TIMP2,TIMP3,TIPARP,TMEM204,TNS1,TPM4,TSC22D1,TSC22D2,VCL,VIM,ZFP36,ZFP36L1'
pathways = ['alpha','beta','delta','gamma','ductal','epsilon','acinar','endothelial','PSC','mast']
target_pathways = [alpha_genes, beta_genes, delta_genes, gamma_genes, ductal_genes, epsilon_genes, acinar_genes,
endothelial, psc_genes, mast]
pathway_dict = {}
for i in range(len(target_pathways)):
pathway = target_pathways[i]
try:
pathway.index(';')
pathway = pathway.split(';')
except:
pathway = pathway.split(',')
pathway_dict[pathways[i]] = pathway
dot_df, pred_y = scRo.cell_type_annotation_with_pathway(pathway_dict)
result_df = pd.concat([scRo.adata.obs['label'],pd.DataFrame(pred_y,columns = ['pred']),dot_df],axis=1)
result_df = result_df[scRo.adata.obs['label'].isin(list(pathway_dict.keys()))]
acc = np.average(result_df['pred'] == result_df['label'])
f1 = f1_score(result_df['label'], result_df['pred'], average='macro')
## analyze HbA1C
clinical_labels = pd.read_csv('./data/Segerstolpe_HumanPancreas_clinical.csv')
sample_ids = clinical_labels['Characteristics [individual]']
sample_ids.index = sample_ids.index.astype(str)
HbA1c_arr = []
count = 0
for i in clinical_labels['Characteristics [clinical information]']:
try:
HbA1c_arr.append(float((i[6:9])))
except:
HbA1c_arr.append(5.5)
count +=1
HbA1c_arr = np.array(HbA1c_arr)
labels = scRobust_adata.obs['label']
target_cell_type = ['alpha','beta','gamma','delta']
cell_types = labels[(HbA1c_arr>0) * (labels.isin(target_cell_type))]
HbA1c = HbA1c_arr[(HbA1c_arr>0) * (labels.isin(target_cell_type))]
s_ids = sample_ids[(HbA1c_arr>0) *(labels.isin(target_cell_type))]
adata_df = pd.DataFrame(data = cell_embeddings[labels.isin(target_cell_type)],
index = labels.index[labels.isin(target_cell_type)])
sc_adata = sc.AnnData(adata_df)
sc_adata.obs['HbA1c'] = HbA1c
sc_adata.obs['cell_types'] = cell_types.tolist()
sc_adata.obs['sample_ids'] = s_ids.tolist()
sc.pp.pca(sc_adata, n_comps=50)
sc.pp.neighbors(sc_adata, n_neighbors=10, n_pcs=50)
sc.tl.tsne(sc_adata)
sc.tl.leiden(sc_adata)
sc.pl.tsne(sc_adata, color='HbA1c')
sc.pl.tsne(sc_adata, color='cell_types')
sc.pl.tsne(sc_adata, color='sample_ids')
## Downstream task
indices = ~scRo.adata.obs['label'].isin(['co-expression','unclassified endocrine']).values
scRo.adata = scRo.adata[indices]
n_classes = len(scRo.adata.obs['label'].cat.categories)
scRo.set_downstream_model(hidden = 64*8, n_clssses = n_classes, att_dropout = 0.3)
val_auc, val_loss, val_f1, val_acc, \
test_auc, test_loss, test_f1, test_acc = scRo.train_DS(epoch = 20, lr = 5e-5, batch_size = 64, n_ge = 800)
Large scRobust with 8 layers and large vocaburary with 42,160 tokens
from scRobust import *
import scanpy as sc
import pandas as pd
import numpy as np
cuda_condition = torch.cuda.is_available()
device = torch.device("cuda:3" if cuda_condition else "cpu")
scRo = scRobust(device)
vocab_path = './vocab/whole_human_vocab.csv'
adata_path = './data/Processed_Filtered_Segerstolpe_HumanPancreas_data.h5ad'
scRo.load_vocab(vocab_path)
## For the larger version, normalize_total -> log1p step is needed.
## Processed_Filtered_Segerstolpe_HumanPancreas_data.h5ad has already been pre-processed.
scRo.read_adata(adata_path,normalize_total = False, log1p = False)
## Pre-processing for being sure. Therefore, it's not necessary.
scRo.adata.X = np.expm1(scRo.adata.X)
sc.pp.normalize_total(scRo.adata, target_sum=1e4)
sc.pp.log1p(scRo.adata)
d = 64; attn_heads = 8; hidden = d*attn_heads; n_layers = 8; n_ge = 400;
scRo.set_encoder(hidden = 64*8, n_layers = n_layers, attn_heads= 8)
scRo.set_pretraining_model(hidden = 64*8, att_dropout = 0.3)
save_path = './weights/test_w_weight_ly8_bt_128'
weight_path = './weights/Whole_Human_BERT_Hid_512_Att_8_nGenes_200_ly_8_bt_256.pt'
scRo.load_model_weight(weight_path)
## can do fine-tuning
train_cl_loss, train_ge_loss, test_cl_loss, test_ge_loss = scRo.train_SSL(epoch = 100, lr = 0.0001, batch_size = 128,
n_ge = n_ge, save_path = save_path, simple = True)
## get cell embeddings
cell_embeddings = scRo.get_cell_embeddings(n_ge = 1000, batch_size = 128, use_HUGs = False, use_HVGs = True, simple = True)
scRobust_adata = scRo.get_cell_adata(cell_embeddings, umap = False, tsne = True, leiden = True,
n_comps = 50, n_neighbors=10, n_pcs=50)