Wednesday, July 18, 2012

Enforcing concrete class in Python

Interface-based programming is one of the two principles that Gang of Four book advocates for. This style of programming makes the contract between a provider and a consumer crystal clear while securely hides the implementation details and allows for easy modification.

Hence, we can find the concepts of interface and abstract class in virtually all object oriented programming languages. An interface provides a pure contract without any implementation. A concrete class implements all of them. And an abstract class lies somewhere in between with some unimplemented contracts, and some implemented ones.

Python, in its pureness, does not facilitate this separation of concerns. There is no interface nor abstract class in Python. And the usual way to define an interface or abstract class in Python is to raise NotImplementedError in the methods which subclasses need to implement.

Even though it usually works fine in practice, this does not enforce a concrete subclass to implement all unimplemented methods. When a concrete subclass forgets to implement some methods, the Python interpreter still happily runs the code until it hits the raise statement.

So I wrote a quick decorator to ensure that if a class is marked concrete then all its functions must not raise NotImplementedError. A SyntaxError would be thrown if any method (even inherited one) was left unimplemented. The check happens at class definition time.

To be clear, Python does have Abstract Base Class (ABC) but that is different from the regular interface/abstract class concept.

'''Set of (class) decorators to help with interface programming.

Examples::

    @concrete
    class subclass(parent):
        ...

Released to the public domain.

Nam T. Nguyen, 2012

'''

import dis
import types
import unittest


def concrete(orig_class):
    '''A decorator to ensure that a concrete class has no un-implemented
    methods.

    An un-implemented method is defined similarly to::

        def method(...):
            raise NotImplementedError()

    This, when translated to bytecode, looks like::

        LOAD_GLOBAL       0 (NotImplementedError)
        CALL_FUNCTION     1
        RAISE_VARARGS     1
        ...

    Or when ``raise NotImplementedError``::

        LOAD_GLOBAL       0 (NotImplementedError)
        RAISE_VARARGS     1
        ...

    The check here is for such pattern.

    '''

    for name in dir(orig_class):
        func = getattr(orig_class, name)
        # correct type?
        if type(func) not in (types.FunctionType, types.MethodType):
            continue
        # check if first name is NotImplementedError
        if len(func.func_code.co_names) < 1 or \
                func.func_code.co_names[0] != 'NotImplementedError':
            continue
        # and RAISE_VARARGS somewhere after that
        for position in (3, 6):
            if len(func.func_code.co_code) < position:
                continue
            opcode = ord(func.func_code.co_code[position])
            if dis.opname[opcode] == 'RAISE_VARARGS':
                raise SyntaxError('Function %s.%s must be implemented.' % (
                        orig_class.__name__, func.func_name))

    return orig_class


class ConcreteTest(unittest.TestCase):

    def test_raise(self):
        class test1:
            def method(self):
                raise NotImplementedError
        self.assertRaises(SyntaxError, concrete, test1)

        class test2:
            def method(self):
                raise NotImplementedError()
        self.assertRaises(SyntaxError, concrete, test2)

    def test_not_raise(self):
        class test:
            def method(self):
                return NotImplementedError()
        concrete(test)

    def test_subclass(self):
        class parent(object):
            def override_me(self):
                raise NotImplementedError()
        class test(parent):
            def leave_it(self):
                pass
        self.assertRaises(SyntaxError, concrete, test)


if __name__ == '__main__':
    unittest.main()

No comments:

Post a Comment