scatterplot matrix with marginal probability distributions in seaborn

It is straightfoward to do scatter plot matrices with seaborn pairplot. Jointplot also allows combining scatter plots with marginal probability distributions for a single plot.

Although the option diag_kind='kde' let you plot the probability distributions in the diagonal (useful when x_varsand y_vars are the same) I want to combine both to have marginal probability distributions in a matrix scatter plot. Something like this:

enter image description here

How do I get marginal probability distributions in a matrix scatter plot in seaborn as shown in my screenshot above?

Answer

Many thanks mwaskom for the guiding.

As you suggested, I built my own matplotlib figure and plotted the seaborn plots there guided by this piece of documentation.

def basic_conf(f,a,xin,yin,x,y):
   ax = f.add_subplot(a)
   ax.tick_params(axis='both', which='major', labelsize=10)
   ax.spines["right"].set_visible(False)
   ax.spines["top"].set_visible(False)
   if xin !=0:
       ax.set_yticklabels([])
       ax.set_ylabel(" ",fontsize=0).set_visible(False)
   ax.set_ylabel(y,fontsize=10)
   ax.set_xticklabels([])
   ax.set_xlabel(" ",fontsize=0).set_visible(False)

   return ax
def xhist_conf(f,a,x):
   ax = f.add_subplot(a)  
   ax.spines["right"].set_visible(False)
   ax.spines["left"].set_visible(False)
   ax.spines["top"].set_visible(False)
   ax.set_yticklabels([])
   ax.yaxis.set_ticks_position('none')
   ax.set_xlabel(x,fontsize=10)
   ax.set_ylabel(" ").set_visible(False)#,fontsize='xx-small'
   return ax

def yhist_conf(f,a,y):
    ax = f.add_subplot(a)
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.xaxis.set_ticks_position('none')
    ax.set_xlabel(" ",fontsize='xx-small').set_visible(False)
    ax.set_ylabel(" ",fontsize=0).set_visible(False)
    return ax

def includer(ax,x,y):
    r,_=stats.pearsonr(concat_convert[x],concat_convert[y])
    ax.text(0.1, 0.9, f'ρ = {r:.2f}', transform=ax.transAxes)#,fontsize='xx-small'


x_vars=["$P_{LA}$", "$R^{Ao}_P$", "$C^{Ao}_P$", "$R^{Ao}_S$", "$B_{VAD}$", "$A_{VAD}$", "HR", "EF"]
y_vars=["${Q}^{avg}_{M}$", "${Q}^{max}_{M}$","${Q}^{avg}_{Ao}$", "${Q}^{max}_{Ao}$", "${Q}^{avg}_{VAD}$", "${Q}^{max}_{VAD}$", "$Q_{RAT}$"]


sns.set(context="paper",font_scale=1.75,style="ticks")


f = plt.figure(figsize=(18, 16), dpi=600)
gs = f.add_gridspec(8, 9)
plt.rcParams['font.size'] = '10'
plt.rcParams['xtick.labelsize']='8'

with sns.axes_style("ticks"):
    xin=0
    for x in x_vars:
        yin=0
        for y in y_vars:
            ax = basic_conf(f,gs[yin,xin],xin,yin,x,y)
            sns.regplot(ax=ax, data=concat_convert, x=x, y=y, scatter_kws={'s':4})
            includer(ax,x,y)
            yin=yin+1
        xin=xin+1

    xin=0
    for x in x_vars:
        ax = xhist_conf(f,gs[yin,xin],x)
        sns.histplot(ax=ax, data=concat_convert, x=x, kde=True)
        xin=xin+1
    yin=0
    for y in y_vars:
        ax = yhist_conf(f,gs[yin,xin],y)
        sns.histplot(ax=ax, data=concat_convert, y=y, kde=True)
        yin=yin+1

for i in range(len(y_vars)):
    ax = f.add_subplot(gs[i,2])
    ax.set_xlim((0.001,0.0014))

ax = f.add_subplot(gs[len(y_vars),0])    
ax.ticklabel_format(style='sci',scilimits=(0,0), axis='x')

ax = f.add_subplot(gs[len(y_vars),5])
ax.ticklabel_format(style='sci',scilimits=(0,0), axis='x')

And it get me exactly what I want:

enter image description here

Many thanks.

EDIT: Final code snippet and obtained plot.

Leave a Reply

Your email address will not be published. Required fields are marked *