@@ -656,12 +656,88 @@ def __exit__(self, t, v, tb): return False
656656 self .fail ("ZeroDivisionError should have been raised" )
657657
658658
659+ class NestedWith (unittest .TestCase ):
660+
661+ class Dummy (object ):
662+ def __init__ (self , value = None , gobble = False ):
663+ if value is None :
664+ value = self
665+ self .value = value
666+ self .gobble = gobble
667+ self .enter_called = False
668+ self .exit_called = False
669+
670+ def __enter__ (self ):
671+ self .enter_called = True
672+ return self .value
673+
674+ def __exit__ (self , * exc_info ):
675+ self .exit_called = True
676+ self .exc_info = exc_info
677+ if self .gobble :
678+ return True
679+
680+ class CtorRaises (object ):
681+ def __init__ (self ): raise RuntimeError ()
682+
683+ class EnterRaises (object ):
684+ def __enter__ (self ): raise RuntimeError ()
685+ def __exit__ (self , * exc_info ): pass
686+
687+ class ExitRaises (object ):
688+ def __enter__ (self ): pass
689+ def __exit__ (self , * exc_info ): raise RuntimeError ()
690+
691+ def testNoExceptions (self ):
692+ with self .Dummy () as a , self .Dummy () as b :
693+ self .assertTrue (a .enter_called )
694+ self .assertTrue (b .enter_called )
695+ self .assertTrue (a .exit_called )
696+ self .assertTrue (b .exit_called )
697+
698+ def testExceptionInExprList (self ):
699+ try :
700+ with self .Dummy () as a , self .CtorRaises ():
701+ pass
702+ except :
703+ pass
704+ self .assertTrue (a .enter_called )
705+ self .assertTrue (a .exit_called )
706+
707+ def testExceptionInEnter (self ):
708+ try :
709+ with self .Dummy () as a , self .EnterRaises ():
710+ self .fail ('body of bad with executed' )
711+ except RuntimeError :
712+ pass
713+ else :
714+ self .fail ('RuntimeError not reraised' )
715+ self .assertTrue (a .enter_called )
716+ self .assertTrue (a .exit_called )
717+
718+ def testExceptionInExit (self ):
719+ body_executed = False
720+ with self .Dummy (gobble = True ) as a , self .ExitRaises ():
721+ body_executed = True
722+ self .assertTrue (a .enter_called )
723+ self .assertTrue (a .exit_called )
724+ self .assertNotEqual (a .exc_info [0 ], None )
725+
726+ def testEnterReturnsTuple (self ):
727+ with self .Dummy (value = (1 ,2 )) as (a1 , a2 ), \
728+ self .Dummy (value = (10 , 20 )) as (b1 , b2 ):
729+ self .assertEquals (1 , a1 )
730+ self .assertEquals (2 , a2 )
731+ self .assertEquals (10 , b1 )
732+ self .assertEquals (20 , b2 )
733+
659734def test_main ():
660735 run_unittest (FailureTestCase , NonexceptionalTestCase ,
661736 NestedNonexceptionalTestCase , ExceptionalTestCase ,
662737 NonLocalFlowControlTestCase ,
663738 AssignmentTargetTestCase ,
664- ExitSwallowsExceptionTestCase )
739+ ExitSwallowsExceptionTestCase ,
740+ NestedWith )
665741
666742
667743if __name__ == '__main__' :
0 commit comments