Tuesday, July 24, 2012

Introducing auto_argument

One of the subtle bugs that often bites new Python programmers is default argument values. People expect those default values to be (re-)created and supplied to the function at each invocation. And so, when it does not happen that way, developers scratch their heads.

It is easy, though, to understand why Python acts that way. As explained in the Python Tutorial:

The default values are evaluated at the point of function definition in the defining scope.

And that:

The default value is evaluated only once. This makes a difference when the default is a mutable object such as a list, dictionary, or instances of most classes.

To prevent the default from being shared between subsequent calls, one usually writes code like this:

def f(a, L=None):
    if L is None:
        L = []
    # do something with L

The if block creates a new actual default value to replace the None marker value.

To alleviate repeated typings and clutter if blocks, I wrote a simple module auto_argument that automatically inspects the function arguments on each invocation and replaces them with new values. The code above would look cleaner with callable_auto_argument like this:

@callable_auto_argument([('L', None, dict)])
def f(a, L=None):
    # do something with L

Of course, developers are very welcome to make their own decorator by subclassing one of these auto_argument classes to save redundant code. See auto_dict_argument and test_subclass in the source code for example.

And here is auto_argument in all its glory :-). The code is released under MIT license.

# Copyright 2012, Nam T. Nguyen
# Released under MIT license

'''Automatically replace a default argument with some other (potentially
dynamic) value.

The default argument is usually guarded like this::

    def func(arg=None):
        if arg is None:
            arg = dict()
        // use arg

With decorators provided in this module, one can write::

    __marker = object()

    @callable_auto_argument([('arg', __marker, dict)])
    def func(arg=__marker):
        // use arg

See class:`callable_auto_argument`.

Also, the standard Python behavior could be thought as::

    __marker = object()

    @passthru_auto_argument([('arg', __marker, {})])
    def func(arg=__marker):
        // ...

See class:`passthru_auto_argument`.

These classes can be used by themselves or serve as base classes for more
customizations. For example, to eliminate repeated typings, one can subclass
``callable_auto_argument`` like this::

    class auto_dict_argument(callable_auto_argument):

        def __init__(self, *names):
            names_markers_values = []
            for name in names:
                names_markers_values.append((name, None, dict))
            super(auto_dict_argument, self).__init__(names_markers_values)

And then apply this on methods like this::

    @auto_dict_argument('arg_1', 'arg_2', 'arg_3')
    def method(arg_1=None, arg_2=None, arg_3=None):
        # arg_1, arg_2, arg_3 are new dicts unless specified otherwise
        # and these lines are no longer required
        # if arg_1 is None: arg_1 = {}
        # if arg_2 is None: arg_2 = {}
        # if arg_3 is None: arg_3 = {}

'''


import unittest


class auto_argument(object):
    '''The base class for automatic argument.

    Subclasses must implement method:`create` to create appropriate value for
    the argument.

    '''

    def __init__(self, names_markers_values):
        '''Construct an auto argument decorator with a collection of variable
        names, their markers, and their supposed values.

        The __supposed__ value objects are used in method:`create` to produce
        final value.

        Args:

            names_markers_values (collection of 3-tuples): A collection of
                (string, object, object) tuples specifying (in that order) the
                names, the marker objects, and the supposed value objects

        '''

        self.names_markers_values = names_markers_values

    def create(self, name, current_value, value):
        '''Return a value based for the named argument and its current value.

        This final value will be used to replace what is currently passed in
        the invocation.

        Subclasses MUST override this method to provide more meaningful
        behavior.

        Args:

            name (string): The argument's name
            current_value: Its current value in this invocation
            value: The supposed value passed in during construction time

        Returns:

            Final value for this argument

        '''

        raise NotImplementedError()

    def __call__(self, orig_func):
        def new_func(*args, **kw_args):
            for name, marker, value in self.names_markers_values:
                # check kw_args first
                try:
                    if kw_args[name] is marker:
                        kw_args[name] = self.create(name, kw_args[name], value)
                except KeyError:
                    # ignored
                    pass
                else:
                    continue

                # then check args
                # we need to instropect the arg names from orig_func
                co_obj = orig_func.func_code
                # number of required arguments
                nr_required = (co_obj.co_argcount -
                                    len(orig_func.func_defaults))
                for i in range(nr_required, co_obj.co_argcount):
                    if co_obj.co_varnames[i] != name:
                        continue
                    # is it supplied in args?
                    if i < len(args):
                        if args[i] is marker:
                            if isinstance(args, tuple):
                                args = list(args)
                            args[i] = self.create(name, args[i], value)
                    # it is not, so, check defaults
                    else:
                        default = orig_func.func_defaults[i - nr_required]
                        if default is marker:
                            kw_args[name] = self.create(name, default, value)

            # invoke orig_func with new args
            return orig_func(*args, **kw_args)

        return new_func


class callable_auto_argument(auto_argument):

    def create(self, name, current_value, value):
        # call on value
        return value()


class passthru_auto_argument(auto_argument):

    def create(self, name, current_value, value):
        # just return it directly
        return value


class AutoArgumentTest(unittest.TestCase):

    def test_keyword_1(self):

        marker = 'replace_me'

        @passthru_auto_argument([('arg_name', marker, 'done')])
        def orig_func_0(arg_name=marker):
            return arg_name

        self.assertEqual('done', orig_func_0())
        self.assertEqual('done', orig_func_0(marker))
        self.assertEqual('not_replace', orig_func_0('not_replace'))
        self.assertEqual('done', orig_func_0(arg_name=marker))
        self.assertEqual('not_replace', orig_func_0(arg_name='not_replace'))

        @passthru_auto_argument([('arg_name', 'replace_me', 'done')])
        def orig_func_1(junk, arg_name='replace_me'):
            return junk, arg_name

        self.assertEqual(('ignore', 'done'), orig_func_1('ignore'))
        self.assertEqual(('ignore', 'done'),
                orig_func_1('ignore', marker))
        self.assertEqual(('ignore', 'not_replace'),
                orig_func_1('ignore', 'not_replace'))
        self.assertEqual(('ignore', 'done'),
                orig_func_1('ignore', arg_name=marker))
        self.assertEqual(('ignore', 'not_replace'),
                orig_func_1('ignore', arg_name='not_replace'))

    def test_keyword_2(self):

        marker_1 = 'replace_me'
        marker_2 = 'replace_too'

        @passthru_auto_argument([('arg_1', marker_1, 'done'),
                     ('arg_2', marker_2, 'enod')])
        def orig_func_0(arg_1=marker_1, arg_2=marker_2):
            return arg_1, arg_2

        self.assertEqual(('done', 'enod'), orig_func_0())
        self.assertEqual(('not_replace', 'enod'), orig_func_0('not_replace'))
        self.assertEqual(('done', 'not'), orig_func_0(marker_1, 'not'))
        self.assertEqual(('done', 'enod'), orig_func_0(marker_1, marker_2))
        self.assertEqual(('not_1', 'not_2'), orig_func_0('not_1', 'not_2'))

        @passthru_auto_argument([('arg_1', marker_1, 'done'),
                     ('arg_2', marker_2, 'enod')])
        def orig_func_1(junk, arg_1=marker_1, arg_2=marker_2):
            return junk, arg_1, arg_2

        self.assertEqual(('.', 'done', 'enod'), orig_func_1('.'))
        self.assertEqual(('.', 'not_replace', 'enod'),
                orig_func_1('.', 'not_replace'))
        self.assertEqual(('.', 'done', 'not'),
                orig_func_1('.', marker_1, 'not'))
        self.assertEqual(('.', 'done', 'enod'),
                orig_func_1('.', marker_1, marker_2))
        self.assertEqual(('.', 'not_1', 'not_2'),
                orig_func_1('.', 'not_1', 'not_2'))

    def test_subclass(self):

        class auto_dict_argument(callable_auto_argument):

            def __init__(self, *names):
                names_markers_values = []
                for name in names:
                    names_markers_values.append((name, None, dict))
                super(auto_dict_argument, self).__init__(names_markers_values)

        @auto_dict_argument('arg_1', 'arg_2')
        def test_func(arg_1=None, arg_2=None):
            arg_1['1'] = 'arg_1'
            arg_2['2'] = 'arg_2'
            return (arg_1, arg_2)

        self.assertEqual(({'1': 'arg_1'}, {'2': 'arg_2'}), test_func())
        arg_1, arg_2 = test_func()
        self.assertNotEqual(id(arg_1), id(arg_2))


if __name__ == '__main__':
    unittest.main()

No comments:

Post a Comment