def subplot(plt, (Y, X), (sz_y, sz_x) = (10, 10)): throws invalid syntax in Python 3

I am trying to run a code that was originally written for Python 2 for Python 3. The code block is:

def draw_bbox(plt, ax, rois, fill=False, linewidth=2, edgecolor=[1.0, 0.0, 0.0], **kwargs):
    for i in range(rois.shape[0]):
        roi = rois[i,:].astype(np.int)
        ax.add_patch(plt.Rectangle((roi[0], roi[1]),
            roi[2] - roi[0], roi[3] - roi[1],
            fill=False, linewidth=linewidth, edgecolor=edgecolor, **kwargs))

def subplot(plt, (Y, X), (sz_y, sz_x) = (10, 10)):
    plt.rcParams['figure.figsize'] = (X*sz_x, Y*sz_y)
    fig, axes = plt.subplots(Y, X)
    return fig, axes

and the error I receive is:

 File "<ipython-input-7-9e2eb5f0d3ab>", line 8
    def subplot(plt, (Y, X), (sz_y, sz_x) = (10, 10)):
                     ^
SyntaxError: invalid syntax

How can I fix this? The code is from this repo: https://github.com/s-gupta/v-coco/blob/master/V-COCO.ipynb

enter image description here

Answer

Here’s how you can rewrite this function:

def subplot(plt, yx, sz = (10, 10)):
    (Y, X) = yx
    (sz_y, sz_x) = sz
    plt.rcParams['figure.figsize'] = (X*sz_x, Y*sz_y)
    fig, axes = plt.subplots(Y, X)
    return fig, axes