#!/usr/bin/env python import setpath import unittest from bike.refactor.extractMethod import ExtractMethod, \ extractMethod, coords from bike import testdata from bike.testutils import * from bike.parsing.load import Cache def assertTokensAreSame(t1begin, t1end, tokens): it = t1begin.clone() pos = 0 while it != t1end: assert it.deref() == tokens[pos] it.incr() pos+=1 assert pos == len(tokens) def helper(src,startcoords, endcoords, newname): sourcenode = createAST(src) extractMethod(tmpfile, startcoords, endcoords, newname) return sourcenode.getSource() class TestExtractMethod(BRMTestCase): def test_extractsPass(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): pass """) srcAfter=trimLines(""" class MyClass: def myMethod(self): self.newMethod() def newMethod(self): pass """) src = helper(srcBefore, coords(3, 8), coords(3, 12), "newMethod") self.assertEqual(src,srcAfter) def test_extractsPassWhenFunctionAllOnOneLine(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): pass # comment """) srcAfter=trimLines(""" class MyClass: def myMethod(self): self.newMethod() # comment def newMethod(self): pass """) src = helper(srcBefore, coords(2, 24), coords(2, 28),"newMethod") self.assertEqual(src,srcAfter) def test_extractsPassFromForLoop(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): # comment for i in foo: pass """) srcAfter=trimLines(""" class MyClass: def myMethod(self): # comment for i in foo: self.newMethod() def newMethod(self): pass """) src = helper(srcBefore, coords(4, 12), coords(4, 16), "newMethod") self.assertEqual(srcAfter, src) def test_newMethodHasArgumentsForUsedTemporarys(self): srcBefore=trimLines(""" class MyClass: def myMethod(self, c): a = something() b = somethingelse() print a + b + c + d print \"hello\" dosomethingelse(a, b) """) srcAfter=trimLines(""" class MyClass: def myMethod(self, c): a = something() b = somethingelse() self.newMethod(a, b, c) dosomethingelse(a, b) def newMethod(self, a, b, c): print a + b + c + d print \"hello\" """) src = helper(srcBefore, coords(5, 8), coords(6, 21), "newMethod") self.assertEqual(srcAfter, src) def test_newMethodHasSingleArgument(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): a = something() print a print \"hello\" dosomethingelse(a, b) """) srcAfter=trimLines(""" class MyClass: def myMethod(self): a = something() self.newMethod(a) dosomethingelse(a, b) def newMethod(self, a): print a print \"hello\" """) src = helper(srcBefore, coords(4, 8), coords(5, 21), "newMethod") self.assertEqual(srcAfter, src) def test_doesntHaveDuplicateArguments(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): a = 3 print a print a """) srcAfter=trimLines(""" class MyClass: def myMethod(self): a = 3 self.newMethod(a) def newMethod(self, a): print a print a """) src = helper(srcBefore, coords(4, 0), coords(6, 0), "newMethod") self.assertEqual(srcAfter, src) def test_extractsQueryWhenFunctionAllOnOneLine(self): srcBefore=trimLines(""" class MyClass: def myMethod(self, a): print a # comment """) srcAfter=trimLines(""" class MyClass: def myMethod(self, a): self.newMethod(a) # comment def newMethod(self, a): print a """) src = helper(srcBefore, coords(2, 27), coords(2, 34), "newMethod") self.assertEqual(srcAfter, src) def test_worksWhenAssignmentsToTuples(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): a, b, c = 35, 36, 37 print a + b """) srcAfter=trimLines(""" class MyClass: def myMethod(self): a, b, c = 35, 36, 37 self.newMethod(a, b) def newMethod(self, a, b): print a + b """) src = helper(srcBefore, coords(4, 8), coords(4, 19), "newMethod") self.assertEqual(srcAfter, src) def test_worksWhenUserSelectsABlockButDoesntSelectTheHangingDedent(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): # comment for i in foo: pass """) srcAfter=trimLines(""" class MyClass: def myMethod(self): # comment for i in foo: self.newMethod() def newMethod(self): pass """) src = helper(srcBefore, coords(4, 8), coords(4, 16), "newMethod") self.assertEqual(srcAfter, src) def test_newMethodHasSingleReturnValue(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): a = 35 # <-- extract me print a """) srcAfter=trimLines(""" class MyClass: def myMethod(self): a = self.newMethod() print a def newMethod(self): a = 35 # <-- extract me return a """) src = helper(srcBefore, coords(3, 4), coords(3, 34), "newMethod") self.assertEqual(srcAfter, src) def test_newMethodHasMultipleReturnValues(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): a = 35 b = 352 print a + b """) srcAfter=trimLines(""" class MyClass: def myMethod(self): a, b = self.newMethod() print a + b def newMethod(self): a = 35 b = 352 return a, b """) src = helper(srcBefore, coords(3, 8), coords(4, 15), "newMethod") self.assertEqual(srcAfter, src) def test_worksWhenMovingCodeJustAfterDedent(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): # comment for i in foo: pass print \"hello\" """) srcAfter=trimLines(""" class MyClass: def myMethod(self): # comment for i in foo: pass self.newMethod() def newMethod(self): print \"hello\" """) src = helper(srcBefore, coords(5, 8), coords(5, 21), "newMethod") self.assertEqual(srcAfter, src) def test_extractsPassWhenSelectionCoordsAreReversed(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): pass """) srcAfter=trimLines(""" class MyClass: def myMethod(self): self.newMethod() def newMethod(self): pass """) src = helper(srcBefore, coords(3, 12), coords(3, 8), "newMethod") self.assertEqual(srcAfter, src) def test_extractsExpression(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): # comment a = 32 b = 2 + a * 1 + 2 """) srcAfter=trimLines(""" class MyClass: def myMethod(self): # comment a = 32 b = 2 + self.newMethod(a) + 2 def newMethod(self, a): return a * 1 """) src = helper(srcBefore, coords(4, 16), coords(4, 21), "newMethod") self.assertEqual(srcAfter, src) def test_extractsExpression2(self): srcBefore=trimLines(""" class MyClass: def myMethod(self): # comment g = 32 assert output.thingy(g) == \"bah\" """) srcAfter=trimLines(""" class MyClass: def myMethod(self): # comment g = 32 assert self.newMethod(g) == \"bah\" def newMethod(self, g): return output.thingy(g) """) src = helper(srcBefore, coords(4, 15), coords(4, 31), "newMethod") self.assertEqual(srcAfter, src) class TestExtractFunction(BRMTestCase): def runTarget(self, src, begincoords, endcoords, newname): ast = createAST(src) extractFunction(ast, begincoords, endcoords, newname) return ast def test_extractsFunction(self): srcBefore=trimLines(""" def myFunction(): # comment a = 3 c = a + 99 b = c * 1 print b """) srcAfter=trimLines(""" def myFunction(): # comment a = 3 b = newFunction(a) print b def newFunction(a): c = a + 99 b = c * 1 return b """) src = helper(srcBefore, coords(3, 4), coords(4, 13), "newFunction") self.assertEqual(srcAfter, src) def test_extractsAssignToAttribute(self): srcBefore=trimLines(""" def simulateLoad(path): item = foo() item.decl = line """) srcAfter=trimLines(""" def simulateLoad(path): item = foo() newFunction(item) def newFunction(item): item.decl = line """) src = helper(srcBefore, coords(3, 0), coords(4, 0), "newFunction") self.assertEqual(srcAfter, src) def test_extractsFromFirstBlockOfIfElseStatement(self): srcBefore=trimLines(""" def foo(): if bah: print \"hello1\" print \"hello2\" elif foo: pass """) srcAfter=trimLines(""" def foo(): if bah: newFunction() print \"hello2\" elif foo: pass def newFunction(): print \"hello1\" """) src = helper(srcBefore, coords(3, 0), coords(4, 0), "newFunction") self.assertEqual(srcAfter, src) def test_extractsAugAssign(self): srcBefore=trimLines(""" def foo(): a = 3 a += 1 print a """) srcAfter=trimLines(""" def foo(): a = 3 a = newFunction(a) print a def newFunction(a): a += 1 return a """) src = helper(srcBefore, coords(3, 0), coords(4, 0), "newFunction") self.assertEqual(srcAfter, src) def test_extractsForLoopUsingLoopVariable(self): srcBefore=trimLines(""" def foo(): for i in range(1, 3): print i """) srcAfter=trimLines(""" def foo(): for i in range(1, 3): newFunction(i) def newFunction(i): print i """) src = helper(srcBefore, coords(3, 0), coords(4, 0), "newFunction") self.assertEqual(srcAfter, src) def test_extractWhileLoopVariableIncrement(self): srcBefore=trimLines(""" def foo(): a = 0 while a != 3: a = a+1 """) srcAfter=trimLines(""" def foo(): a = 0 while a != 3: a = newFunction(a) def newFunction(a): a = a+1 return a """) src = helper(srcBefore, coords(4, 0), coords(5, 0), "newFunction") self.assertEqual(srcAfter, src) def test_extractAssignedVariableUsedInOuterForLoop(self): srcBefore=trimLines(""" def foo(): b = 0 for a in range(1, 3): b = b+1 while b != 2: print a b += 1 """) srcAfter=trimLines(""" def foo(): b = 0 for a in range(1, 3): b = b+1 while b != 2: b = newFunction(a, b) def newFunction(a, b): print a b += 1 return b """) src = helper(srcBefore, coords(6, 0), coords(8, 0), "newFunction") self.assertEqual(srcAfter, src) def test_extractsConditionalFromExpression(self): srcBefore=trimLines(""" def foo(): if 123+3: print aoue """) srcAfter=trimLines(""" def foo(): if newFunction(): print aoue def newFunction(): return 123+3 """) src = helper(srcBefore, coords(2, 7), coords(2, 12), "newFunction") self.assertEqual(srcAfter, src) def test_extractCodeAfterCommentInMiddleOfFnDoesntRaiseParseException(self): srcBefore=trimLines(""" def theFunction(): print 1 # comment print 2 """) srcAfter=trimLines(""" def theFunction(): print 1 # comment newFunction() def newFunction(): print 2 """) src = helper(srcBefore, coords(4, 0), coords(5, 0), "newFunction") self.assertEqual(srcAfter, src) def test_canExtractQueryFromNestedIfStatement(self): srcBefore=trimLines(""" def theFunction(): if foo: # comment if bah: pass """) srcAfter=trimLines(""" def theFunction(): if foo: # comment if newFunction(): pass def newFunction(): return bah """) src = helper(srcBefore, coords(3, 11), coords(3, 14), "newFunction") self.assertEqual(srcAfter, src) def test_doesntMessUpTheNextFunctionOrClass(self): srcBefore=trimLines(""" def myFunction(): a = 3 print \"hello\"+a # extract me class MyClass: def myMethod(self): b = 12 # extract me c = 3 # and me d = 2 # and me print b, c """) srcAfter=trimLines(""" def myFunction(): a = 3 newFunction(a) def newFunction(a): print \"hello\"+a # extract me class MyClass: def myMethod(self): b = 12 # extract me c = 3 # and me d = 2 # and me print b, c """) # extract code on one line src = helper(srcBefore, coords(3, 4), coords(3, 34), "newFunction") self.assertEqual(srcAfter, src) # extract code on 2 lines (most common user method) resetRoot() Cache.instance.reset() Root() src = helper(srcBefore, coords(3, 0), coords(4, 0), "newFunction") self.assertEqual(srcAfter, src) def test_doesntBallsUpIndentWhenTheresALineWithNoSpacesInIt(self): srcBefore=trimLines(""" def theFunction(): if 1: pass pass """) srcAfter=trimLines(""" def theFunction(): newFunction() def newFunction(): if 1: pass pass """) src = helper(srcBefore, coords(2, 4), coords(5, 8), "newFunction") self.assertEqual(srcAfter, src) def test_doesntHaveToBeInsideAFunction(self): srcBefore=trimLines(r""" a = 1 print a + 2 f(b) """) srcAfter=trimLines(r""" a = 1 newFunction(a) def newFunction(a): print a + 2 f(b) """) src = helper(srcBefore, coords(2, 0), coords(3, 4), "newFunction") self.assertEqual(srcAfter, src) def test_doesntBarfWhenEncountersMethodCalledOnCreatedObj(self): srcBefore=trimLines(r""" results = QueryEngine(q).foo() """) srcAfter=trimLines(r""" newFunction() def newFunction(): results = QueryEngine(q).foo() """) src = helper(srcBefore, coords(1, 0), coords(2, 0), "newFunction") self.assertEqual(srcAfter, src) def test_worksIfNoLinesBeforeExtractedCode(self): srcBefore=trimLines(r""" print a + 2 f(b) """) srcAfter=trimLines(r""" newFunction() def newFunction(): print a + 2 f(b) """) src = helper(srcBefore, coords(1, 0), coords(2, 4), "newFunction") self.assertEqual(srcAfter, src) class TestGetRegionAsString(BRMTestCase): def test_getsHighlightedSingleLinePassStatement(self): src=trimLines(""" class MyClass: def myMethod(self): pass """) sourcenode = createAST(src) em = ExtractMethod(sourcenode, coords(3, 8), coords(3, 12), "foobah") em.getRegionToBuffer() self.assertEqual(len(em.extractedLines), 1) self.assertEqual(em.extractedLines[0], "pass\n") def test_getsSingleLinePassStatementWhenWholeLineIsHighlighted(self): src=trimLines(""" class MyClass: def myMethod(self): pass """) sourcenode = createAST(src) em = ExtractMethod(sourcenode, coords(3, 0), coords(3, 12), "foobah") em.getRegionToBuffer() self.assertEqual(len(em.extractedLines), 1) self.assertEqual(em.extractedLines[0], "pass\n") def test_getsMultiLineRegionWhenJustRegionIsHighlighted(self): src=trimLines(""" class MyClass: def myMethod(self): print 'hello' pass """) region=trimLines(""" print 'hello' pass """) sourcenode = createAST(src) em = ExtractMethod(sourcenode, coords(3, 8), coords(4, 12), "foobah") em.getRegionToBuffer() self.assertEqual(em.extractedLines, region.splitlines(1)) def test_getsMultiLineRegionWhenRegionLinesAreHighlighted(self): src=trimLines(""" class MyClass: def myMethod(self): print 'hello' pass """) region=trimLines(""" print 'hello' pass """) sourcenode = createAST(src) em = ExtractMethod(sourcenode, coords(3, 0), coords(5, 0), "foobah") em.getRegionToBuffer() self.assertEqual(em.extractedLines, region.splitlines(1)) def test_getsHighlightedSubstringOfLine(self): src=trimLines(""" class MyClass: def myMethod(self): if a == 3: pass """) region=trimLines(""" a == 3 """) sourcenode = createAST(src) em = ExtractMethod(sourcenode, coords(3, 11), coords(3, 17), "foobah") em.getRegionToBuffer() self.assertEqual(em.extractedLines, region.splitlines(1)) class TestGetTabwidthOfParentFunction(BRMTestCase): def test_getsTabwidthForSimpleMethod(self): src=trimLines(""" class MyClass: def myMethod(self): pass """) sourcenode = createAST(src) em = ExtractMethod(sourcenode, coords(3, 11), coords(3, 17), "foobah") self.assertEqual(em.getTabwidthOfParentFunction(), 4) def test_getsTabwidthForFunctionAtRootScope(self): src=trimLines(""" def myFn(self): pass """) sourcenode = createAST(src) em = ExtractMethod(sourcenode, coords(2, 0), coords(2, 9), "foobah") self.assertEqual(em.getTabwidthOfParentFunction(), 0) if __name__ == "__main__": unittest.main()