library(tidyverse)
library(Matrix)
library(ggcorrplot)
library(ggrepel)
file = "/sc/arion/projects/CommonMind/hoffman/dreamlet_analysis/PsychAD_r0/topTable_PsychAD_r0.tsv.gz"
tab.psychad = read_tsv(file)

file = "/sc/arion/projects/CommonMind/hoffman/dreamlet_analysis/Mathys_2023/Mathys_2023_Major_Cell_Type_ADdiag2typesAD.tsv"
tab.mathys = read_tsv(file)

# cell type order
ctorder = c('EN_L2_3_IT', 'EN_L3_5_IT_1', 'EN_L3_5_IT_2', 'EN_L3_5_IT_3', 'EN_L5_6_NP', 'EN_L6_CT', 'EN_L6_IT', 'EN_NF', 'IN_ADARB2', 'IN_LAMP5', 'IN_PVALB', 'IN_PVALB_CHC', 'IN_SST', 'IN_VIP', 'Oligo', 'OPC', 'Astro', 'Micro_PVM', 'CD8_T', 'PC', 'VLMC','Endo')
# Compare Astrocytes
pair = c(psychad = "Micro_PVM", mathys = "Mic")

tab.psychad.sub = tab.psychad %>% 
        filter(assay == pair[['psychad']])

tab.mathys.sub = tab.mathys %>% 
        filter(assay == pair[['mathys']])

tab = inner_join(tab.psychad.sub, tab.mathys.sub, by="ID")

tab$Signif = "no"
i = with(tab, adj.P.Val.x < 0.05)
tab$Signif[i] = "PsychAD"
i = with(tab, adj.P.Val.y < 0.05)
tab$Signif[i] = "Mathys, et al."
i = with(tab, adj.P.Val.x < 0.05 & adj.P.Val.y < 0.05)
tab$Signif[i] = "Both"

tab$Signif = factor(tab$Signif, c("no", "PsychAD", "Mathys, et al.", "Both"))

col = c("grey70", "grey70", "grey70", "red2")
names(col) = levels(tab$Signif)

fit = lm(logFC.y ~ logFC.x, tab)

pv = coef(summary(fit))[2,4]
rvalue = format(summary(fit)$r.squared, digits=3)

txt = paste("p < 1e-16\nR2 =", rvalue)

fig = tab %>% 
      arrange(-P.Value.x) %>%
      ggplot(aes(logFC.x, logFC.y, color=Signif)) +
          coord_fixed() +
          theme_classic() +
          xlab("logFC from PsychAD (Micro_PVM)") +
          ylab("logFC from Mathys, et al. (Microglia)") +
          geom_abline(color="red") +
          geom_hline(yintercept = 0, color="grey40", linetype='dashed') +
          geom_vline(xintercept = 0, color="grey40", linetype='dashed') +
          geom_point() + 
          ggtitle("Compare logFC from microglia") +
          theme(plot.title = element_text(hjust = 0.5)) +
          scale_color_manual(values=col) +
          annotate("text", x=-.6, y=1.6, label=txt) 

fig = fig +
      geom_text_repel(data = tab %>% filter(Signif == "Both"), aes(x = logFC.x, y = logFC.y, label=ID), box.padding=.6, label.padding=.6, min.segment.length=.1, color="black")
fig

All pairs

jointMatrix = function( tab ){

  tab$se = with(tab, logFC / t)

  cellTypes = unique(tab$assay)

  grd = t(combn(length(cellTypes), 2))
  grd = data.frame( CT1 = cellTypes[grd[,1,drop=TRUE]], 
                    CT2 = cellTypes[grd[,2,drop=TRUE]])

  df = lapply( seq(nrow(grd)), function(i){

    tab.sub1 = tab %>% 
            filter(assay == grd$CT1[i])

    tab.sub2 = tab %>% 
            filter(assay == grd$CT2[i])

    tab.join = inner_join(tab.sub1, tab.sub2, by="ID")

    df = list()
    df$cor = with(tab.join, cor(logFC.x, logFC.y, method="spearman"))

    data.frame(CT1 = grd$CT1[i], CT2 = grd$CT2[i], df)
  })
  df = do.call(rbind, df)

  df
}

convertToMatrix = function(df, column){

  cellTypes = unique(c(df$CT1, df$CT2))
  df$i = factor(df$CT1, cellTypes)
  df$j = factor(df$CT2, cellTypes)

  C = sparseMatrix(i = as.numeric(df$i), 
                          j = as.numeric(df$j), 
                          x = df[[column]], 
                          symmetric = TRUE)
  rownames(C) = cellTypes
  colnames(C) = cellTypes
  diag(C) = 1

  as.matrix(C)
}

# Mathys
res.Mathys = jointMatrix(tab.mathys)

# PsychAD
res.PsychAD = jointMatrix(tab.psychad)

# # merged
tab.merge = rbind(tab.mathys %>% mutate(assay = paste0('Mathys.', assay)),
  tab.psychad %>% mutate(assay = paste0('PsychAD.', assay)))

res.merge = jointMatrix(tab.merge)
library(pheatmap)

# res.merge$Dataset1 = sapply(strsplit(res.merge$CT1, '\\.'), function(x) x[2])
# res.merge$Dataset2 = sapply(strsplit(res.merge$CT2, '\\.'), function(x) x[2])

# Pearson correlation
C = convertToMatrix(res.merge, "cor")
ids = sapply(strsplit(rownames(C), '\\.'), function(x) x[1])

i = which(ids == ids[1])
j = which(ids == ids[length(ids)])

bks = seq(-.3, .3, by=.05)
col = colorRampPalette(c("blue", "white", "red"))(length(bks))

colnames(C) = sapply(strsplit(colnames(C), '\\.'), function(x) x[2])
rownames(C) = sapply(strsplit(rownames(C), '\\.'), function(x) x[2])

pheatmap(C[-i, -j], breaks=bks, cellheight=15, cellwidth = 15, color=col, breaklist=seq(-.3, .3, by=0.3), main = "Spearman", xlab="Mathys, et al.", ylab="PsychAD")

A = C[-i, -j]
idx = c( "Exc", "Inh","Oli", "Opc", "Ast", "Mic", "Vas")
A = A[ctorder,idx]
pheatmap(A, breaks=bks, cellheight=15, cellwidth = 15, color=col, breaklist=seq(-.3, .3, by=0.3), main = "Spearman", xlab="Mathys, et al.", ylab="PsychAD", cluster_cols=FALSE, cluster_rows=FALSE)

knitr::knit_exit()