library(SingleCellExperiment)
library(zellkonverter)
library(dreamlet)
library(crumblr)
library(aplot) 
library(tidyverse) 
library(RColorBrewer)
library(scales)
library(ggtree) 
library(kableExtra)
library(tidyverse) 
library(cowplot)
# read single cell RNA-seq
file = "/sc/arion/projects/CommonMind/hoffman/scRNAseq_data/Yazar_Science_2022/08984b3c-3189-4732-be22-62f1fe8f15a4.h5ad"
sce = readH5AD(file, use_hdf5=TRUE)
counts(sce) = assay(sce, "X")

# keep cell types with suficient counts
tab = table(sce$cell_type) > 1000
keep = names(tab)[tab]

# Compute pseudobulk by donor
pb <- aggregateToPseudoBulk(sce[,sce$cell_type %in% keep],
    assay = "counts",     
    cluster_id = "cell_type",  
    sample_id = "donor_id",
    verbose = FALSE)

# Compute pseudobulk by pool
pb_pool <- aggregateToPseudoBulk(sce[,sce$cell_type %in% keep],
    assay = "counts",     
    cluster_id = "cell_type",  
    sample_id = "pool_number",
    verbose = FALSE)

rm(sce)

# crumblr transform
cobj = crumblr(cellCounts(pb)[,keep])

Bar plots of cell composition

# By donor
df = cellCounts(pb)[,keep] %>%
        data.frame
df_fracs = df / rowSums(df) 

df_fracs %>%
  rowid_to_column("Sample") %>% 
  pivot_longer(!Sample) %>%
  ggplot(aes(Sample, value, fill=name)) +
    geom_bar(stat="identity") +
    theme(legend.position="none", aspect.ratio=1/4) +
    coord_cartesian(expand=FALSE) +
    ylab("Cell fraction")

# By pool
df = cellCounts(pb_pool)[,keep] %>%
        data.frame
df_fracs = df / rowSums(df) 

df_fracs %>%
  rowid_to_column("Pool") %>% 
  pivot_longer(!Pool) %>%
  ggplot(aes(Pool, value, fill=name)) +
    geom_bar(stat="identity") +
    theme(legend.position="none", aspect.ratio=1/4) +
    coord_cartesian(expand=FALSE) +
    ylab("Cell fraction")

form = ~ age + (1|sex) + (1|pool_number) 
res.vp = fitExtractVarPartModel(cobj, form, colData(pb) )

cols = c(brewer.pal(ncol(res.vp)-1, "Set1"), "grey85")
fig.vp = plotPercentBars(sortCols(res.vp), col=cols)

form = ~ age + (1|sex) + (1|pool_number) 
fit = dream(cobj, form, colData(pb))
fit = eBayes(fit)
 
hc = buildClusterTreeFromPB(pb)
topTable(fit, coef="age", number=Inf) %>%   
  select(logFC, AveExpr, t, P.Value, adj.P.Val) %>% 
  kbl() %>%  
  kable_classic(full_width = FALSE)
logFC AveExpr t P.Value adj.P.Val
naive thymus-derived CD8-positive, alpha-beta T cell -0.0297043 0.8350022 -22.9403821 0.0000000 0.0000000
natural killer cell 0.0120883 2.3190839 11.9150283 0.0000000 0.0000000
CD14-low, CD16-positive monocyte 0.0162833 -0.4107950 9.5143970 0.0000000 0.0000000
effector memory CD8-positive, alpha-beta T cell 0.0117192 2.1922798 8.7885939 0.0000000 0.0000000
CD4-positive, alpha-beta cytotoxic T cell 0.0164453 -0.3438581 8.7555044 0.0000000 0.0000000
regulatory T cell 0.0073007 0.4906432 8.0950823 0.0000000 0.0000000
mucosal invariant T cell -0.0120322 -0.9298774 -7.7254366 0.0000000 0.0000000
central memory CD4-positive, alpha-beta T cell 0.0052265 2.9558465 7.0501563 0.0000000 0.0000000
CD14-positive monocyte 0.0108346 0.2642911 7.0014083 0.0000000 0.0000000
platelet 0.0055331 -2.3320299 4.4070159 0.0000117 0.0000269
conventional dendritic cell 0.0057286 -1.3642922 4.2402716 0.0000244 0.0000511
naive thymus-derived CD4-positive, alpha-beta T cell -0.0042212 2.7453926 -3.6332994 0.0002942 0.0005640
memory B cell -0.0033129 0.5517829 -3.1315205 0.0017937 0.0031735
naive B cell 0.0036105 1.2591626 2.6695028 0.0077238 0.0126892
plasmablast -0.0030283 -1.6912639 -1.8843197 0.0598402 0.0917550
plasmacytoid dendritic cell 0.0020623 -2.1977522 1.3345923 0.1823200 0.2494650
central memory CD8-positive, alpha-beta T cell -0.0015852 -0.0704142 -1.3282992 0.1843872 0.2494650
transitional stage B cell 0.0012492 0.3508902 0.8674714 0.3859196 0.4931195
CD16-negative, CD56-bright natural killer cell, human 0.0007168 -0.8784892 0.6066970 0.5441943 0.6560342
effector memory CD4-positive, alpha-beta T cell 0.0005516 0.6375663 0.5675653 0.5704646 0.6560342
hematopoietic precursor cell -0.0005647 -2.1414414 -0.4161833 0.6773772 0.7172313
gamma-delta T cell 0.0005720 -0.0653021 0.4043480 0.6860474 0.7172313
double negative thymocyte -0.0004358 -2.1764256 -0.3264708 0.7441382 0.7441382
res = treeTest( fit, cobj, hc, coef="age")
fig1 = plotTreeTestBeta(res) + ggtitle('Age') 

fig2 = crumblr::plotForest(res, hide=TRUE)

# combine plots
fig2 %>% insert_left(fig1) %>% insert_right(fig.vp) 

Plot each regression

# compute residuals
form = ~ (1|sex) + (1|pool_number) 
fit = dream(cobj, form, colData(pb))

# plot of CLR
figList = lapply(rownames(cobj$E), function(CT){

  df = data.frame(CLR = cobj$E[CT,],
                  se = 1/sqrt(cobj$weights[CT,]),
                  colData(pb))

  df %>% 
    ggplot(aes(age, CLR, color=se, weight=1/se^2)) +
      geom_point() +
      theme(aspect.ratio=1) +
      theme_classic() +
      theme(plot.title = element_text(hjust = 0.5), aspect.ratio=1) +
      ggtitle(gsub(",", "\n", CT)) +
      scale_color_gradient(low=muted('red', 70, 20), high="red", name = "SE") +
      geom_smooth(method=lm, formula = y ~ x, color="navy", se=FALSE)
})

plot_grid(plotlist=figList, ncol=3)

library(splines)
i = 10
figList[[i]] + 
  geom_smooth(method=lm, formula = y ~ x + I(x^2), color="green", se=FALSE) +
  geom_smooth(method="lm", formula = y ~ x + I(x^2), weight=1, color="turquoise", se=FALSE)

figList2 = lapply( figList, function(fig){
  fig + 
  geom_smooth(method=lm, formula = y ~ x + I(x^2), color="green", se=FALSE) +
  geom_smooth(method="lm", formula = y ~ x + I(x^2), weight=1, color="turquoise", se=FALSE)
  })

plot_grid(plotlist=figList2, ncol=3)

# plot of residuals
figList = lapply(rownames(cobj$E), function(CT){

  df = data.frame(residuals = residuals(fit)[CT,],
                  se = 1/sqrt(cobj$weights[CT,]),
                  colData(pb))

  df %>% 
    ggplot(aes(age, residuals, color=se, weight=1/se^2)) +
      geom_point() +
      theme(aspect.ratio=1) +
      theme_classic() +
      theme(plot.title = element_text(hjust = 0.5), aspect.ratio=1) +
      ggtitle(gsub(",", "\n", CT)) +
      scale_color_gradient(low=muted('red', 80, 1), high="red", name = "SE") +
      geom_smooth(method=lm, formula = y ~ x, color="navy", se=FALSE)
})

plot_grid(plotlist=figList, ncol=3)