# Copyright (c) 2002-2005 LOGILAB S.A. (Paris, FRANCE). # http://www.logilab.fr/ -- mailto:contact@logilab.fr # # This program is free software; you can redistribute it and/or modify it under # the terms of the GNU General Public License as published by the Free Software # Foundation; either version 2 of the License, or (at your option) any later # version. # # This program is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. # # You should have received a copy of the GNU General Public License along with # this program; if not, write to the Free Software Foundation, Inc., # 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. """unittest generator """ __revision__ = "$Id: test_generator.py,v 1.11 2005/04/15 11:00:49 syt Exp $" import os, sys from os.path import basename, join, exists from logilab.common import astng from logilab.common.astng.utils import is_interface, is_exception, \ IgnoreChild, get_raises, get_returns, LocalsVisitor from logilab.common.configuration import OptionsProviderMixIn from pyreverse.utils import FilterMixIn from pyreverse.extensions.testutils import * PGM = basename(sys.argv[0]) class UnittestGenerator(OptionsProviderMixIn, FilterMixIn, LocalsVisitor): name = 'test generation' options = (("test-dir", {'action' :"store", 'type' : "string", 'dest' : "test_dir", 'metavar' : "", "default" : 'test', 'help' : "put generated tests in ."}), ("header-file", {'default': '', 'type' : 'string', 'action' : 'store', 'metavar' : '', 'dest': 'header', 'help' : 'insert commented content of in each tests \ file generated (for instance, licence).'}), ("backup-files", {'default': 1, 'type' : 'yn', 'metavar' : '', 'dest': 'backup', 'help' : 'do backup files of existing tests (renammed with \ a ~ suffix).'}), ) + FilterMixIn.options def __init__(self): FilterMixIn.__init__(self) OptionsProviderMixIn.__init__(self) LocalsVisitor.__init__(self) self._stream = None def visit_project(self, node): """visit an astng.Project node optionaly compute the additional header """ if self.config.header: self.header = wrap_file(self.config.header) else: self.header = '' def visit_module(self, node): """visit an astng.Module node open a new unit test file """ name = node.name if name.endswith('__init__') and not name.startswith('__init__'): filename = name[:-9].replace('.', '_') else: filename = name.replace('.', '_') filename = join(self.config.test_dir, 'unittest_%s.py' % filename) if self.config.backup and exists(filename): os.rename(filename, '%s~' % filename) self._stream = open(filename, 'w+') os.chmod(filename, 0777) self._stream.write(open_unittest(PGM, node.name, self.header, doc=node.doc)) self._test_class = None def leave_module(self, node): """leave an astng.Module node close the current unit test file """ self._stream.write(close_unittest()) self._stream.close() self._stream = None def visit_class(self, node): """visit an astng.Class node open a new test case """ self._test_filter(node) self._test_class = test = '%sTC' % node.name self._stream.write(open_testcase( test, node.doc, teardown='pass', setup='self.o = %s(%s)' % (node.name, _get_class_init_args(node)))) def visit_function(self, node): """visit an astng.Function node open a new test function if the node is a method, else create a new test case """ self._test_filter(node) # open a new test case for functions but not for methods if not isinstance(node.parent.get_frame(), astng.Class): names = [name.capitalize() for name in node.name.split('_')] self._test_class = test = '%sTC' % ''.join(names) self._stream.write(open_testcase(test, node.doc)) object = '' else: object = 'self.o.' args = node.format_args() returns = get_returns(node) if not returns: returns = [None] i = 0 for returned in returns: i += 1 returned = returned and returned.as_string() or 'None' test_name = _name('known_values', node.name, str(i)) self._stream.write(open_func(test_name, node.doc)) _assert = assert_equal('%s%s(%s)'%(object, node.name, args), returned) self._stream.write(_assert) for exception in get_raises(node): if hasattr(exception, 'name'): exception = exception.name else: exception = exception.getChildNodes()[0].as_string() test_name = _name('raise', node.name, exception) self._stream.write(open_func(test_name)) self._stream.write(assert_raise(exception, '%s%s, %s' % ( object, node.name, args))) def _test_filter(self, node): """raise IgnoreChild if the node should not be processed """ if not self.filter(node.name): raise IgnoreChild() # take only first level classes which are not interface nor signal if isinstance(node, astng.Class) and ( (not isinstance(node.parent.get_frame(), astng.Module)) or is_interface(node) or is_exception(node)): raise IgnoreChild() # ignore constructor if isinstance(node, astng.Function) and node.name == '__init__': raise IgnoreChild() return 1 def _name(group, name, tid): """build a test function name""" return 'test_%s_%s_%s' % (group, name, tid) def _get_class_init_args(node): """return a formatted string for class constructor arguments""" try: return node.locals['__init__'].format_args() except KeyError: try: return node.get_ancestor_for_method('__init__').format_args() except: return ''