Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix context managers that return bool | None incorrectly being treated as if they can never suppress exceptions #111

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions docs/benefits-over-pyright/fixed-context-manager-exit-types.md
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe something like this instead

@contextmanager
def f() -> Generator[None, BaseException, True]:
    error = yield
    return True

Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# fixed handling for context managers that can suppress exceptions

## the problem

if an exception is raised inside a context manager and its `__exit__` method returns `True`, it will be suppressed:

```py
class SuppressError(AbstractContextManager[None, bool]):
@override
def __enter__(self) -> None:
pass

@override
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
/,
) -> bool:
return True
```

but if it returns `False` or `None`, the exception will not be suppressed:

```py
class Log(AbstractContextManager[None, Literal[False]]):
@override
def __enter__(self) -> None:
print("entering context manager")

@override
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
/,
) -> Literal[False]:
print("exiting context manager")
return False
```

pyright will take this into account when determining reachability:

```py
def raise_exception() -> Never:
raise Exception

with SuppressError(): # see definition for `SuppressError` above
foo: int = raise_exception()

# when the exception is raised, the context manager exits before foo is assigned to:
print(foo) # error: "foo" is unbound (reportPossiblyUnboundVariable)
```

```py
with Log(): # see definition for `Log` above
foo: int = raise_exception()

# when the exception is raised, it does not get suppressed so this line can never run:
print(foo) # error: Code is unreachable (reportUnreachable)
```

however, due to [a bug in mypy](https://github.com/python/mypy/issues/8766) that [pyright blindly copied and accepted as the "standard"](https://github.com/microsoft/pyright/issues/6034#issuecomment-1738941412), a context manager will incorrectly be treated as if it never suppresses exceptions if its return type is a union of `bool | None`:

```py
class SuppressError(AbstractContextManager[None, bool | None]):
@override
def __enter__(self) -> None:
pass

@override
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
/,
) -> bool | None:
return True


with SuppressError():
foo: int = raise_exception()

# this error is wrong because this line is actually reached at runtime:
print(foo) # error: Code is unreachable (reportUnreachable)
```

## the solution

basedpyright introduces a new setting, `strictContextManagerExitTypes` to address this issue. when enabled, context managers where the `__exit__` dunder returns `bool | None` are treated the same way as context managers that return `bool` or `Literal[True]`. put simply, if `True` is assignable to the return type, then it's treated as if it can suppress exceptions.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this to the playground


## issues with `@contextmanager`

the reason we support disabling this fix using the `strictContextManagerExitTypes` setting is because it will cause all context managers decorated with `@contextlib.contextmanager` to be treated as if they can suppress an exception, even if they never do:
DetachHead marked this conversation as resolved.
Show resolved Hide resolved

```py
@contextmanager
def log():
print("entering context manager")
try:
yield
finally:
print("exiting context manager")

with log():
foo: int = get_value()

# basedpyright accounts for the possibility that get_value raised an exception and foo
# was never assigned to, even though this context manager never suppresses exceptions
print(foo) # error: "foo" is unbound (reportPossiblyUnboundVariable)
```

this is because there's no way to tell a type checker whether the function body wraps the `yield` statement inside a `try`/`except` statement, which is necessary to suppress exeptions when using the `@contextmanager` decorator:

```py
@contextmanager
def suppress_error():
try:
yield
except:
pass
```

due to this limitation in the type system, the `@contextmanager` dectorator always modifies the return type of generator functions from `Iterator[T]` to `_GeneratorContextManager[T]`, which extends `AbstractContextManager[T, bool | None]`.

```py
# contextlib.pyi

def contextmanager(func: Callable[_P, Iterator[_T_co]]) -> Callable[_P, _GeneratorContextManager[_T_co]]: ...

class _GeneratorContextManager(_GeneratorContextManagerBase, AbstractContextManager[_T_co, bool | None], ContextDecorator):
...
```

and since `bool | None` is used for the return type of `__exit__`, basedpyright will assume that all `@contextllib.contextmanager`'s have the ability to suppress exceptions when `strictContextManagerExitTypes` is enabled.

as a workaround, it's recommended to instead use class context managers [like in the examples above](#the-problem) for the following reasons:

- it forces you to be explicit about whether or not the context manager is able to suppress an exception
- it prevents you from accidentally creating a context manager that doesn't run its cleanup if an exception occurs:
```py
@contextmanager
def suppress_error():
print("setup")
yield
# this part won't run if an exception is raised because you forgot to use a try/finally
print("cleanup")
```

if you're dealing with third party modules where the usage of `@contextmanager` decorator is unavoidable, it may be best to just disable `strictContextManagerExitTypes` instead.
2 changes: 2 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ The following settings determine how different types should be evaluated.

- <a name="strictGenericNarrowing"></a> **strictGenericNarrowing** [boolean]: When a type is narrowed in such a way that its type parameters are not known (eg. using an `isinstance` check), basedpyright will resolve the type parameter to the generic's bound or constraint instead of `Any`. [more info](../benefits-over-pyright/improved-generic-narrowing.md)

- <a name="strictContextManagerExitTypes"></a> **strictContextManagerExitTypes** [boolean]: Assume that a context manager could potentially suppress an exception if its `__exit__` method is typed as returning `bool | None`. [more info](../benefits-over-pyright/fixed-context-manager-exit-types.md)

## Diagnostic Categories

diagnostics can be configured to be reported as any of the following categories:
Expand Down
15 changes: 13 additions & 2 deletions packages/pyright-internal/src/analyzer/codeFlowEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import {
isTypeSame,
isTypeVar,
isTypeVarTuple,
isUnion,
maxTypeRecursionCount,
NeverType,
OverloadedType,
Expand Down Expand Up @@ -1964,8 +1965,18 @@ export function getCodeFlowEngine(
}

cmSwallowsExceptions = false;
if (isClassInstance(returnType) && ClassType.isBuiltIn(returnType, 'bool')) {
if (returnType.priv.literalValue === undefined || returnType.priv.literalValue === true) {
// valid return types here are `bool | None`. if the context manager returns `True` then it suppresses,
// meaning we only know for sure that the context manager can't swallow exceptions if its return type
// does not allow `True`.
const typesToCheck =
getFileInfo(node).diagnosticRuleSet.strictContextManagerExitTypes && isUnion(returnType)
? returnType.priv.subtypes
: [returnType];
const boolType = typesToCheck.find(
(type): type is ClassType => isClassInstance(type) && ClassType.isBuiltIn(type, 'bool')
);
if (boolType) {
if (boolType.priv.literalValue === undefined || boolType.priv.literalValue === true) {
cmSwallowsExceptions = true;
}
}
Expand Down
8 changes: 8 additions & 0 deletions packages/pyright-internal/src/common/configOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ export interface DiagnosticRuleSet {
*/
failOnWarnings: boolean;
strictGenericNarrowing: boolean;
strictContextManagerExitTypes: boolean;
reportUnreachable: DiagnosticLevel;
reportAny: DiagnosticLevel;
reportExplicitAny: DiagnosticLevel;
Expand Down Expand Up @@ -443,6 +444,7 @@ export function getBooleanDiagnosticRules(includeNonOverridable = false) {
DiagnosticRule.deprecateTypingAliases,
DiagnosticRule.disableBytesTypePromotions,
DiagnosticRule.strictGenericNarrowing,
DiagnosticRule.strictContextManagerExitTypes,
];

if (includeNonOverridable) {
Expand Down Expand Up @@ -676,6 +678,7 @@ export function getOffDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
Expand Down Expand Up @@ -792,6 +795,7 @@ export function getBasicDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
Expand Down Expand Up @@ -908,6 +912,7 @@ export function getStandardDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
Expand Down Expand Up @@ -1023,6 +1028,7 @@ export const getRecommendedDiagnosticRuleSet = (): DiagnosticRuleSet => ({
reportImplicitOverride: 'warning',
failOnWarnings: true,
strictGenericNarrowing: true,
strictContextManagerExitTypes: true,
reportUnreachable: 'warning',
reportAny: 'warning',
reportExplicitAny: 'warning',
Expand Down Expand Up @@ -1135,6 +1141,7 @@ export const getAllDiagnosticRuleSet = (): DiagnosticRuleSet => ({
reportImplicitOverride: 'error',
failOnWarnings: true,
strictGenericNarrowing: true,
strictContextManagerExitTypes: true,
reportUnreachable: 'error',
reportAny: 'error',
reportExplicitAny: 'error',
Expand Down Expand Up @@ -1248,6 +1255,7 @@ export function getStrictDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
Expand Down
1 change: 1 addition & 0 deletions packages/pyright-internal/src/common/diagnosticRules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export enum DiagnosticRule {
// basedpyright options:
failOnWarnings = 'failOnWarnings',
strictGenericNarrowing = 'strictGenericNarrowing',
strictContextManagerExitTypes = 'strictContextManagerExitTypes',
reportUnreachable = 'reportUnreachable',
reportAny = 'reportAny',
reportExplicitAny = 'reportExplicitAny',
Expand Down
27 changes: 27 additions & 0 deletions packages/pyright-internal/src/tests/checker.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,33 @@ test('With2', () => {
TestUtils.validateResults(analysisResults, 3);
});

describe('context manager where __exit__ returns bool | None', () => {
test('strictContextManagerExitTypes=true', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.strictContextManagerExitTypes = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py'], configOptions);
TestUtils.validateResultsButBased(analysisResults, {
hints: [
{ code: DiagnosticRule.reportUnreachable, line: 45 },
{ code: DiagnosticRule.reportUnreachable, line: 60 },
],
});
});
test('strictContextManagerExitTypes=false', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.strictContextManagerExitTypes = false;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py'], configOptions);
TestUtils.validateResultsButBased(analysisResults, {
hints: [
{ code: DiagnosticRule.reportUnreachable, line: 16 },
{ code: DiagnosticRule.reportUnreachable, line: 30 },
{ code: DiagnosticRule.reportUnreachable, line: 45 },
{ code: DiagnosticRule.reportUnreachable, line: 60 },
],
});
});
});

test('With3', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['with3.py']);

Expand Down
61 changes: 61 additions & 0 deletions packages/pyright-internal/src/tests/samples/withBased.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import contextlib
from types import TracebackType
from typing import Literal

class BoolOrNone(contextlib.AbstractContextManager[None]):
def __exit__(
self,
__exc_type: type[BaseException] | None,
__exc_value: BaseException | None,
__traceback: TracebackType | None,
) -> bool | None:
...

def _():
with BoolOrNone():
raise Exception
print(1) # reachable

class TrueOrNone(contextlib.AbstractContextManager[None]):
def __exit__(
self,
__exc_type: type[BaseException] | None,
__exc_value: BaseException | None,
__traceback: TracebackType | None,
) -> Literal[True] | None:
...

def _():
with TrueOrNone():
raise Exception
print(1) # reachable


class FalseOrNone(contextlib.AbstractContextManager[None]):
def __exit__(
self,
__exc_type: type[BaseException] | None,
__exc_value: BaseException | None,
__traceback: TracebackType | None,
) -> Literal[False] | None:
...

def _():
with FalseOrNone():
raise Exception
print(1) # unreachable


class OnlyNone(contextlib.AbstractContextManager[None]):
def __exit__(
self,
__exc_type: type[BaseException] | None,
__exc_value: BaseException | None,
__traceback: TracebackType | None,
) -> None:
...

def _():
with OnlyNone():
raise Exception
print(1) # unreachable
8 changes: 6 additions & 2 deletions packages/pyright-internal/src/tests/testUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,11 @@ export const validateResultsButBased = (allResults: FileAnalysisResult[], expect
baselined: result.baselined,
})
);
const expectedResult = expectedResults[diagnosticType] ?? [];
expect(new Set(actualResult)).toEqual(new Set(expectedResult.map(expect.objectContaining)));
const expectedResult = expectedResults[diagnosticType];
// if it's explicitly in the expected results as undefined, that means we don't care.
// if it's not in the expected results at all, then check it
if (!(diagnosticType in expectedResults) || expectedResult !== undefined) {
expect(new Set(actualResult)).toEqual(new Set((expectedResult ?? []).map(expect.objectContaining)));
}
}
};
Loading