diff --git a/docs/benefits-over-pyright/fixed-context-manager-exit-types.md b/docs/benefits-over-pyright/fixed-context-manager-exit-types.md
new file mode 100644
index 000000000..d776b42e6
--- /dev/null
+++ b/docs/benefits-over-pyright/fixed-context-manager-exit-types.md
@@ -0,0 +1,153 @@
+# fixed handling for context managers that can suppress exceptions
+
+## the problem
+
+if an exception is raised inside a context manager and its `__exit__` method returns `True`, it will be suppressed:
+
+```py
+class SuppressError(AbstractContextManager[None, bool]):
+ @override
+ def __enter__(self) -> None:
+ pass
+
+ @override
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ /,
+ ) -> bool:
+ return True
+```
+
+but if it returns `False` or `None`, the exception will not be suppressed:
+
+```py
+class Log(AbstractContextManager[None, Literal[False]]):
+ @override
+ def __enter__(self) -> None:
+ print("entering context manager")
+
+ @override
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ /,
+ ) -> Literal[False]:
+ print("exiting context manager")
+ return False
+```
+
+pyright will take this into account when determining reachability:
+
+```py
+def raise_exception() -> Never:
+ raise Exception
+
+with SuppressError(): # see definition for `SuppressError` above
+ foo: int = raise_exception()
+
+# when the exception is raised, the context manager exits before foo is assigned to:
+print(foo) # error: "foo" is unbound (reportPossiblyUnboundVariable)
+```
+
+```py
+with Log(): # see definition for `Log` above
+ foo: int = raise_exception()
+
+# when the exception is raised, it does not get suppressed so this line can never run:
+print(foo) # error: Code is unreachable (reportUnreachable)
+```
+
+however, due to [a bug in mypy](https://github.com/python/mypy/issues/8766) that [pyright blindly copied and accepted as the "standard"](https://github.com/microsoft/pyright/issues/6034#issuecomment-1738941412), a context manager will incorrectly be treated as if it never suppresses exceptions if its return type is a union of `bool | None`:
+
+```py
+class SuppressError(AbstractContextManager[None, bool | None]):
+ @override
+ def __enter__(self) -> None:
+ pass
+
+ @override
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ /,
+ ) -> bool | None:
+ return True
+
+
+with SuppressError():
+ foo: int = raise_exception()
+
+# this error is wrong because this line is actually reached at runtime:
+print(foo) # error: Code is unreachable (reportUnreachable)
+```
+
+## the solution
+
+basedpyright introduces a new setting, `strictContextManagerExitTypes` to address this issue. when enabled, context managers where the `__exit__` dunder returns `bool | None` are treated the same way as context managers that return `bool` or `Literal[True]`. put simply, if `True` is assignable to the return type, then it's treated as if it can suppress exceptions.
+
+## issues with `@contextmanager`
+
+the reason we support disabling this fix using the `strictContextManagerExitTypes` setting is because it will cause all context managers decorated with `@contextlib.contextmanager` to be treated as if they can suppress an exception, even if they never do:
+
+```py
+@contextmanager
+def log():
+ print("entering context manager")
+ try:
+ yield
+ finally:
+ print("exiting context manager")
+
+with log():
+ foo: int = get_value()
+
+# basedpyright accounts for the possibility that get_value raised an exception and foo
+# was never assigned to, even though this context manager never suppresses exceptions
+print(foo) # error: "foo" is unbound (reportPossiblyUnboundVariable)
+```
+
+this is because there's no way to tell a type checker whether the function body wraps the `yield` statement inside a `try`/`except` statement, which is necessary to suppress exeptions when using the `@contextmanager` decorator:
+
+```py
+@contextmanager
+def suppress_error():
+ try:
+ yield
+ except:
+ pass
+```
+
+due to this limitation in the type system, the `@contextmanager` dectorator always modifies the return type of generator functions from `Iterator[T]` to `_GeneratorContextManager[T]`, which extends `AbstractContextManager[T, bool | None]`.
+
+```py
+# contextlib.pyi
+
+def contextmanager(func: Callable[_P, Iterator[_T_co]]) -> Callable[_P, _GeneratorContextManager[_T_co]]: ...
+
+class _GeneratorContextManager(_GeneratorContextManagerBase, AbstractContextManager[_T_co, bool | None], ContextDecorator):
+ ...
+```
+
+and since `bool | None` is used for the return type of `__exit__`, basedpyright will assume that all `@contextllib.contextmanager`'s have the ability to suppress exceptions when `strictContextManagerExitTypes` is enabled.
+
+as a workaround, it's recommended to instead use class context managers [like in the examples above](#the-problem) for the following reasons:
+
+- it forces you to be explicit about whether or not the context manager is able to suppress an exception
+- it prevents you from accidentally creating a context manager that doesn't run its cleanup if an exception occurs:
+ ```py
+ @contextmanager
+ def suppress_error():
+ print("setup")
+ yield
+ # this part won't run if an exception is raised because you forgot to use a try/finally
+ print("cleanup")
+ ```
+
+if you're dealing with third party modules where the usage of `@contextmanager` decorator is unavoidable, it may be best to just disable `strictContextManagerExitTypes` instead.
diff --git a/docs/configuration.md b/docs/configuration.md
index d87a42bb9..71cc2ea9a 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -77,6 +77,8 @@ The following settings determine how different types should be evaluated.
- **strictGenericNarrowing** [boolean]: When a type is narrowed in such a way that its type parameters are not known (eg. using an `isinstance` check), basedpyright will resolve the type parameter to the generic's bound or constraint instead of `Any`. [more info](../benefits-over-pyright/improved-generic-narrowing.md)
+- **strictContextManagerExitTypes** [boolean]: Assume that a context manager could potentially suppress an exception if its `__exit__` method is typed as returning `bool | None`. [more info](../benefits-over-pyright/fixed-context-manager-exit-types.md)
+
## Diagnostic Categories
diagnostics can be configured to be reported as any of the following categories:
diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts
index 528067194..10d050016 100644
--- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts
+++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts
@@ -57,6 +57,7 @@ import {
isTypeSame,
isTypeVar,
isTypeVarTuple,
+ isUnion,
maxTypeRecursionCount,
NeverType,
OverloadedType,
@@ -1964,8 +1965,18 @@ export function getCodeFlowEngine(
}
cmSwallowsExceptions = false;
- if (isClassInstance(returnType) && ClassType.isBuiltIn(returnType, 'bool')) {
- if (returnType.priv.literalValue === undefined || returnType.priv.literalValue === true) {
+ // valid return types here are `bool | None`. if the context manager returns `True` then it suppresses,
+ // meaning we only know for sure that the context manager can't swallow exceptions if its return type
+ // does not allow `True`.
+ const typesToCheck =
+ getFileInfo(node).diagnosticRuleSet.strictContextManagerExitTypes && isUnion(returnType)
+ ? returnType.priv.subtypes
+ : [returnType];
+ const boolType = typesToCheck.find(
+ (type): type is ClassType => isClassInstance(type) && ClassType.isBuiltIn(type, 'bool')
+ );
+ if (boolType) {
+ if (boolType.priv.literalValue === undefined || boolType.priv.literalValue === true) {
cmSwallowsExceptions = true;
}
}
diff --git a/packages/pyright-internal/src/common/configOptions.ts b/packages/pyright-internal/src/common/configOptions.ts
index fb7795579..1f4439e54 100644
--- a/packages/pyright-internal/src/common/configOptions.ts
+++ b/packages/pyright-internal/src/common/configOptions.ts
@@ -412,6 +412,7 @@ export interface DiagnosticRuleSet {
*/
failOnWarnings: boolean;
strictGenericNarrowing: boolean;
+ strictContextManagerExitTypes: boolean;
reportUnreachable: DiagnosticLevel;
reportAny: DiagnosticLevel;
reportExplicitAny: DiagnosticLevel;
@@ -443,6 +444,7 @@ export function getBooleanDiagnosticRules(includeNonOverridable = false) {
DiagnosticRule.deprecateTypingAliases,
DiagnosticRule.disableBytesTypePromotions,
DiagnosticRule.strictGenericNarrowing,
+ DiagnosticRule.strictContextManagerExitTypes,
];
if (includeNonOverridable) {
@@ -676,6 +678,7 @@ export function getOffDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
+ strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
@@ -792,6 +795,7 @@ export function getBasicDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
+ strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
@@ -908,6 +912,7 @@ export function getStandardDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
+ strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
@@ -1023,6 +1028,7 @@ export const getRecommendedDiagnosticRuleSet = (): DiagnosticRuleSet => ({
reportImplicitOverride: 'warning',
failOnWarnings: true,
strictGenericNarrowing: true,
+ strictContextManagerExitTypes: true,
reportUnreachable: 'warning',
reportAny: 'warning',
reportExplicitAny: 'warning',
@@ -1135,6 +1141,7 @@ export const getAllDiagnosticRuleSet = (): DiagnosticRuleSet => ({
reportImplicitOverride: 'error',
failOnWarnings: true,
strictGenericNarrowing: true,
+ strictContextManagerExitTypes: true,
reportUnreachable: 'error',
reportAny: 'error',
reportExplicitAny: 'error',
@@ -1248,6 +1255,7 @@ export function getStrictDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
+ strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
diff --git a/packages/pyright-internal/src/common/diagnosticRules.ts b/packages/pyright-internal/src/common/diagnosticRules.ts
index a6d564d52..9245a2981 100644
--- a/packages/pyright-internal/src/common/diagnosticRules.ts
+++ b/packages/pyright-internal/src/common/diagnosticRules.ts
@@ -107,6 +107,7 @@ export enum DiagnosticRule {
// basedpyright options:
failOnWarnings = 'failOnWarnings',
strictGenericNarrowing = 'strictGenericNarrowing',
+ strictContextManagerExitTypes = 'strictContextManagerExitTypes',
reportUnreachable = 'reportUnreachable',
reportAny = 'reportAny',
reportExplicitAny = 'reportExplicitAny',
diff --git a/packages/pyright-internal/src/tests/checker.test.ts b/packages/pyright-internal/src/tests/checker.test.ts
index 58a959784..04308efea 100644
--- a/packages/pyright-internal/src/tests/checker.test.ts
+++ b/packages/pyright-internal/src/tests/checker.test.ts
@@ -170,6 +170,33 @@ test('With2', () => {
TestUtils.validateResults(analysisResults, 3);
});
+describe('context manager where __exit__ returns bool | None', () => {
+ test('strictContextManagerExitTypes=true', () => {
+ const configOptions = new ConfigOptions(Uri.empty());
+ configOptions.diagnosticRuleSet.strictContextManagerExitTypes = true;
+ const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py'], configOptions);
+ TestUtils.validateResultsButBased(analysisResults, {
+ hints: [
+ { code: DiagnosticRule.reportUnreachable, line: 45 },
+ { code: DiagnosticRule.reportUnreachable, line: 60 },
+ ],
+ });
+ });
+ test('strictContextManagerExitTypes=false', () => {
+ const configOptions = new ConfigOptions(Uri.empty());
+ configOptions.diagnosticRuleSet.strictContextManagerExitTypes = false;
+ const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py'], configOptions);
+ TestUtils.validateResultsButBased(analysisResults, {
+ hints: [
+ { code: DiagnosticRule.reportUnreachable, line: 16 },
+ { code: DiagnosticRule.reportUnreachable, line: 30 },
+ { code: DiagnosticRule.reportUnreachable, line: 45 },
+ { code: DiagnosticRule.reportUnreachable, line: 60 },
+ ],
+ });
+ });
+});
+
test('With3', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['with3.py']);
diff --git a/packages/pyright-internal/src/tests/samples/withBased.py b/packages/pyright-internal/src/tests/samples/withBased.py
new file mode 100644
index 000000000..61dc2481d
--- /dev/null
+++ b/packages/pyright-internal/src/tests/samples/withBased.py
@@ -0,0 +1,61 @@
+import contextlib
+from types import TracebackType
+from typing import Literal
+
+class BoolOrNone(contextlib.AbstractContextManager[None]):
+ def __exit__(
+ self,
+ __exc_type: type[BaseException] | None,
+ __exc_value: BaseException | None,
+ __traceback: TracebackType | None,
+ ) -> bool | None:
+ ...
+
+def _():
+ with BoolOrNone():
+ raise Exception
+ print(1) # reachable
+
+class TrueOrNone(contextlib.AbstractContextManager[None]):
+ def __exit__(
+ self,
+ __exc_type: type[BaseException] | None,
+ __exc_value: BaseException | None,
+ __traceback: TracebackType | None,
+ ) -> Literal[True] | None:
+ ...
+
+def _():
+ with TrueOrNone():
+ raise Exception
+ print(1) # reachable
+
+
+class FalseOrNone(contextlib.AbstractContextManager[None]):
+ def __exit__(
+ self,
+ __exc_type: type[BaseException] | None,
+ __exc_value: BaseException | None,
+ __traceback: TracebackType | None,
+ ) -> Literal[False] | None:
+ ...
+
+def _():
+ with FalseOrNone():
+ raise Exception
+ print(1) # unreachable
+
+
+class OnlyNone(contextlib.AbstractContextManager[None]):
+ def __exit__(
+ self,
+ __exc_type: type[BaseException] | None,
+ __exc_value: BaseException | None,
+ __traceback: TracebackType | None,
+ ) -> None:
+ ...
+
+def _():
+ with OnlyNone():
+ raise Exception
+ print(1) # unreachable
\ No newline at end of file
diff --git a/packages/pyright-internal/src/tests/testUtils.ts b/packages/pyright-internal/src/tests/testUtils.ts
index d9b7debce..eae163f38 100644
--- a/packages/pyright-internal/src/tests/testUtils.ts
+++ b/packages/pyright-internal/src/tests/testUtils.ts
@@ -260,7 +260,11 @@ export const validateResultsButBased = (allResults: FileAnalysisResult[], expect
baselined: result.baselined,
})
);
- const expectedResult = expectedResults[diagnosticType] ?? [];
- expect(new Set(actualResult)).toEqual(new Set(expectedResult.map(expect.objectContaining)));
+ const expectedResult = expectedResults[diagnosticType];
+ // if it's explicitly in the expected results as undefined, that means we don't care.
+ // if it's not in the expected results at all, then check it
+ if (!(diagnosticType in expectedResults) || expectedResult !== undefined) {
+ expect(new Set(actualResult)).toEqual(new Set((expectedResult ?? []).map(expect.objectContaining)));
+ }
}
};