Skip to content

Commit

Permalink
Auto merge of rust-lang#134185 - compiler-errors:impl-trait-in-bindin…
Browse files Browse the repository at this point in the history
…gs, r=oli-obk

(Re-)Implement `impl_trait_in_bindings`

This reimplements the `impl_trait_in_bindings` feature for local bindings.

"`impl Trait` in bindings" serve as a form of *trait* ascription, where the type basically functions as an infer var but additionally registering the `impl Trait`'s trait bounds for the infer type. These trait bounds can be used to enforce that predicates hold, and can guide inference (e.g. for closure signature inference):

```rust
let _: impl Fn(&u8) -> &u8 = |x| x;
```

They are implemented as an additional set of bounds that are registered when the type is lowered during typeck, and then these bounds are tied to a given `CanonicalUserTypeAscription` for borrowck. We enforce these `CanonicalUserTypeAscription` bounds during borrowck to make sure that the `impl Trait` types are sensitive to lifetimes:

```rust
trait Static: 'static {}
impl<T> Static for T where T: 'static {}

let local = 1;
let x: impl Static = &local;
//~^ ERROR `local` does not live long enough
```

r? oli-obk

cc rust-lang#63065

---

Why can't we just use TAIT inference or something? Well, TAITs in bodies have the problem that they cannot reference lifetimes local to a body. For example:

```rust
type TAIT = impl Display;
let local = 0;
let x: TAIT = &local;
//~^ ERROR `local` does not live long enough
```

That's because TAITs requires us to do *opaque type inference* which is pretty strict, since we need to remap all of the lifetimes of the hidden type to universal regions. This is simply not possible here.

---

I consider this part of the "impl trait everywhere" experiment. I'm not certain if this needs yet another lang team experiment.
  • Loading branch information
bors committed Dec 14, 2024
2 parents ed14192 + d714a22 commit f5079d0
Show file tree
Hide file tree
Showing 53 changed files with 442 additions and 52 deletions.
19 changes: 15 additions & 4 deletions compiler/rustc_ast_lowering/src/block.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use rustc_ast::{Block, BlockCheckMode, Local, LocalKind, Stmt, StmtKind};
use rustc_hir as hir;
use rustc_span::sym;
use smallvec::SmallVec;

use crate::{ImplTraitContext, ImplTraitPosition, LoweringContext};
Expand Down Expand Up @@ -82,11 +83,21 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
(self.arena.alloc_from_iter(stmts), expr)
}

/// Return an `ImplTraitContext` that allows impl trait in bindings if
/// the feature gate is enabled, or issues a feature error if it is not.
fn impl_trait_in_bindings_ctxt(&self, position: ImplTraitPosition) -> ImplTraitContext {
if self.tcx.features().impl_trait_in_bindings() {
ImplTraitContext::InBinding
} else {
ImplTraitContext::FeatureGated(position, sym::impl_trait_in_bindings)
}
}

fn lower_local(&mut self, l: &Local) -> &'hir hir::LetStmt<'hir> {
let ty = l
.ty
.as_ref()
.map(|t| self.lower_ty(t, ImplTraitContext::Disallowed(ImplTraitPosition::Variable)));
// Let statements are allowed to have impl trait in bindings.
let ty = l.ty.as_ref().map(|t| {
self.lower_ty(t, self.impl_trait_in_bindings_ctxt(ImplTraitPosition::Variable))
});
let init = l.kind.init().map(|init| self.lower_expr(init));
let hir_id = self.lower_node_id(l.id);
let pat = self.lower_pat(&l.pat);
Expand Down
10 changes: 10 additions & 0 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,13 @@ enum ImplTraitContext {
/// equivalent to a new opaque type like `type T = impl Debug; fn foo() -> T`.
///
OpaqueTy { origin: hir::OpaqueTyOrigin<LocalDefId> },

/// Treat `impl Trait` as a "trait ascription", which is like a type
/// variable but that also enforces that a set of trait goals hold.
///
/// This is useful to guide inference for unnameable types.
InBinding,

/// `impl Trait` is unstably accepted in this position.
FeatureGated(ImplTraitPosition, Symbol),
/// `impl Trait` is not accepted in this position.
Expand Down Expand Up @@ -1327,6 +1334,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}
path
}
ImplTraitContext::InBinding => {
hir::TyKind::TraitAscription(self.lower_param_bounds(bounds, itctx))
}
ImplTraitContext::FeatureGated(position, feature) => {
let guar = self
.tcx
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_borrowck/src/type_check/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
user_ty: ty::UserType<'tcx>,
span: Span,
) {
let ty::UserType::Ty(user_ty) = user_ty else { bug!() };
let ty::UserTypeKind::Ty(user_ty) = user_ty.kind else { bug!() };

// A fast path for a common case with closure input/output types.
if let ty::Infer(_) = user_ty.kind() {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_borrowck/src/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
) {
self.ascribe_user_type_skip_wf(
arg_decl.ty,
ty::UserType::Ty(user_ty),
ty::UserType::new(ty::UserTypeKind::Ty(user_ty)),
arg_decl.source_info.span,
);
}
Expand All @@ -119,7 +119,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
let output_decl = &body.local_decls[RETURN_PLACE];
self.ascribe_user_type_skip_wf(
output_decl.ty,
ty::UserType::Ty(user_provided_sig.output()),
ty::UserType::new(ty::UserTypeKind::Ty(user_provided_sig.output())),
output_decl.source_info.span,
);
}
Expand Down
10 changes: 7 additions & 3 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use rustc_middle::ty::visit::TypeVisitableExt;
use rustc_middle::ty::{
self, Binder, CanonicalUserTypeAnnotation, CanonicalUserTypeAnnotations, CoroutineArgsExt,
Dynamic, GenericArgsRef, OpaqueHiddenType, OpaqueTypeKey, RegionVid, Ty, TyCtxt, UserArgs,
UserType, UserTypeAnnotationIndex,
UserTypeAnnotationIndex,
};
use rustc_middle::{bug, span_bug};
use rustc_mir_dataflow::ResultsCursor;
Expand Down Expand Up @@ -370,7 +370,10 @@ impl<'a, 'b, 'tcx> Visitor<'tcx> for TypeVerifier<'a, 'b, 'tcx> {
} else {
self.cx.ascribe_user_type(
constant.const_.ty(),
UserType::TypeOf(uv.def, UserArgs { args: uv.args, user_self_ty: None }),
ty::UserType::new(ty::UserTypeKind::TypeOf(uv.def, UserArgs {
args: uv.args,
user_self_ty: None,
})),
locations.span(self.cx.body),
);
}
Expand Down Expand Up @@ -991,9 +994,10 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
for user_annotation in self.user_type_annotations {
let CanonicalUserTypeAnnotation { span, ref user_ty, inferred_ty } = *user_annotation;
let annotation = self.instantiate_canonical(span, user_ty);
if let ty::UserType::TypeOf(def, args) = annotation
if let ty::UserTypeKind::TypeOf(def, args) = annotation.kind
&& let DefKind::InlineConst = tcx.def_kind(def)
{
assert!(annotation.bounds.is_empty());
self.check_inline_const(inferred_ty, def.expect_local(), args, span);
} else {
self.ascribe_user_type(inferred_ty, annotation, span);
Expand Down
3 changes: 0 additions & 3 deletions compiler/rustc_feature/src/removed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ declare_features! (
better implied higher-ranked implied bounds support"
)
),
/// Allows `impl Trait` in bindings (`let`, `const`, `static`).
(removed, impl_trait_in_bindings, "1.55.0", Some(63065),
Some("the implementation was not maintainable, the feature may get reintroduced once the current refactorings are done")),
(removed, import_shadowing, "1.0.0", None, None),
/// Allows in-band quantification of lifetime bindings (e.g., `fn foo(x: &'a u8) -> &'a u8`).
(removed, in_band_lifetimes, "1.23.0", Some(44524),
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_feature/src/unstable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ declare_features! (
(unstable, if_let_guard, "1.47.0", Some(51114)),
/// Allows `impl Trait` to be used inside associated types (RFC 2515).
(unstable, impl_trait_in_assoc_type, "1.70.0", Some(63063)),
/// Allows `impl Trait` in bindings (`let`).
(unstable, impl_trait_in_bindings, "1.64.0", Some(63065)),
/// Allows `impl Trait` as output type in `Fn` traits in return position of functions.
(unstable, impl_trait_in_fn_trait_return, "1.64.0", Some(99697)),
/// Allows associated types in inherent impls.
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2906,6 +2906,8 @@ pub enum TyKind<'hir> {
Path(QPath<'hir>),
/// An opaque type definition itself. This is only used for `impl Trait`.
OpaqueDef(&'hir OpaqueTy<'hir>),
/// A trait ascription type, which is `impl Trait` within a local binding.
TraitAscription(GenericBounds<'hir>),
/// A trait object type `Bound1 + Bound2 + Bound3`
/// where `Bound` is a trait or a lifetime.
TraitObject(&'hir [PolyTraitRef<'hir>], &'hir Lifetime, TraitObjectSyntax),
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_hir/src/intravisit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,9 @@ pub fn walk_ty<'v, V: Visitor<'v>>(visitor: &mut V, typ: &'v Ty<'v>) -> V::Resul
TyKind::OpaqueDef(opaque) => {
try_visit!(visitor.visit_opaque_ty(opaque));
}
TyKind::TraitAscription(bounds) => {
walk_list!(visitor, visit_param_bound, bounds);
}
TyKind::Array(ref ty, ref length) => {
try_visit!(visitor.visit_ty(ty));
try_visit!(visitor.visit_const_arg(length));
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_hir_analysis/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use rustc_errors::{
use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::intravisit::{self, Visitor, walk_generics};
use rustc_hir::{self as hir, GenericParamKind, Node};
use rustc_hir::{self as hir, GenericParamKind, HirId, Node};
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
use rustc_infer::traits::ObligationCause;
use rustc_middle::hir::nested_filter;
Expand Down Expand Up @@ -437,6 +437,15 @@ impl<'tcx> HirTyLowerer<'tcx> for ItemCtxt<'tcx> {
ty::Const::new_error_with_message(self.tcx(), span, "bad placeholder constant")
}

fn register_trait_ascription_bounds(
&self,
_: Vec<(ty::Clause<'tcx>, Span)>,
_: HirId,
span: Span,
) {
self.dcx().span_delayed_bug(span, "trait ascription type not allowed here");
}

fn probe_ty_param_bounds(
&self,
span: Span,
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_hir_analysis/src/collect/resolve_bound_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,21 @@ impl<'a, 'tcx> Visitor<'tcx> for BoundVarContext<'a, 'tcx> {
};
self.with(scope, |this| this.visit_ty(mt.ty));
}
hir::TyKind::TraitAscription(bounds) => {
let scope = Scope::TraitRefBoundary { s: self.scope };
self.with(scope, |this| {
let scope = Scope::LateBoundary {
s: this.scope,
what: "`impl Trait` in binding",
deny_late_regions: true,
};
this.with(scope, |this| {
for bound in bounds {
this.visit_param_bound(bound);
}
})
});
}
_ => intravisit::walk_ty(self, ty),
}
}
Expand Down
26 changes: 26 additions & 0 deletions compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ pub trait HirTyLowerer<'tcx> {
/// Returns the const to use when a const is omitted.
fn ct_infer(&self, param: Option<&ty::GenericParamDef>, span: Span) -> Const<'tcx>;

fn register_trait_ascription_bounds(
&self,
bounds: Vec<(ty::Clause<'tcx>, Span)>,
hir_id: HirId,
span: Span,
);

/// Probe bounds in scope where the bounded type coincides with the given type parameter.
///
/// Rephrased, this returns bounds of the form `T: Trait`, where `T` is a type parameter
Expand Down Expand Up @@ -2375,6 +2382,25 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {

self.lower_opaque_ty(opaque_ty.def_id, in_trait)
}
hir::TyKind::TraitAscription(hir_bounds) => {
// Impl trait in bindings lower as an infer var with additional
// set of type bounds.
let self_ty = self.ty_infer(None, hir_ty.span);
let mut bounds = Bounds::default();
self.lower_bounds(
self_ty,
hir_bounds.iter(),
&mut bounds,
ty::List::empty(),
PredicateFilter::All,
);
self.register_trait_ascription_bounds(
bounds.clauses().collect(),
hir_ty.hir_id,
hir_ty.span,
);
self_ty
}
// If we encounter a type relative path with RTN generics, then it must have
// *not* gone through `lower_ty_maybe_return_type_notation`, and therefore
// it's certainly in an illegal position.
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_hir_pretty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ impl<'a> State<'a> {
self.print_unsafe_binder(unsafe_binder);
}
hir::TyKind::OpaqueDef(..) => self.word("/*impl Trait*/"),
hir::TyKind::TraitAscription(bounds) => {
self.print_bounds("impl", bounds);
}
hir::TyKind::Path(ref qpath) => self.print_qpath(qpath, false),
hir::TyKind::TraitObject(bounds, lifetime, syntax) => {
if syntax == ast::TraitObjectSyntax::Dyn {
Expand Down
45 changes: 35 additions & 10 deletions compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use std::slice;
use rustc_abi::FieldIdx;
use rustc_data_structures::fx::FxHashSet;
use rustc_errors::{Applicability, Diag, ErrorGuaranteed, MultiSpan, StashKey};
use rustc_hir as hir;
use rustc_hir::def::{CtorOf, DefKind, Res};
use rustc_hir::def_id::DefId;
use rustc_hir::intravisit::Visitor;
use rustc_hir::lang_items::LangItem;
use rustc_hir::{ExprKind, GenericArg, HirId, Node, QPath};
use rustc_hir::{self as hir, ExprKind, GenericArg, HirId, Node, QPath, intravisit};
use rustc_hir_analysis::hir_ty_lowering::errors::GenericsArgsErrExtend;
use rustc_hir_analysis::hir_ty_lowering::generics::{
check_generic_arg_count_for_call, lower_generic_args,
Expand All @@ -25,7 +25,7 @@ use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::visit::{TypeVisitable, TypeVisitableExt};
use rustc_middle::ty::{
self, AdtKind, CanonicalUserType, GenericArgKind, GenericArgsRef, GenericParamDefKind,
IsIdentity, Ty, TyCtxt, UserArgs, UserSelfTy, UserType,
IsIdentity, Ty, TyCtxt, UserArgs, UserSelfTy,
};
use rustc_middle::{bug, span_bug};
use rustc_session::lint;
Expand Down Expand Up @@ -216,11 +216,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
debug!("fcx {}", self.tag());

if Self::can_contain_user_lifetime_bounds((args, user_self_ty)) {
let canonicalized =
self.canonicalize_user_type_annotation(UserType::TypeOf(def_id, UserArgs {
args,
user_self_ty,
}));
let canonicalized = self.canonicalize_user_type_annotation(ty::UserType::new(
ty::UserTypeKind::TypeOf(def_id, UserArgs { args, user_self_ty }),
));
debug!(?canonicalized);
self.write_user_type_annotation(hir_id, canonicalized);
}
Expand Down Expand Up @@ -462,13 +460,40 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
LoweredTy::from_raw(self, hir_ty.span, ty)
}

/// Walk a `hir_ty` and collect any clauses that may have come from a type
/// within the `hir_ty`. These clauses will be canonicalized with a user type
/// annotation so that we can enforce these bounds in borrowck, too.
pub(crate) fn collect_impl_trait_clauses_from_hir_ty(
&self,
hir_ty: &'tcx hir::Ty<'tcx>,
) -> ty::Clauses<'tcx> {
struct CollectClauses<'a, 'tcx> {
clauses: Vec<ty::Clause<'tcx>>,
fcx: &'a FnCtxt<'a, 'tcx>,
}

impl<'tcx> intravisit::Visitor<'tcx> for CollectClauses<'_, 'tcx> {
fn visit_ty(&mut self, ty: &'tcx hir::Ty<'tcx>) {
if let Some(clauses) = self.fcx.trait_ascriptions.borrow().get(&ty.hir_id.local_id)
{
self.clauses.extend(clauses.iter().cloned());
}
intravisit::walk_ty(self, ty)
}
}

let mut clauses = CollectClauses { clauses: vec![], fcx: self };
clauses.visit_ty(hir_ty);
self.tcx.mk_clauses(&clauses.clauses)
}

#[instrument(level = "debug", skip_all)]
pub(crate) fn lower_ty_saving_user_provided_ty(&self, hir_ty: &hir::Ty<'tcx>) -> Ty<'tcx> {
pub(crate) fn lower_ty_saving_user_provided_ty(&self, hir_ty: &'tcx hir::Ty<'tcx>) -> Ty<'tcx> {
let ty = self.lower_ty(hir_ty);
debug!(?ty);

if Self::can_contain_user_lifetime_bounds(ty.raw) {
let c_ty = self.canonicalize_response(UserType::Ty(ty.raw));
let c_ty = self.canonicalize_response(ty::UserType::new(ty::UserTypeKind::Ty(ty.raw)));
debug!(?c_ty);
self.typeck_results.borrow_mut().user_provided_types_mut().insert(hir_ty.hir_id, c_ty);
}
Expand Down
34 changes: 33 additions & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ use std::ops::Deref;

use hir::def_id::CRATE_DEF_ID;
use rustc_errors::DiagCtxtHandle;
use rustc_hir as hir;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::{self as hir, HirId, ItemLocalMap};
use rustc_hir_analysis::hir_ty_lowering::{HirTyLowerer, RegionInferReason};
use rustc_infer::infer;
use rustc_infer::traits::Obligation;
use rustc_middle::ty::{self, Const, Ty, TyCtxt, TypeVisitableExt};
use rustc_session::Session;
use rustc_span::symbol::Ident;
Expand Down Expand Up @@ -114,6 +115,12 @@ pub(crate) struct FnCtxt<'a, 'tcx> {

pub(super) diverging_fallback_behavior: DivergingFallbackBehavior,
pub(super) diverging_block_behavior: DivergingBlockBehavior,

/// Clauses that we lowered as part of the `impl_trait_in_bindings` feature.
///
/// These are stored here so we may collect them when canonicalizing user
/// type ascriptions later.
pub(super) trait_ascriptions: RefCell<ItemLocalMap<Vec<ty::Clause<'tcx>>>>,
}

impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
Expand Down Expand Up @@ -141,6 +148,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
fallback_has_occurred: Cell::new(false),
diverging_fallback_behavior,
diverging_block_behavior,
trait_ascriptions: Default::default(),
}
}

Expand Down Expand Up @@ -252,6 +260,30 @@ impl<'tcx> HirTyLowerer<'tcx> for FnCtxt<'_, 'tcx> {
}
}

fn register_trait_ascription_bounds(
&self,
bounds: Vec<(ty::Clause<'tcx>, Span)>,
hir_id: HirId,
_span: Span,
) {
for (clause, span) in bounds {
if clause.has_escaping_bound_vars() {
self.dcx().span_delayed_bug(span, "clause should have no escaping bound vars");
continue;
}

self.trait_ascriptions.borrow_mut().entry(hir_id.local_id).or_default().push(clause);

let clause = self.normalize(span, clause);
self.register_predicate(Obligation::new(
self.tcx,
self.misc(span),
self.param_env,
clause,
));
}
}

fn probe_ty_param_bounds(
&self,
_: Span,
Expand Down
Loading

0 comments on commit f5079d0

Please sign in to comment.