diff --git a/docs/benefits-over-pyright/fixed-context-manager-exit-types.md b/docs/benefits-over-pyright/fixed-context-manager-exit-types.md new file mode 100644 index 000000000..d776b42e6 --- /dev/null +++ b/docs/benefits-over-pyright/fixed-context-manager-exit-types.md @@ -0,0 +1,153 @@ +# fixed handling for context managers that can suppress exceptions + +## the problem + +if an exception is raised inside a context manager and its `__exit__` method returns `True`, it will be suppressed: + +```py +class SuppressError(AbstractContextManager[None, bool]): + @override + def __enter__(self) -> None: + pass + + @override + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> bool: + return True +``` + +but if it returns `False` or `None`, the exception will not be suppressed: + +```py +class Log(AbstractContextManager[None, Literal[False]]): + @override + def __enter__(self) -> None: + print("entering context manager") + + @override + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> Literal[False]: + print("exiting context manager") + return False +``` + +pyright will take this into account when determining reachability: + +```py +def raise_exception() -> Never: + raise Exception + +with SuppressError(): # see definition for `SuppressError` above + foo: int = raise_exception() + +# when the exception is raised, the context manager exits before foo is assigned to: +print(foo) # error: "foo" is unbound (reportPossiblyUnboundVariable) +``` + +```py +with Log(): # see definition for `Log` above + foo: int = raise_exception() + +# when the exception is raised, it does not get suppressed so this line can never run: +print(foo) # error: Code is unreachable (reportUnreachable) +``` + +however, due to [a bug in mypy](https://github.com/python/mypy/issues/8766) that [pyright blindly copied and accepted as the "standard"](https://github.com/microsoft/pyright/issues/6034#issuecomment-1738941412), a context manager will incorrectly be treated as if it never suppresses exceptions if its return type is a union of `bool | None`: + +```py +class SuppressError(AbstractContextManager[None, bool | None]): + @override + def __enter__(self) -> None: + pass + + @override + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> bool | None: + return True + + +with SuppressError(): + foo: int = raise_exception() + +# this error is wrong because this line is actually reached at runtime: +print(foo) # error: Code is unreachable (reportUnreachable) +``` + +## the solution + +basedpyright introduces a new setting, `strictContextManagerExitTypes` to address this issue. when enabled, context managers where the `__exit__` dunder returns `bool | None` are treated the same way as context managers that return `bool` or `Literal[True]`. put simply, if `True` is assignable to the return type, then it's treated as if it can suppress exceptions. + +## issues with `@contextmanager` + +the reason we support disabling this fix using the `strictContextManagerExitTypes` setting is because it will cause all context managers decorated with `@contextlib.contextmanager` to be treated as if they can suppress an exception, even if they never do: + +```py +@contextmanager +def log(): + print("entering context manager") + try: + yield + finally: + print("exiting context manager") + +with log(): + foo: int = get_value() + +# basedpyright accounts for the possibility that get_value raised an exception and foo +# was never assigned to, even though this context manager never suppresses exceptions +print(foo) # error: "foo" is unbound (reportPossiblyUnboundVariable) +``` + +this is because there's no way to tell a type checker whether the function body wraps the `yield` statement inside a `try`/`except` statement, which is necessary to suppress exeptions when using the `@contextmanager` decorator: + +```py +@contextmanager +def suppress_error(): + try: + yield + except: + pass +``` + +due to this limitation in the type system, the `@contextmanager` dectorator always modifies the return type of generator functions from `Iterator[T]` to `_GeneratorContextManager[T]`, which extends `AbstractContextManager[T, bool | None]`. + +```py +# contextlib.pyi + +def contextmanager(func: Callable[_P, Iterator[_T_co]]) -> Callable[_P, _GeneratorContextManager[_T_co]]: ... + +class _GeneratorContextManager(_GeneratorContextManagerBase, AbstractContextManager[_T_co, bool | None], ContextDecorator): + ... +``` + +and since `bool | None` is used for the return type of `__exit__`, basedpyright will assume that all `@contextllib.contextmanager`'s have the ability to suppress exceptions when `strictContextManagerExitTypes` is enabled. + +as a workaround, it's recommended to instead use class context managers [like in the examples above](#the-problem) for the following reasons: + +- it forces you to be explicit about whether or not the context manager is able to suppress an exception +- it prevents you from accidentally creating a context manager that doesn't run its cleanup if an exception occurs: + ```py + @contextmanager + def suppress_error(): + print("setup") + yield + # this part won't run if an exception is raised because you forgot to use a try/finally + print("cleanup") + ``` + +if you're dealing with third party modules where the usage of `@contextmanager` decorator is unavoidable, it may be best to just disable `strictContextManagerExitTypes` instead. diff --git a/docs/configuration.md b/docs/configuration.md index d87a42bb9..71cc2ea9a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -77,6 +77,8 @@ The following settings determine how different types should be evaluated. - **strictGenericNarrowing** [boolean]: When a type is narrowed in such a way that its type parameters are not known (eg. using an `isinstance` check), basedpyright will resolve the type parameter to the generic's bound or constraint instead of `Any`. [more info](../benefits-over-pyright/improved-generic-narrowing.md) +- **strictContextManagerExitTypes** [boolean]: Assume that a context manager could potentially suppress an exception if its `__exit__` method is typed as returning `bool | None`. [more info](../benefits-over-pyright/fixed-context-manager-exit-types.md) + ## Diagnostic Categories diagnostics can be configured to be reported as any of the following categories: diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts index 528067194..10d050016 100644 --- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts +++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts @@ -57,6 +57,7 @@ import { isTypeSame, isTypeVar, isTypeVarTuple, + isUnion, maxTypeRecursionCount, NeverType, OverloadedType, @@ -1964,8 +1965,18 @@ export function getCodeFlowEngine( } cmSwallowsExceptions = false; - if (isClassInstance(returnType) && ClassType.isBuiltIn(returnType, 'bool')) { - if (returnType.priv.literalValue === undefined || returnType.priv.literalValue === true) { + // valid return types here are `bool | None`. if the context manager returns `True` then it suppresses, + // meaning we only know for sure that the context manager can't swallow exceptions if its return type + // does not allow `True`. + const typesToCheck = + getFileInfo(node).diagnosticRuleSet.strictContextManagerExitTypes && isUnion(returnType) + ? returnType.priv.subtypes + : [returnType]; + const boolType = typesToCheck.find( + (type): type is ClassType => isClassInstance(type) && ClassType.isBuiltIn(type, 'bool') + ); + if (boolType) { + if (boolType.priv.literalValue === undefined || boolType.priv.literalValue === true) { cmSwallowsExceptions = true; } } diff --git a/packages/pyright-internal/src/common/configOptions.ts b/packages/pyright-internal/src/common/configOptions.ts index fb7795579..1f4439e54 100644 --- a/packages/pyright-internal/src/common/configOptions.ts +++ b/packages/pyright-internal/src/common/configOptions.ts @@ -412,6 +412,7 @@ export interface DiagnosticRuleSet { */ failOnWarnings: boolean; strictGenericNarrowing: boolean; + strictContextManagerExitTypes: boolean; reportUnreachable: DiagnosticLevel; reportAny: DiagnosticLevel; reportExplicitAny: DiagnosticLevel; @@ -443,6 +444,7 @@ export function getBooleanDiagnosticRules(includeNonOverridable = false) { DiagnosticRule.deprecateTypingAliases, DiagnosticRule.disableBytesTypePromotions, DiagnosticRule.strictGenericNarrowing, + DiagnosticRule.strictContextManagerExitTypes, ]; if (includeNonOverridable) { @@ -676,6 +678,7 @@ export function getOffDiagnosticRuleSet(): DiagnosticRuleSet { reportImplicitOverride: 'none', failOnWarnings: false, strictGenericNarrowing: false, + strictContextManagerExitTypes: false, reportUnreachable: 'hint', reportAny: 'none', reportExplicitAny: 'none', @@ -792,6 +795,7 @@ export function getBasicDiagnosticRuleSet(): DiagnosticRuleSet { reportImplicitOverride: 'none', failOnWarnings: false, strictGenericNarrowing: false, + strictContextManagerExitTypes: false, reportUnreachable: 'hint', reportAny: 'none', reportExplicitAny: 'none', @@ -908,6 +912,7 @@ export function getStandardDiagnosticRuleSet(): DiagnosticRuleSet { reportImplicitOverride: 'none', failOnWarnings: false, strictGenericNarrowing: false, + strictContextManagerExitTypes: false, reportUnreachable: 'hint', reportAny: 'none', reportExplicitAny: 'none', @@ -1023,6 +1028,7 @@ export const getRecommendedDiagnosticRuleSet = (): DiagnosticRuleSet => ({ reportImplicitOverride: 'warning', failOnWarnings: true, strictGenericNarrowing: true, + strictContextManagerExitTypes: true, reportUnreachable: 'warning', reportAny: 'warning', reportExplicitAny: 'warning', @@ -1135,6 +1141,7 @@ export const getAllDiagnosticRuleSet = (): DiagnosticRuleSet => ({ reportImplicitOverride: 'error', failOnWarnings: true, strictGenericNarrowing: true, + strictContextManagerExitTypes: true, reportUnreachable: 'error', reportAny: 'error', reportExplicitAny: 'error', @@ -1248,6 +1255,7 @@ export function getStrictDiagnosticRuleSet(): DiagnosticRuleSet { reportImplicitOverride: 'none', failOnWarnings: false, strictGenericNarrowing: false, + strictContextManagerExitTypes: false, reportUnreachable: 'hint', reportAny: 'none', reportExplicitAny: 'none', diff --git a/packages/pyright-internal/src/common/diagnosticRules.ts b/packages/pyright-internal/src/common/diagnosticRules.ts index a6d564d52..9245a2981 100644 --- a/packages/pyright-internal/src/common/diagnosticRules.ts +++ b/packages/pyright-internal/src/common/diagnosticRules.ts @@ -107,6 +107,7 @@ export enum DiagnosticRule { // basedpyright options: failOnWarnings = 'failOnWarnings', strictGenericNarrowing = 'strictGenericNarrowing', + strictContextManagerExitTypes = 'strictContextManagerExitTypes', reportUnreachable = 'reportUnreachable', reportAny = 'reportAny', reportExplicitAny = 'reportExplicitAny', diff --git a/packages/pyright-internal/src/tests/checker.test.ts b/packages/pyright-internal/src/tests/checker.test.ts index 58a959784..04308efea 100644 --- a/packages/pyright-internal/src/tests/checker.test.ts +++ b/packages/pyright-internal/src/tests/checker.test.ts @@ -170,6 +170,33 @@ test('With2', () => { TestUtils.validateResults(analysisResults, 3); }); +describe('context manager where __exit__ returns bool | None', () => { + test('strictContextManagerExitTypes=true', () => { + const configOptions = new ConfigOptions(Uri.empty()); + configOptions.diagnosticRuleSet.strictContextManagerExitTypes = true; + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py'], configOptions); + TestUtils.validateResultsButBased(analysisResults, { + hints: [ + { code: DiagnosticRule.reportUnreachable, line: 45 }, + { code: DiagnosticRule.reportUnreachable, line: 60 }, + ], + }); + }); + test('strictContextManagerExitTypes=false', () => { + const configOptions = new ConfigOptions(Uri.empty()); + configOptions.diagnosticRuleSet.strictContextManagerExitTypes = false; + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py'], configOptions); + TestUtils.validateResultsButBased(analysisResults, { + hints: [ + { code: DiagnosticRule.reportUnreachable, line: 16 }, + { code: DiagnosticRule.reportUnreachable, line: 30 }, + { code: DiagnosticRule.reportUnreachable, line: 45 }, + { code: DiagnosticRule.reportUnreachable, line: 60 }, + ], + }); + }); +}); + test('With3', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['with3.py']); diff --git a/packages/pyright-internal/src/tests/samples/withBased.py b/packages/pyright-internal/src/tests/samples/withBased.py new file mode 100644 index 000000000..61dc2481d --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/withBased.py @@ -0,0 +1,61 @@ +import contextlib +from types import TracebackType +from typing import Literal + +class BoolOrNone(contextlib.AbstractContextManager[None]): + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + ... + +def _(): + with BoolOrNone(): + raise Exception + print(1) # reachable + +class TrueOrNone(contextlib.AbstractContextManager[None]): + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> Literal[True] | None: + ... + +def _(): + with TrueOrNone(): + raise Exception + print(1) # reachable + + +class FalseOrNone(contextlib.AbstractContextManager[None]): + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> Literal[False] | None: + ... + +def _(): + with FalseOrNone(): + raise Exception + print(1) # unreachable + + +class OnlyNone(contextlib.AbstractContextManager[None]): + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> None: + ... + +def _(): + with OnlyNone(): + raise Exception + print(1) # unreachable \ No newline at end of file diff --git a/packages/pyright-internal/src/tests/testUtils.ts b/packages/pyright-internal/src/tests/testUtils.ts index d9b7debce..eae163f38 100644 --- a/packages/pyright-internal/src/tests/testUtils.ts +++ b/packages/pyright-internal/src/tests/testUtils.ts @@ -260,7 +260,11 @@ export const validateResultsButBased = (allResults: FileAnalysisResult[], expect baselined: result.baselined, }) ); - const expectedResult = expectedResults[diagnosticType] ?? []; - expect(new Set(actualResult)).toEqual(new Set(expectedResult.map(expect.objectContaining))); + const expectedResult = expectedResults[diagnosticType]; + // if it's explicitly in the expected results as undefined, that means we don't care. + // if it's not in the expected results at all, then check it + if (!(diagnosticType in expectedResults) || expectedResult !== undefined) { + expect(new Set(actualResult)).toEqual(new Set((expectedResult ?? []).map(expect.objectContaining))); + } } };