Align colorbar with GeoAxes subplot edges

I have a figure with 3 subplots, two of which share a colorbar and the third has has it’s own colorbar.

I would like the colorbars to align with the vertical limits of their respective plots, and for the top two plots to have the same vertical limits.

Googling, I have found ways to do this with a single plot, but am stuck trying to make it work for my fig. My figure currently looks like this:

enter image description here

The code for which is as follows:

import cartopy.io.shapereader as shpreader
import cartopy.crs as ccrs
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


shpfilename = shpreader.natural_earth(resolution='50m',
                                      category='cultural',
                                      name='admin_0_countries')

reader = shpreader.Reader(shpfilename)

countries = reader.records()

projection = ccrs.PlateCarree()

fig = plt.figure()
axs = [plt.subplot(2, 2, x + 1, projection = projection) for x in range(2)]
    + [plt.subplot(2, 2, (3, 4), projection = projection)]


def cmap_seg(cmap, value, k):
    cmaplist = [cmap(i) for i in range(cmap.N)]
    cmap = mpl.colors.LinearSegmentedColormap.from_list(
                            'Custom cmap', cmaplist, cmap.N)
    bounds = np.linspace(0, k, k + 1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    color = cmap(norm(value))
    return color, cmap


for country in countries:
    
    c_name = country.attributes["SOVEREIGNT"]
    country_dat = df.loc[c_name]
    
    cmap = matplotlib.cm.get_cmap("plasma")
    cmap_blues = matplotlib.cm.get_cmap("Blues")
    ax_extent = [-170, 180, -65, 85]
    alpha = 1.0
    edgecolor = "k"
    linewidth = 0.5
    
    ax = axs[0]
    value = country_dat.loc["wgi_bin"]
    ax.add_geometries([country.geometry],
                    projection,
                    facecolor = cmap_seg(cmap, value, 5)[0],
                    alpha = alpha,
                    edgecolor = edgecolor,
                    linewidth = linewidth)
    ax.set_xlabel("WGI group")
    ax.set_extent(ax_extent)

    
    ax = axs[1]
    value = country_dat.loc["epi_bin"]
    ax.add_geometries([country.geometry],
                    projection,
                    facecolor = cmap_seg(cmap, value, 5)[0],
                    alpha = alpha,
                    edgecolor = edgecolor,
                    linewidth = linewidth)
    ax.set_xlabel("EPI group")
    ax.set_extent(ax_extent)
        
    
    ax = axs[2]
    value = country_dat.loc["diff"]
    ax.add_geometries([country.geometry],
                    projection,
                    facecolor = cmap_seg(cmap_blues, value, 4)[0],
                    alpha = alpha,
                    edgecolor = edgecolor,
                    linewidth = linewidth)
    ax.set_xlabel("difference")
    ax.set_extent(ax_extent)

subplot_labels = ["WGI group", "EPI group", "Metric difference"]

for i, ax in enumerate(axs):
    ax.text(0.5, -0.07, subplot_labels[i], va='bottom', ha='center',
        rotation='horizontal', rotation_mode='anchor',
        transform=ax.transAxes)
    
sm = plt.cm.ScalarMappable(cmap=cmap_seg(cmap, 5, 5)[1], norm = plt.Normalize(0, 5))
sm._A = []
cb = plt.colorbar(sm, ax = axs[1], values = [1,2,3,4, 5], ticks = [1,2,3,4,5])
                  
sm2 = plt.cm.ScalarMappable(cmap=cmap_seg(cmap_blues, 5, 4)[1], norm = plt.Normalize(0, 4))
sm2._A = []
cb2 = plt.colorbar(sm2, ax = axs[2], values = [0,1,2,3], ticks = [0,1,2,3])

Answer

Try this:

# update your code for this specific line (added shrink option)
cb = plt.colorbar(sm, ax=axs[1], values=[1,2,3,4,5], ticks=[1,2,3,4,5], shrink=0.6)

And add these lines of code towards the end:

p00 = axs[0].get_position()
p01 = axs[1].get_position()
p00_new = [p00.x0, p01.y0, p00.width, p01.height]
axs[0].set_position(p00_new)

The plot should be similar to this:

sample_output