Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated snapshot injectivity axiom. #1475

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions prusti-tests/tests/verify/pass/pure-fn/pure_taking_self.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use prusti_contracts::*;

#[derive(Clone, Copy)]
struct MyWrapper(u32);

impl MyWrapper {
#[pure]
#[ensures(self === MyWrapper(result))]
fn unwrap(self) -> u32 {
self.0
}
}

fn test(x: MyWrapper) -> u32 {
x.unwrap()
}

fn main() { }
41 changes: 35 additions & 6 deletions prusti-viper/src/encoder/definition_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,29 @@ impl<'p, 'v: 'p, 'tcx: 'v> Collector<'p, 'v, 'tcx> {
functions.sort_by_cached_key(|f| f.get_identifier());
Ok(functions)
}

fn get_domain_functions_used_in_axioms(axioms: &Vec<vir::DomainAxiom>) -> FxHashSet<String> {
struct Walker {
function_names: FxHashSet<String>,
}

impl ExprWalker for Walker {
fn walk_domain_func_app(&mut self, function_call: &vir::DomainFuncApp) {
self.function_names
.insert(function_call.domain_function.get_identifier());
}
}
let mut functions_in_axioms: FxHashSet<String> = Default::default();
for axiom in axioms {
let mut walker = Walker {
function_names: Default::default(),
};
walker.walk(&axiom.expr);
functions_in_axioms.extend(walker.function_names);
}
functions_in_axioms
}

fn get_used_domains(&self) -> Vec<vir::Domain> {
let mut domains: Vec<_> = self
.used_domains
Expand All @@ -250,7 +273,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> Collector<'p, 'v, 'tcx> {
// on an axiom, then we should also retain the axiom.
let mut used_snap_domain_function_prefixes = vec![];
let mut used_snap_domain_constructor = false;
domain.functions.retain(|function| {
for function in domain.functions.iter() {
let function_name = function.get_identifier();
let prefix = function_name.split("__").next().map(String::from);
let is_constructor = function_name.starts_with("cons");
Expand All @@ -260,22 +283,28 @@ impl<'p, 'v: 'p, 'tcx: 'v> Collector<'p, 'v, 'tcx> {
{
used_snap_domain_function_prefixes.extend(prefix);
used_snap_domain_constructor |= is_constructor;
true
} else {
false
}
});
}
domain.axioms.retain(|axiom| {
let used = used_snap_domain_function_prefixes
.iter()
.any(|prefix| axiom.name.starts_with(prefix));
let retain_type_invariant = axiom.name.ends_with("$valid") && used;
let retain_injectivity = used_snap_domain_constructor
&& axiom.name.ends_with("$injectivity");
let retain_field_axiom = used_snap_domain_constructor && used;
let retain_field_axiom =
used_snap_domain_constructor && axiom.name.ends_with("$axiom");

retain_type_invariant || retain_injectivity || retain_field_axiom
});
let functions_in_axioms =
Self::get_domain_functions_used_in_axioms(&domain.axioms);
domain.functions.retain(|function| {
let function_name = function.get_identifier();
self.used_snap_domain_functions
.contains(&function_name.clone().into())
|| functions_in_axioms.contains(&function_name)
});
}
}
domain
Expand Down
78 changes: 37 additions & 41 deletions prusti-viper/src/encoder/snapshot/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1691,10 +1691,10 @@ impl SnapshotEncoder {
// * the constructor, which takes the flattened value-only
// representation of the variant and returns an instance of the
// snapshot domain
// * the injectivity axiom for that constructor:
// * the injectivity axiom for that variant:
// ```plain
// forall _l_args..., _r_args... :: {cons(_l_args...), cons(_r_args)}
// cons(_l_args...) == cons(_r_args) ==> _l_args... == _r_args...
// forall this: Variant :: {field1(this), field2(this), ...}
// this == cons(field1(this), field2(this), ...)
// ```
// * the discriminant axiom for that constructor:
// ```plain
Expand All @@ -1704,7 +1704,7 @@ impl SnapshotEncoder {
// * field access function
// * field access axiom:
// ```plain
// forall args... :: {field(cons(arg_field, other_args...))}
// forall args... :: {cons(arg_field, other_args...)}
// field(cons(arg_field, other_args...)) == arg_field
// ```
for (variant_idx, variant) in variants.iter().enumerate() {
Expand All @@ -1715,8 +1715,6 @@ impl SnapshotEncoder {
.enumerate()
.map(|(idx, field)| vir::LocalVar::new(format!("_{idx}"), field.typ.clone()))
.collect::<Vec<vir::LocalVar>>();
// TODO: filter out Units to reduce constructor size
let has_args = !args.is_empty();

// record name to index mapping
if let Some(name) = &variant.name {
Expand Down Expand Up @@ -1745,37 +1743,6 @@ impl SnapshotEncoder {
constructor.apply(args.iter().cloned().map(Expr::local).collect())
};

if has_args {
// encode injectivity axiom of constructor
let lhs_args = encode_prefixed_args("_l");
let rhs_args = encode_prefixed_args("_r");

let lhs_call = encode_constructor_call(&lhs_args);
let rhs_call = encode_constructor_call(&rhs_args);

let mut forall_vars = vec![];
forall_vars.extend(lhs_args.iter().cloned());
forall_vars.extend(rhs_args.iter().cloned());

let conjunction = lhs_args
.iter()
.cloned()
.zip(rhs_args.iter().cloned())
.map(|(l, r)| Expr::eq_cmp(Expr::local(l), Expr::local(r)))
.conjoin();

domain_axioms.push(vir::DomainAxiom {
comment: None,
name: format!("{domain_name}${variant_idx}$injectivity"),
expr: forall_or_body(
forall_vars,
vec![vir::Trigger::new(vec![lhs_call.clone(), rhs_call.clone()])],
Expr::implies(Expr::eq_cmp(lhs_call, rhs_call), conjunction),
),
domain_name: domain_name.to_string(),
});
}

if has_multiple_variants {
// encode discriminant axiom
domain_axioms.push({
Expand All @@ -1799,12 +1766,15 @@ impl SnapshotEncoder {

let mut field_access_funcs = FxHashMap::default();

let self_local = vir::LocalVar::new("self", snapshot_type.clone());
let self_expr = Expr::local(self_local.clone());

for (field_idx, field) in variant.fields.iter().enumerate() {
// encode field access function
let field_access_func = vir::DomainFunc {
name: format!("{}${}$field${}", domain_name, variant_idx, field.name),
type_arguments: Vec::new(), // FIXME: This is most likely wrong.
formal_args: vec![vir::LocalVar::new("self", snapshot_type.clone())],
formal_args: vec![self_local.clone()],
return_type: field.typ.clone(),
unique: false,
domain_name: domain_name.to_string(),
Expand All @@ -1823,7 +1793,7 @@ impl SnapshotEncoder {
name: format!("{}${}$field${}$axiom", domain_name, variant_idx, field.name),
expr: forall_or_body(
args.clone(),
vec![vir::Trigger::new(vec![field_of_cons.clone()])],
vec![vir::Trigger::new(vec![call.clone()])],
Expr::eq_cmp(
field_of_cons.clone(),
Expr::local(args[field_idx].clone()),
Expand All @@ -1840,8 +1810,6 @@ impl SnapshotEncoder {
| ty::TyKind::Uint(_)
| ty::TyKind::Float(_)
| ty::TyKind::Char => domain_axioms.push({
let self_local = vir::LocalVar::new("self", snapshot_type.clone());
let self_expr = Expr::local(self_local.clone());
let field_of_self = field_access_func.apply(vec![self_expr.clone()]);

vir::DomainAxiom {
Expand All @@ -1866,6 +1834,34 @@ impl SnapshotEncoder {
}
}

if !args.is_empty() {
let field_access_apps: Vec<_> = variant
.fields
.iter()
.map(|f| {
let field_access_func = field_access_funcs
.get(&f.name)
.unwrap_or_else(|| panic!("No accessor for field {}", f.name));
field_access_func.apply(vec![self_expr.clone()])
})
.collect();
let expr = Expr::forall(
vec![self_local.clone()],
field_access_apps
.iter()
.map(|e| vir::Trigger::new(vec![e.clone()]))
.collect(),
Expr::eq_cmp(self_expr.clone(), constructor.apply(field_access_apps)),
);

domain_axioms.push(vir::DomainAxiom {
comment: None,
name: format!("{domain_name}${variant_idx}$injectivity"),
expr,
domain_name: domain_name.to_string(),
});
}

variant_domain_funcs.push((constructor.clone(), field_access_funcs));

// encode constructor call for this variant
Expand Down
Loading