diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts index 05456ceb25..211faa50c6 100644 --- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts +++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts @@ -898,7 +898,8 @@ export function getCodeFlowEngine( } } - const effectiveType = typesToCombine.length > 0 ? combineTypes(typesToCombine) : undefined; + const effectiveType = + typesToCombine.length > 0 ? combineTypes(typesToCombine, undefined, evaluator) : undefined; return setCacheEntry(branchNode, effectiveType, sawIncomplete); } diff --git a/packages/pyright-internal/src/analyzer/operations.ts b/packages/pyright-internal/src/analyzer/operations.ts index afc08dfb84..7330ced158 100644 --- a/packages/pyright-internal/src/analyzer/operations.ts +++ b/packages/pyright-internal/src/analyzer/operations.ts @@ -656,7 +656,7 @@ export function getTypeOfBinaryOperation( flags | EvaluatorFlags.ExpectingInstantiableType ); - let newUnion = combineTypes([adjustedLeftType, adjustedRightType]); + let newUnion = combineTypes([adjustedLeftType, adjustedRightType], undefined, evaluator); const unionClass = evaluator.getUnionClassType(); if (unionClass && isInstantiableClass(unionClass)) { diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index 924b2302b9..ede3670185 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -12,6 +12,7 @@ import { Uri } from '../common/uri/uri'; import { ArgumentNode, ExpressionNode, NameNode, ParameterCategory } from '../parser/parseNodes'; import { ClassDeclaration, FunctionDeclaration, SpecialBuiltInClassDeclaration } from './declaration'; import { Symbol, SymbolTable } from './symbol'; +import { TypeEvaluator } from './typeEvaluatorTypes'; export const enum TypeCategory { // Name is not bound to a value of any type. @@ -3237,7 +3238,6 @@ export function removeUnbound(type: Type): Type { return type; } - export function removeFromUnion(type: Type, removeFilter: (type: Type) => boolean) { if (isUnion(type)) { const remainingTypes = type.subtypes.filter((t) => !removeFilter(t)); @@ -3265,11 +3265,28 @@ export function findSubtype(type: Type, filter: (type: UnionableType | NeverType return filter(type) ? type : undefined; } -// Combines multiple types into a single type. If the types are -// the same, only one is returned. If they differ, they -// are combined into a UnionType. NeverTypes are filtered out. -// If no types remain in the end, a NeverType is returned. -export function combineTypes(subtypes: Type[], maxSubtypeCount?: number): Type { +/** + * Combines multiple types into a single type. If the types are + * the same, only one is returned. If they differ, they + * are combined into a UnionType. NeverTypes are filtered out. + * If no types remain in the end, a NeverType is returned. + * + * if a {@link TypeEvaluator} is provided, it not only checks that + * the types aren't the same, but also prevents redundant subtypes from + * being added to the union. eg. adding `Literal[1]` to a union of `int | str` + * is useless, so the union is left as-is. when adding a supertype to a union + * that contains a subtype of it, that subtype becomes redundant and therefore + * gets removed (eg. adding `int` to `Literal[1] | str` will result in + * `int | str`). this is useful to prevent cases where a narrowed type would be + * treated as partially unknown unnecessarily (eg. `object | list[Any]`). + * + * a {@link TypeEvaluator} should not be provided in cases where the union + * intentionally contains redundant information for the purpose of autocomplete. + * i don't think there are any situations where this is supported currently, but + * it's something to keep in mind if we end up implementing + * https://github.com/DetachHead/basedpyright/issues/320 + */ +export function combineTypes(subtypes: Type[], maxSubtypeCount?: number, evaluator?: TypeEvaluator): Type { // Filter out any "Never" and "NoReturn" types. let sawNoReturn = false; @@ -3352,7 +3369,7 @@ export function combineTypes(subtypes: Type[], maxSubtypeCount?: number): Type { return UnknownType.create(); } - const newUnionType = UnionType.create(); + let newUnionType = UnionType.create(); if (typeAliasSources.size > 0) { newUnionType.typeAliasSources = typeAliasSources; } @@ -3360,9 +3377,34 @@ export function combineTypes(subtypes: Type[], maxSubtypeCount?: number): Type { let hitMaxSubtypeCount = false; expandedTypes.forEach((subtype, index) => { - if (index === 0) { - UnionType.addType(newUnionType, subtype as UnionableType); + let shouldAddType = false; + if ( + !evaluator || + !newUnionType.subtypes.length || + evaluator.assignType(newUnionType, subtype, undefined, undefined, undefined) + ) { + if (index === 0) { + UnionType.addType(newUnionType, subtype as UnionableType); + } else { + shouldAddType = true; + } + } else if (evaluator.assignType(subtype, newUnionType, undefined, undefined, undefined)) { + if (evaluator) { + const filteredType = removeFromUnion(newUnionType, (type) => evaluator.assignType(subtype, type)); + if (isUnion(filteredType)) { + newUnionType = filteredType; + } else { + newUnionType = UnionType.create(); + if (filteredType.category !== TypeCategory.Never) { + UnionType.addType(newUnionType, filteredType as UnionableType); + } + } + shouldAddType = true; + } } else { + shouldAddType = true; + } + if (shouldAddType) { if (maxSubtypeCount === undefined || newUnionType.subtypes.length < maxSubtypeCount) { _addTypeIfUnique(newUnionType, subtype as UnionableType); } else { diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingBased.py b/packages/pyright-internal/src/tests/samples/typeNarrowingBased.py new file mode 100644 index 0000000000..b34d65f44e --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingBased.py @@ -0,0 +1,16 @@ +from typing import Any, assert_type + + +def foo(value: object): + print(value) + if isinstance(value, list): + _ = assert_type(value, list[Any]) + _ = assert_type(value, object) + +def bar(value: object): + print(value) + if isinstance(value, list): + _ = assert_type(value, list[Any]) + else: + _ = assert_type(value, object) + _ = assert_type(value, object) \ No newline at end of file diff --git a/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts b/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts index a13d1f91b0..ad368c6ee6 100644 --- a/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts @@ -117,3 +117,12 @@ test('subscript context manager types on 3.8', () => { ], }); }); + +test("useless type isn't added to union after if statement", () => { + const configOptions = new ConfigOptions(Uri.empty()); + configOptions.diagnosticRuleSet.reportAssertTypeFailure = 'error'; + const analysisResults = typeAnalyzeSampleFiles(['typeNarrowingBased.py'], configOptions); + validateResultsButBased(analysisResults, { + errors: [], + }); +});