Adding decorator to library function without modifying its code

I have a third party library function which is being used in many places in my code. Now I don’t want to go and change in all the places. And So I am thinking monkey patching this function would help me achieve the decorator I want. I just want to add some exceptional handling with retry mechanism to a function.

with Monkey Patching the problem is recursion.

Following the folder structure directly underneath my project folder

plugins
    __init__.py
    hooks
        __init__.py
        base_hook.py
        custom_ssh_hook.py
        ssh_hook.py

This is what my code look like.

hooks/__init__.py

from plugins.hooks import ssh_hook
from plugins.hooks import custom_ssh_hook

#
print "running monkey patching ssh hook get_conn"
ssh_hook.SSHHook.get_conn = custom_ssh_hook.get_conn
print "completed monkey patching ssh hook get_conn"

hooks/base_hook.py

class BaseHook(object):
    def __init__(self, source):
        pass

hooks/ssh_hook.py

from plugins.hooks.base_hook import BaseHook

class SSHHook(BaseHook):
    def __init__(self, source, timeout=10):
        super(SSHHook, self).__init__(source)

    def get_conn(self):
        print("SSH Hook")

hooks/custom_ssh_hook.py

from plugins.hooks.ssh_hook import SSHHook

call_count = 0
def get_conn(self):
    global call_count
    call_count += 1
    if call_count > 1:
        print("Not A good Idea, you are trying recursion")
        return
    try:
        print("custom ssh Hook")
        return SSHHook.get_conn(self)
    except Exception as e:
        print("retry mechanism")
        raise e

I am not able to print variable originalHook does not matter how I import it. I tried following

from plugins.hooks import originalHook
from . import * 

get_conn of custom_ssh_hook is being called recursively.

How do I call get_conn of SSHHook.

Answer

I found out the solutions. Added one more method to class I am patching.

Monkey Patching

from airflow.contrib.hooks import ssh_hook
ssh_hook.SSHHook.original_func = ssh_hook.SSHHook.get_conn
from operators.bmo_ssh_hook import get_conn_with_retry


ssh_hook.SSHHook.get_conn = get_conn_with_retry

Patched Method

max_retry = 3
retry = 0


def get_conn_with_retry(self):
    try:
        return SSHHook.original_func(self)
        # return self.original_func()
    except Exception as e:
        global retry
        if retry < max_retry:
            retry += 1
            print("tried %s times, failed to connect retrying again" %retry)
            self.get_conn()
        else:
            raise e

Leave a Reply

Your email address will not be published. Required fields are marked *