diff --git a/Project.toml b/Project.toml index a9c50e8..7884d33 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HerbGrammar" uuid = "4ef9e186-2fe5-4b24-8de7-9f7291f24af7" authors = ["Sebastijan Dumancic ", "Jaap de Jong ", "Nicolae Filat ", "Piotr Cichoń "] -version = "0.1.0" +version = "0.2.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -11,11 +11,11 @@ TreeView = "39424ebd-4cf3-5550-a685-96706a953f40" HerbCore = "2b23ba43-8213-43cb-b5ea-38c12b45bd45" [compat] -AbstractTrees = "0.4" -DataStructures = "0.17,0.18" -TreeView = "0.5" -HerbCore = "0.1.0" -julia = "1.8" +AbstractTrees = "^0.4" +DataStructures = "^0.18" +TreeView = "^0.5" +HerbCore = "^0.2.0" +julia = "^1.8" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index 6d29e91..637c3cd 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,16 @@ -# Grammar.jl [![Build Status](https://github.com/Herb-AI/HerbGrammar.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/Herb-AI/HerbGrammar.jl/actions/workflows/CI.yml?query=branch%3Amaster) +[![Dev-Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://Herb-AI.github.io/Herb.jl/dev) + +# HerbGrammar.jl This package contains functionality for declaring grammars for the Herb Program Synthesis framework. +For full documentation please see the [`Herb.jl` documentation](https://herb-ai.github.io/Herb.jl/dev/). + ## Getting started -To use this project, initialize the project with +For a quick tutorial on how to get started with using `HerbSearch.jl` have a look at our [introductory tutorial](https://herb-ai.github.io/Herb.jl/dev/get_started/). +If you want to help developing this project, initialize the project with ```shell julia --project=. -e 'using Pkg; Pkg.instantiate()' ``` diff --git a/src/HerbGrammar.jl b/src/HerbGrammar.jl index dd0bd3a..927b56f 100644 --- a/src/HerbGrammar.jl +++ b/src/HerbGrammar.jl @@ -12,10 +12,6 @@ include("rulenode_operators.jl") include("utils.jl") include("nodelocation.jl") - -include("cfg/cfg.jl") -include("cfg/probabilistic_cfg.jl") - include("csg/csg.jl") include("csg/probabilistic_csg.jl") @@ -25,18 +21,13 @@ export ContextFree, ContextSensitive, - ContextFreeGrammar, - ContextSensitiveGrammar, AbstractRuleNode, RuleNode, Hole, NodeLoc, - ProbabilisticCFG, - @cfgrammar, - expr2cfgrammar, max_arity, isterminal, iseval, @@ -56,12 +47,11 @@ export @csgrammar, expr2csgrammar, - cfg2csg, clearconstraints!, addconstraint!, + merge_grammars!, @pcfgrammar, - expr2pcfgrammar, @pcsgrammar, expr2pcsgrammar, @@ -77,9 +67,6 @@ export containedin, subsequenceof, has_children, - store_cfg, - read_cfg, - read_pcfg, store_csg, read_csg, read_pcsg, diff --git a/src/cfg/cfg.jl b/src/cfg/cfg.jl deleted file mode 100644 index 675ebf5..0000000 --- a/src/cfg/cfg.jl +++ /dev/null @@ -1,155 +0,0 @@ -""" - ContextFreeGrammar <: Grammar - -Represents a context-free grammar and its production rules. -Consists of: - -- `rules::Vector{Any}`: A list of RHS of rules (subexpressions). -- `types::Vector{Symbol}`: A list of LHS of rules (types, all symbols). -- `isterminal::BitVector`: A bitvector where bit `i` represents whether rule `i` is terminal. -- `iseval::BitVector`: A bitvector where bit `i` represents whether rule i is an eval rule. -- `bytype::Dict{Symbol,Vector{Int}}`: A dictionary that maps a type to all rules of said type. -- `domains::Dict{Symbol, BitVector}`: A dictionary that maps a type to a domain bitvector. - The domain bitvector has bit `i` set to true iff the `i`th rule is of this type. -- `childtypes::Vector{Vector{Symbol}}`: A list of types of the children for each rule. - If a rule is terminal, the corresponding list is empty. -- `log_probabilities::Union{Vector{Real}, Nothing}`: A list of probabilities for each rule. - If the grammar is non-probabilistic, the list can be `nothing`. - -Use the [`@cfgrammar`](@ref) macro to create a [`ContextFreeGrammar`](@ref) object. -Use the [`@pcfgrammar`](@ref) macro to create a [`ContextFreeGrammar`](@ref) object with probabilities. -For context-sensitive grammars, see [`ContextSensitiveGrammar`](@ref). - -""" -mutable struct ContextFreeGrammar <: Grammar - rules::Vector{Any} # list of RHS of rules (subexpressions) - types::Vector{Union{Symbol, Nothing}} # list of LHS of rules (types, all symbols) - isterminal::BitVector # whether rule i is terminal - iseval::BitVector # whether rule i is an eval rule - bytype::Dict{Symbol,Vector{Int}} # maps type to all rules of said type - domains::Dict{Symbol,BitVector} # maps type to a domain bitvector - childtypes::Vector{Vector{Symbol}} # list of types of the children for each rule. Empty if terminal - log_probabilities::Union{Vector{Real}, Nothing} # list of probabilities for the rules if this is a probabilistic grammar -end - -""" - expr2cfgrammar(ex::Expr)::ContextFreeGrammar - -A function for converting an `Expr` to a [`ContextFreeGrammar`](@ref). -If the expression is hardcoded, you should use the [`@cfgrammar`](@ref) macro. -Only expressions in the correct format (see [`@cfgrammar`](@ref)) can be converted. - -### Example usage: - -```@example -grammar = expr2cfgrammar( - begin - R = x - R = 1 | 2 - R = R + R - end -) -``` -""" -function expr2cfgrammar(ex::Expr)::ContextFreeGrammar - rules = Any[] - types = Symbol[] - bytype = Dict{Symbol,Vector{Int}}() - for e ∈ ex.args - if isa(e, Expr) - if e.head == :(=) - s = e.args[1] # name of return type - rule = e.args[2] # expression? - rvec = Any[] - parse_rule!(rvec, rule) - for r ∈ rvec - push!(rules, r) - push!(types, s) - bytype[s] = push!(get(bytype, s, Int[]), length(rules)) - end - end - end - end - alltypes = collect(keys(bytype)) - is_terminal = [isterminal(rule, alltypes) for rule ∈ rules] - is_eval = [iseval(rule) for rule ∈ rules] - childtypes = [get_childtypes(rule, alltypes) for rule ∈ rules] - domains = Dict(type => BitArray(r ∈ bytype[type] for r ∈ 1:length(rules)) for type ∈ alltypes) - return ContextFreeGrammar(rules, types, is_terminal, is_eval, bytype, domains, childtypes, nothing) -end - -""" - @cfgrammar - -A macro for defining a [`ContextFreeGrammar`](@ref). - -### Example usage: -```julia -grammar = @cfgrammar begin - R = x - R = 1 | 2 - R = R + R -end -``` - -### Syntax: - -- Literals: Symbols that are already defined in Julia are considered literals, such as `1`, `2`, or `π`. - For example: `R = 1`. -- Variables: A variable is a symbol that is not a nonterminal symbol and not already defined in Julia. - For example: `R = x`. -- Functions: Functions and infix operators that are defined in Julia or the `Main` module can be used - with the default evaluator. For example: `R = R + R`, `R = f(a, b)`. -- Combinations: Multiple rules can be defined on a single line in the grammar definition using the `|` symbol. - For example: `R = 1 | 2 | 3`. -- Iterators: Another way to define multiple rules is by providing a Julia iterator after a `|` symbol. - For example: `R = |(1:9)`. - -### Related: - -- [`@csgrammar`](@ref) uses the same syntax to create [`ContextSensitiveGrammar`](@ref)s. -- [`@pcfgrammar`](@ref) uses a similar syntax to create probabilistic [`ContextFreeGrammar`](@ref)s. -""" -macro cfgrammar(ex) - return expr2cfgrammar(ex) -end - -parse_rule!(v::Vector{Any}, r) = push!(v, r) - -function parse_rule!(v::Vector{Any}, ex::Expr) - # Strips `LineNumberNode`s from the expression - Base.remove_linenums!(ex) - - if ex.head == :call && ex.args[1] == :| - terms = _expand_shorthand(ex.args) - - for t in terms - parse_rule!(v, t) - end - else - push!(v, ex) - end -end - -function _expand_shorthand(args::Vector{Any}) - # expand a rule using the `|` symbol: - # `X = |(1:3)`, `X = 1|2|3`, `X = |([1,2,3])` - # these should all be equivalent and should expand to - # the following 3 rules: `X = 1`, `X = 2`, and `X = 3` - if args[1] != :| - throw(ArgumentError("Tried to parse: $ex as a shorthand rule, but it is not a shorthand rule.")) - end - - if length(args) == 2 - to_expand = args[2] - if to_expand.args[1] == :(:) - expanded = collect(to_expand.args[2]:to_expand.args[3]) # (1:3) case - else - expanded = to_expand.args # ([1,2,3]) case - end - elseif length(args) == 3 - expanded = args[2:end] # 1|2|3 case - else - throw(ArgumentError("Failed to parse shorthand for rule: $ex")) - end -end diff --git a/src/cfg/probabilistic_cfg.jl b/src/cfg/probabilistic_cfg.jl deleted file mode 100644 index 0a180e4..0000000 --- a/src/cfg/probabilistic_cfg.jl +++ /dev/null @@ -1,111 +0,0 @@ -""" -Function for converting an `Expr` to a [`ContextFreeGrammar`](@ref) with probabilities. -If the expression is hardcoded, you should use the `@pcfgrammar` macro. -Only expressions in the correct format (see [`@pcfgrammar`](@ref)) can be converted. - -### Example usage: - -```@example -grammar = expr2pcsgrammar( - begin - 0.5 : R = x - 0.3 : R = 1 | 2 - 0.2 : R = R + R - end -) -``` -""" -function expr2pcfgrammar(ex::Expr)::ContextFreeGrammar - rules = Any[] - types = Symbol[] - probabilities = Real[] - bytype = Dict{Symbol,Vector{Int}}() - for e ∈ ex.args - if e isa Expr - if e.head == :(=) - left = e.args[1] # name of return type and probability - if left isa Expr && left.head == :call && left.args[1] == :(:) - p = left.args[2] # Probability - s = left.args[3] # Return type - rule = e.args[2].args[2] # extract rule from block expr - - rvec = Any[] - parse_rule!(rvec, rule) - for r ∈ rvec - push!(rules, r) - push!(types, s) - # Divide the probability of this line by the number of rules it defines. - push!(probabilities, p / length(rvec)) - bytype[s] = push!(get(bytype, s, Int[]), length(rules)) - end - else - @error "Rule without probability encountered in probabilistic grammar. Rule ignored." - end - end - end - end - alltypes = collect(keys(bytype)) - # Normalize probabilities for each type - for t ∈ alltypes - total_prob = sum(probabilities[i] for i ∈ bytype[t]) - if !(total_prob ≈ 1) - @warn "The probabilities for type $t don't add up to 1, so they will be normalized." - for i ∈ bytype[t] - probabilities[i] /= total_prob - end - end - end - - log_probabilities = [log(x) for x ∈ probabilities] - is_terminal = [isterminal(rule, alltypes) for rule in rules] - is_eval = [iseval(rule) for rule in rules] - childtypes = [get_childtypes(rule, alltypes) for rule in rules] - domains = Dict(type => BitArray(r ∈ bytype[type] for r ∈ 1:length(rules)) for type ∈ alltypes) - return ContextFreeGrammar(rules, types, is_terminal, is_eval, bytype, domains, childtypes, log_probabilities) -end - - -""" - @pcfgrammar - -A macro for defining a probabilistic [`ContextFreeGrammar`](@ref). - -### Example usage: -```julia -grammar = @pcfgrammar begin - 0.5 : R = x - 0.3 : R = 1 | 2 - 0.2 : R = R + R -end -``` - -### Syntax: - -The syntax of rules is identical to the syntax used by [`@cfgrammar`](@ref): - -- Literals: Symbols that are already defined in Julia are considered literals, such as `1`, `2`, or `π`. - For example: `R = 1`. -- Variables: A variable is a symbol that is not a nonterminal symbol and not already defined in Julia. - For example: `R = x`. -- Functions: Functions and infix operators that are defined in Julia or the `Main` module can be used - with the default evaluator. For example: `R = R + R`, `R = f(a, b)`. -- Combinations: Multiple rules can be defined on a single line in the grammar definition using the `|` symbol. - For example: `R = 1 | 2 | 3`. -- Iterators: Another way to define multiple rules is by providing a Julia iterator after a `|` symbol. - For example: `R = |(1:9)`. - -Every rule is also prefixed with a probability. -Rules and probabilities are separated using the `:` symbol. -If multiple rules are defined on a single line, the probability is equally divided between the rules. -The sum of probabilities for all rules of a certain non-terminal symbol should be equal to 1. -The probabilities are automatically scaled if this isn't the case. - - -### Related: - -- [`@pcsgrammar`](@ref) uses the same syntax to create probabilistic [`ContextSensitiveGrammar`](@ref)s. -- [`@cfgrammar`](@ref) uses a similar syntax to create non-probabilistic [`ContextFreeGrammar`](@ref)s. -""" -macro pcfgrammar(ex) - return expr2pcfgrammar(ex) -end \ No newline at end of file diff --git a/src/csg/csg.jl b/src/csg/csg.jl index 4e23afe..375414b 100644 --- a/src/csg/csg.jl +++ b/src/csg/csg.jl @@ -1,8 +1,8 @@ """ - ContextSensitiveGrammar <: Grammar + ContextSensitiveGrammar <: AbstractGrammar Represents a context-sensitive grammar. -Extends [`Grammar`](@ref) with constraints. +Extends [`AbstractGrammar`](@ref) with constraints. Consists of: @@ -21,9 +21,8 @@ Consists of: Use the [`@csgrammar`](@ref) macro to create a [`ContextSensitiveGrammar`](@ref) object. Use the [`@pcsgrammar`](@ref) macro to create a [`ContextSensitiveGrammar`](@ref) object with probabilities. -For context-free grammars, see [`ContextFreeGrammar`](@ref). """ -mutable struct ContextSensitiveGrammar <: Grammar +mutable struct ContextSensitiveGrammar <: AbstractGrammar rules::Vector{Any} types::Vector{Union{Symbol, Nothing}} isterminal::BitVector @@ -35,6 +34,16 @@ mutable struct ContextSensitiveGrammar <: Grammar constraints::Vector{Constraint} end +ContextSensitiveGrammar( + rules::Vector{<:Any}, + types::Vector{<:Union{Symbol, Nothing}}, + isterminal::Union{BitVector, Vector{Bool}}, + iseval::Union{BitVector, Vector{Bool}}, + bytype::Dict{Symbol, Vector{Int}}, + domains::Dict{Symbol, BitVector}, + childtypes::Vector{Vector{Symbol}}, + log_probabilities::Union{Vector{<:Real}, Nothing} +) = ContextSensitiveGrammar(rules, types, isterminal, iseval, bytype, domains, childtypes, log_probabilities, Constraint[]) """ expr2csgrammar(ex::Expr)::ContextSensitiveGrammar @@ -56,7 +65,30 @@ grammar = expr2csgrammar( ``` """ function expr2csgrammar(ex::Expr)::ContextSensitiveGrammar - return cfg2csg(expr2cfgrammar(ex)) + rules = Any[] + types = Symbol[] + bytype = Dict{Symbol,Vector{Int}}() + for e ∈ ex.args + if isa(e, Expr) + if e.head == :(=) + s = e.args[1] # name of return type + rule = e.args[2] # expression? + rvec = Any[] + parse_rule!(rvec, rule) + for r ∈ rvec + push!(rules, r) + push!(types, s) + bytype[s] = push!(get(bytype, s, Int[]), length(rules)) + end + end + end + end + alltypes = collect(keys(bytype)) + is_terminal::Vector{Bool} = [isterminal(rule, alltypes) for rule ∈ rules] + is_eval::Vector{Bool} = [iseval(rule) for rule ∈ rules] + childtypes::Vector{Vector{Symbol}} = [get_childtypes(rule, alltypes) for rule ∈ rules] + domains = Dict(type => BitArray(r ∈ bytype[type] for r ∈ 1:length(rules)) for type ∈ alltypes) + return ContextSensitiveGrammar(rules, types, is_terminal, is_eval, bytype, domains, childtypes, nothing) end @@ -91,30 +123,54 @@ end ### Related: -- [`@cfgrammar`](@ref) uses the same syntax to create [`ContextFreeGrammar`](@ref)s. - [`@pcsgrammar`](@ref) uses a similar syntax to create probabilistic [`ContextSensitiveGrammar`](@ref)s. """ macro csgrammar(ex) return expr2csgrammar(ex) end -""" - cfg2csg(g::ContextFreeGrammar)::ContextSensitiveGrammar +macro cfgrammar(ex) + return expr2csgrammar(ex) +end -Converts a [`ContextFreeGrammar`](@ref) to a [`ContextSensitiveGrammar`](@ref) without any [`Constraint`](@ref)s. -""" -function cfg2csg(g::ContextFreeGrammar)::ContextSensitiveGrammar - return ContextSensitiveGrammar( - g.rules, - g.types, - g.isterminal, - g.iseval, - g.bytype, - g.domains, - g.childtypes, - g.log_probabilities, - [] - ) +parse_rule!(v::Vector{Any}, r) = push!(v, r) + +function parse_rule!(v::Vector{Any}, ex::Expr) + # Strips `LineNumberNode`s from the expression + Base.remove_linenums!(ex) + + if ex.head == :call && ex.args[1] == :| + terms = _expand_shorthand(ex.args) + + for t in terms + parse_rule!(v, t) + end + else + push!(v, ex) + end +end + +function _expand_shorthand(args::Vector{Any}) + # expand a rule using the `|` symbol: + # `X = |(1:3)`, `X = 1|2|3`, `X = |([1,2,3])` + # these should all be equivalent and should expand to + # the following 3 rules: `X = 1`, `X = 2`, and `X = 3` + if args[1] != :| + throw(ArgumentError("Tried to parse: $ex as a shorthand rule, but it is not a shorthand rule.")) + end + + if length(args) == 2 + to_expand = args[2] + if to_expand.args[1] == :(:) + expanded = collect(to_expand.args[2]:to_expand.args[3]) # (1:3) case + else + expanded = to_expand.args # ([1,2,3]) case + end + elseif length(args) == 3 + expanded = args[2:end] # 1|2|3 case + else + throw(ArgumentError("Failed to parse shorthand for rule: $ex")) + end end """ @@ -131,4 +187,19 @@ clearconstraints!(grammar::ContextSensitiveGrammar) = empty!(grammar.constraints function Base.display(rulenode::RuleNode, grammar::ContextSensitiveGrammar) return rulenode2expr(rulenode, grammar) -end \ No newline at end of file +end + +""" + merge_grammars!(merge_to::AbstractGrammar, merge_from::AbstractGrammar) + +Adds all rules and constraints from `merge_from` to `merge_to`. +""" +function merge_grammars!(merge_to::AbstractGrammar, merge_from::AbstractGrammar) + for i in eachindex(merge_from.rules) + expression = :($(merge_from.types[i]) = $(merge_from.rules[i])) + add_rule!(merge_to, expression) + end + for i in eachindex(merge_from.constraints) + addconstraint!(merge_to, merge_from.constraints[i]) + end +end diff --git a/src/csg/probabilistic_csg.jl b/src/csg/probabilistic_csg.jl index 3610b00..a8c6522 100644 --- a/src/csg/probabilistic_csg.jl +++ b/src/csg/probabilistic_csg.jl @@ -17,7 +17,94 @@ grammar = expr2pcsgrammar( ``` """ function expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar - cfg2csg(expr2pcfgrammar(ex)) + rules = Any[] + types = Symbol[] + probabilities = Real[] + bytype = Dict{Symbol,Vector{Int}}() + for e ∈ ex.args + if e isa Expr + maybe_rules = parse_probabilistic_rule(e) + isnothing(maybe_rules) && continue # if rules is nothing, skip + s, prvec = maybe_rules + + for (p, r) ∈ prvec + push!(rules, r) + push!(types, s) + push!(probabilities, p) + bytype[s] = push!(get(bytype, s, Int[]), length(rules)) + end + end + end + alltypes = collect(keys(bytype)) + # Normalize probabilities for each type + for t ∈ alltypes + total_prob = sum(probabilities[i] for i ∈ bytype[t]) + if !(total_prob ≈ 1) + @warn "The probabilities for type $t don't add up to 1, so they will be normalized." + for i ∈ bytype[t] + probabilities[i] /= total_prob + end + end + end + + log_probabilities = [log(x) for x ∈ probabilities] + is_terminal = [isterminal(rule, alltypes) for rule in rules] + is_eval = [iseval(rule) for rule in rules] + childtypes = [get_childtypes(rule, alltypes) for rule in rules] + domains = Dict(type => BitArray(r ∈ bytype[type] for r ∈ 1:length(rules)) for type ∈ alltypes) + + normalize!(ContextSensitiveGrammar(rules, types, is_terminal, is_eval, bytype, domains, childtypes, log_probabilities)) +end + +""" +Parses a single (potentially shorthand) derivation rule of a probabilistic [`ContextSensitiveGrammar`](@ref). +Returns `nothing` if the rule is not probabilistic, otherwise a `Tuple` of its type and a +`Vector` of probability-rule pairs it expands into. +""" +function parse_probabilistic_rule(e::Expr) + prvec = Tuple{Real, Any}[] + if e.head == :(=) + left = e.args[1] # name of return type and probability + if left isa Expr && left.head == :call && left.args[1] == :(:) + p = left.args[2] # Probability + s = left.args[3] # Return type + rule = e.args[2].args[2] # extract rule from block expr + + rvec = Any[] + parse_rule!(rvec, rule) + for r ∈ rvec + # Divide the probability of this line by the number of rules it defines. + push!(prvec, (p / length(rvec), r)) + end + + return s, prvec + else + @error "Rule without probability encountered in probabilistic grammar. Rule ignored." + return nothing + end + end +end + + +""" +A function for normalizing the probabilities of a probabilistic [`ContextSensitiveGrammar`](@ref). +If the optional `type` argument is provided, only the rules of that type are normalized. +""" +function normalize!(g::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing) + probabilities = map(exp, g.log_probabilities) + types = isnothing(type) ? keys(g.bytype) : [type] + + for t ∈ types + total_prob = sum(probabilities[i] for i ∈ g.bytype[t]) + if !(total_prob ≈ 1) + for i ∈ g.bytype[t] + probabilities[i] /= total_prob + end + end + end + + g.log_probabilities = map(log, probabilities) + return g end """ @@ -58,9 +145,12 @@ The probabilities are automatically scaled if this isn't the case. ### Related: -- [`@pcfgrammar`](@ref) uses the same syntax to create probabilistic [`ContextFreeGrammar`](@ref)s. - [`@csgrammar`](@ref) uses a similar syntax to create non-probabilistic [`ContextSensitiveGrammar`](@ref)s. """ macro pcsgrammar(ex) return expr2pcsgrammar(ex) +end + +macro pcfgrammar(ex) + return expr2pcsgrammar(ex) end \ No newline at end of file diff --git a/src/grammar_base.jl b/src/grammar_base.jl index d7be8f9..cbe53e2 100644 --- a/src/grammar_base.jl +++ b/src/grammar_base.jl @@ -45,63 +45,63 @@ function get_childtypes(rule::Any, types::AbstractVector{Symbol}) return retval end -Base.getindex(grammar::Grammar, typ::Symbol) = grammar.bytype[typ] +Base.getindex(grammar::AbstractGrammar, typ::Symbol) = grammar.bytype[typ] """ - nonterminals(grammar::Grammar)::Vector{Symbol} + nonterminals(grammar::AbstractGrammar)::Vector{Symbol} -Returns a list of the nonterminals or types in the [`Grammar`](@ref). +Returns a list of the nonterminals or types in the [`AbstractGrammar`](@ref). """ -nonterminals(grammar::Grammar)::Vector{Symbol} = collect(keys(grammar.bytype)) +nonterminals(grammar::AbstractGrammar)::Vector{Symbol} = collect(keys(grammar.bytype)) """ - return_type(grammar::Grammar, rule_index::Int)::Symbol + return_type(grammar::AbstractGrammar, rule_index::Int)::Symbol Returns the type of the production rule at `rule_index`. """ -return_type(grammar::Grammar, rule_index::Int) = grammar.types[rule_index] +return_type(grammar::AbstractGrammar, rule_index::Int) = grammar.types[rule_index] """ - child_types(grammar::Grammar, rule_index::Int) + child_types(grammar::AbstractGrammar, rule_index::Int) Returns the types of the children (nonterminals) of the production rule at `rule_index`. """ -child_types(grammar::Grammar, rule_index::Int) = grammar.childtypes[rule_index] +child_types(grammar::AbstractGrammar, rule_index::Int) = grammar.childtypes[rule_index] """ - get_domain(g::Grammar, type::Symbol)::BitVector + get_domain(g::AbstractGrammar, type::Symbol)::BitVector Returns the domain for the hole of a certain type as a `BitVector` of the same length as the number of rules in the grammar. Bit `i` is set to `true` iff rule `i` is of type `type`. !!! info Since this function can be intensively used when exploring a program space defined by a grammar, - the outcomes of this function are precomputed and stored in the `domains` field in a [`Grammar`](@ref). + the outcomes of this function are precomputed and stored in the `domains` field in a [`AbstractGrammar`](@ref). """ -get_domain(g::Grammar, type::Symbol)::BitVector = deepcopy(g.domains[type]) +get_domain(g::AbstractGrammar, type::Symbol)::BitVector = deepcopy(g.domains[type]) """ - get_domain(g::Grammar, rules::Vector{Int})::BitVector + get_domain(g::AbstractGrammar, rules::Vector{Int})::BitVector Takes a domain `rules` defined as a vector of ints and converts it to a domain defined as a `BitVector`. """ -get_domain(g::Grammar, rules::Vector{Int})::BitVector = BitArray(r ∈ rules for r ∈ 1:length(g.rules)) +get_domain(g::AbstractGrammar, rules::Vector{Int})::BitVector = BitArray(r ∈ rules for r ∈ 1:length(g.rules)) """ - isterminal(grammar::Grammar, rule_index::Int)::Bool + isterminal(grammar::AbstractGrammar, rule_index::Int)::Bool Returns true if the production rule at `rule_index` is terminal, i.e., does not contain any nonterminal symbols. """ -isterminal(grammar::Grammar, rule_index::Int)::Bool = grammar.isterminal[rule_index] +isterminal(grammar::AbstractGrammar, rule_index::Int)::Bool = grammar.isterminal[rule_index] """ - iseval(grammar::Grammar)::Bool + iseval(grammar::AbstractGrammar)::Bool Returns true if any production rules in grammar contain the special _() eval function. @@ -109,11 +109,11 @@ Returns true if any production rules in grammar contain the special _() eval fun evaluate immediately functionality is not yet supported by most of Herb.jl """ -iseval(grammar::Grammar)::Bool = any(grammar.iseval) +iseval(grammar::AbstractGrammar)::Bool = any(grammar.iseval) """ - iseval(grammar::Grammar, index::Int)::Bool + iseval(grammar::AbstractGrammar, index::Int)::Bool Returns true if the production rule at rule_index contains the special _() eval function. @@ -121,18 +121,18 @@ Returns true if the production rule at rule_index contains the special _() eval evaluate immediately functionality is not yet supported by most of Herb.jl """ -iseval(grammar::Grammar, index::Int)::Bool = grammar.iseval[index] +iseval(grammar::AbstractGrammar, index::Int)::Bool = grammar.iseval[index] """ - log_probability(grammar::Grammar, index::Int)::Real + log_probability(grammar::AbstractGrammar, index::Int)::Real Returns the log probability for the rule at `index` in the grammar. !!! warning If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. """ -function log_probability(grammar::Grammar, index::Int)::Real +function log_probability(grammar::AbstractGrammar, index::Int)::Real if !isprobabilistic(grammar) @warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed." # Assume uniform probability @@ -142,7 +142,7 @@ function log_probability(grammar::Grammar, index::Int)::Real end """ - probability(grammar::Grammar, index::Int)::Real + probability(grammar::AbstractGrammar, index::Int)::Real Return the probability for a rule in the grammar. Use [`log_probability`](@ref) whenever possible. @@ -150,7 +150,7 @@ Use [`log_probability`](@ref) whenever possible. !!! warning If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. """ -function probability(grammar::Grammar, index::Int)::Real +function probability(grammar::AbstractGrammar, index::Int)::Real if !isprobabilistic(grammar) @warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed." # Assume uniform probability @@ -160,30 +160,30 @@ function probability(grammar::Grammar, index::Int)::Real end """ - isprobabilistic(grammar::Grammar)::Bool + isprobabilistic(grammar::AbstractGrammar)::Bool -Function returns whether a [`Grammar`](@ref) is probabilistic. +Function returns whether a [`AbstractGrammar`](@ref) is probabilistic. """ -isprobabilistic(grammar::Grammar)::Bool = !(grammar.log_probabilities ≡ nothing) +isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities ≡ nothing) """ - nchildren(grammar::Grammar, rule_index::Int)::Int + nchildren(grammar::AbstractGrammar, rule_index::Int)::Int Returns the number of children (nonterminals) of the production rule at `rule_index`. """ -nchildren(grammar::Grammar, rule_index::Int)::Int = length(grammar.childtypes[rule_index]) +nchildren(grammar::AbstractGrammar, rule_index::Int)::Int = length(grammar.childtypes[rule_index]) """ - max_arity(grammar::Grammar)::Int + max_arity(grammar::AbstractGrammar)::Int -Returns the maximum arity (number of children) over all production rules in the [`Grammar`](@ref). +Returns the maximum arity (number of children) over all production rules in the [`AbstractGrammar`](@ref). """ -max_arity(grammar::Grammar)::Int = maximum(length(cs) for cs in grammar.childtypes) +max_arity(grammar::AbstractGrammar)::Int = maximum(length(cs) for cs in grammar.childtypes) -function Base.show(io::IO, grammar::Grammar) +function Base.show(io::IO, grammar::AbstractGrammar) for i in eachindex(grammar.rules) println(io, i, ": ", grammar.types[i], " = ", grammar.rules[i]) end @@ -191,7 +191,7 @@ end """ - add_rule!(g::Grammar, e::Expr) + add_rule!(g::AbstractGrammar, e::Expr) Adds a rule to the grammar. @@ -204,21 +204,25 @@ The syntax is identical to the syntax of [`@csgrammar`](@ref) and [`@cfgrammar`] !!! warning Calls to this function are ignored if a rule is already in the grammar. """ -function add_rule!(g::Grammar, e::Expr) - if e.head == :(=) +function add_rule!(g::AbstractGrammar, e::Expr) + if e.head == :(=) && typeof(e.args[1]) == Symbol s = e.args[1] # Name of return type rule = e.args[2] # expression? rvec = Any[] parse_rule!(rvec, rule) for r ∈ rvec - if r ∈ g.rules - continue + # Only add a rule if it does not exist yet. Check for existance + # with strict equality so that true and 1 are not considered + # equal. that means we can't use `in` or `∈` for equality checking. + if !any(r === rule for rule ∈ g.rules) + push!(g.rules, r) + push!(g.iseval, iseval(rule)) + push!(g.types, s) + g.bytype[s] = push!(get(g.bytype, s, Int[]), length(g.rules)) end - push!(g.rules, r) - push!(g.iseval, iseval(rule)) - push!(g.types, s) - g.bytype[s] = push!(get(g.bytype, s, Int[]), length(g.rules)) end + else + throw(ArgumentError("Invalid rule: $e. Rules must be of the form `Symbol = Expr`")) end alltypes = collect(keys(g.bytype)) @@ -230,15 +234,27 @@ function add_rule!(g::Grammar, e::Expr) return g end +""" +Adds a probabilistic derivation rule. +""" +function add_rule!(g::AbstractGrammar, p::Real, e::Expr) + isprobabilistic(g) || throw(ArgumentError("adding a probabilistic rule to a non-probabilistic grammar")) + len₀ = length(g.rules) + add_rule!(g, e) + len₁ = length(g.rules) + nnew = len₁ - len₀ + append!(g.log_probabilities, repeat([log(p / nnew)], nnew)) + normalize!(g) +end """ - remove_rule!(g::Grammar, idx::Int) + remove_rule!(g::AbstractGrammar, idx::Int) Removes the rule corresponding to `idx` from the grammar. In order to avoid shifting indices, the rule is replaced with `nothing`, and all other data structures are updated accordingly. """ -function remove_rule!(g::Grammar, idx::Int) +function remove_rule!(g::AbstractGrammar, idx::Int) type = g.types[idx] g.rules[idx] = nothing g.iseval[idx] = false @@ -259,7 +275,7 @@ end """ - cleanup_removed_rules!(g::Grammar) + cleanup_removed_rules!(g::AbstractGrammar) Removes any placeholders for previously deleted rules. This means that indices get shifted. @@ -269,7 +285,7 @@ This means that indices get shifted. [`AbstractRuleNode`](@ref) trees created before the call to this function. These trees become meaningless. """ -function cleanup_removed_rules!(g::Grammar) +function cleanup_removed_rules!(g::AbstractGrammar) rules_to_cleanup = findall(isequal(nothing), g.rules) # highest indices are removed first, otherwise their index will have shifted for v ∈ [g.rules, g.types, g.isterminal, g.iseval, g.childtypes] diff --git a/src/grammar_io.jl b/src/grammar_io.jl index 9ba028d..39867a9 100644 --- a/src/grammar_io.jl +++ b/src/grammar_io.jl @@ -1,9 +1,14 @@ +const OptionalPath = Union{Nothing, AbstractString} + """ - store_cfg(filepath::AbstractString, grammar::ContextFreeGrammar) + store_csg(g::ContextSensitiveGrammar, grammarpath::AbstractString, constraintspath::OptionalPath=nothing) -Writes a [`ContextFreeGrammar`](@ref) to the file provided by `filepath`. +Writes a [`ContextSensitiveGrammar`](@ref) to the files at `grammarpath` and `constraintspath`. +The `grammarpath` file will contain a [`ContextSensitiveGrammar`](@ref) definition, and the +`constraintspath` file will contain the [`Constraint`](@ref)s of the [`ContextSensitiveGrammar`](@ref). """ -function store_cfg(filepath::AbstractString, grammar::ContextFreeGrammar) +function store_csg(grammar::ContextSensitiveGrammar, filepath::AbstractString, constraintspath::OptionalPath=nothing) + # Store grammar as CFG open(filepath, write=true) do file if !isprobabilistic(grammar) for (type, rule) ∈ zip(grammar.types, grammar.rules) @@ -15,21 +20,28 @@ function store_cfg(filepath::AbstractString, grammar::ContextFreeGrammar) end end end -end + + # exit if no constraintspath is given + isnothing(constraintspath) && return + # Store constraints separately + open(constraintspath, write=true) do file + serialize(file, grammar.constraints) + end +end """ - read_cfg(filepath::AbstractString)::ContextFreeGrammar + read_csg(grammarpath::AbstractString, constraintspath::OptionalPath=nothing)::ContextSensitiveGrammar -Reads a [`ContextFreeGrammar`](@ref) from the file provided in `filepath`. +Reads a [`ContextSensitiveGrammar`](@ref) from the files at `grammarpath` and `constraintspath`. !!! danger Only open trusted grammars. Parts of the grammar can be passed to Julia's `eval` function. """ -function read_cfg(filepath::AbstractString)::ContextFreeGrammar +function read_csg(grammarpath::AbstractString, constraintspath::OptionalPath=nothing)::ContextSensitiveGrammar # Read the contents of the file into a string - file = open(filepath) + file = open(grammarpath) program::AbstractString = read(file, String) close(file) @@ -37,21 +49,32 @@ function read_cfg(filepath::AbstractString)::ContextFreeGrammar ex::Expr = Meta.parse("begin $program end") # Convert the expression to a context-free grammar - return expr2cfgrammar(ex) + g = expr2csgrammar(ex) + + if !isnothing(constraintspath) + file = open(constraintspath) + constraints = deserialize(file) + close(file) + else + constraints = Constraint[] + end + + return ContextSensitiveGrammar(g.rules, g.types, g.isterminal, + g.iseval, g.bytype, g.domains, g.childtypes, g.log_probabilities, constraints) end """ - read_pcfg(filepath::AbstractString)::ContextFreeGrammar + read_pcsg(grammarpath::AbstractString, constraintspath::OptionalPath=nothing)::ContextSensitiveGrammar -Reads a probabilistic [`ContextFreeGrammar`](@ref) from a file provided in `filepath`. +Reads a probabilistic [`ContextSensitiveGrammar`](@ref) from the files at `grammarpath` and `constraintspath`. !!! danger Only open trusted grammars. Parts of the grammar can be passed to Julia's `eval` function. """ -function read_pcfg(filepath::AbstractString)::ContextFreeGrammar +function read_pcsg(grammarpath::AbstractString, constraintspath::OptionalPath=nothing)::ContextSensitiveGrammar # Read the contents of the file into a string - file = open(filepath) + file = open(grammarpath) program::AbstractString = read(file, String) close(file) @@ -59,63 +82,16 @@ function read_pcfg(filepath::AbstractString)::ContextFreeGrammar ex::Expr = Meta.parse("begin $program end") # Convert the expression to a context-free grammar - return expr2pcfgrammar(ex) -end - -""" - store_csg(grammarpath::AbstractString, constraintspath::AbstractString, g::ContextSensitiveGrammar) - -Writes a [`ContextSensitiveGrammar`](@ref) to the files at `grammarpath` and `constraintspath`. -The `grammarpath` file will contain a [`ContextSensitiveGrammar`](@ref) definition, and the -`constraintspath` file will contain the [`Constraint`](@ref)s of the [`ContextSensitiveGrammar`](@ref). -""" -function store_csg(grammarpath::AbstractString, constraintspath::AbstractString, g::ContextSensitiveGrammar) - # Store grammar as CFG - store_cfg(grammarpath, ContextFreeGrammar(g.rules, g.types, - g.isterminal, g.iseval, g.bytype, g.domains, g.childtypes, g.log_probabilities)) + g = expr2pcsgrammar(ex) - # Store constraints separately - open(constraintspath, write=true) do file - serialize(file, g.constraints) + if !isnothing(constraintspath) + file = open(constraintspath) + constraints = deserialize(file) + close(file) + else + constraints = Constraint[] end -end - -""" - read_csg(grammarpath::AbstractString, constraintspath::AbstractString)::ContextSensitiveGrammar - -Reads a [`ContextSensitiveGrammar`](@ref) from the files at `grammarpath` and `constraintspath`. -The grammar path may also point to a [`ContextFreeGrammar`](@ref). - -!!! danger - Only open trusted grammars. - Parts of the grammar can be passed to Julia's `eval` function. -""" -function read_csg(grammarpath::AbstractString, constraintspath::AbstractString)::ContextSensitiveGrammar - g = read_cfg(grammarpath) - file = open(constraintspath) - constraints = deserialize(file) - close(file) - - return ContextSensitiveGrammar(g.rules, g.types, g.isterminal, - g.iseval, g.bytype, g.domains, g.childtypes, g.log_probabilities, constraints) -end - -""" - read_pcsg(grammarpath::AbstractString, constraintspath::AbstractString)::ContextSensitiveGrammar - -Reads a probabilistic [`ContextSensitiveGrammar`](@ref) from the files at `grammarpath` and `constraintspath`. -The grammar path may also point to a [`ContextFreeGrammar`](@ref). - -!!! danger - Only open trusted grammars. - Parts of the grammar can be passed to Julia's `eval` function. -""" -function read_pcsg(grammarpath::AbstractString, constraintspath::AbstractString)::ContextSensitiveGrammar - g = read_pcfg(grammarpath) - file = open(constraintspath) - constraints = deserialize(file) - close(file) - + return ContextSensitiveGrammar(g.rules, g.types, g.isterminal, g.iseval, g.bytype, g.domains, g.childtypes, g.log_probabilities, constraints) end diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index 5c8cc14..1401df8 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -1,15 +1,15 @@ -HerbCore.RuleNode(ind::Int, grammar::Grammar) = RuleNode(ind, nothing, [Hole(get_domain(grammar, type)) for type ∈ grammar.childtypes[ind]]) -HerbCore.RuleNode(ind::Int, _val::Any, grammar::Grammar) = RuleNode(ind, _val, [Hole(get_domain(grammar, type)) for type ∈ grammar.childtypes[ind]]) +HerbCore.RuleNode(ind::Int, grammar::AbstractGrammar) = RuleNode(ind, nothing, [Hole(get_domain(grammar, type)) for type ∈ grammar.childtypes[ind]]) +HerbCore.RuleNode(ind::Int, _val::Any, grammar::AbstractGrammar) = RuleNode(ind, _val, [Hole(get_domain(grammar, type)) for type ∈ grammar.childtypes[ind]]) rulesoftype(::Hole, ::Set{Int}) = Set{Int}() """ - rulesoftype(node::RuleNode, grammar::Grammar, ruletype::Symbol) + rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol) Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree. """ -rulesoftype(node::RuleNode, grammar::Grammar, ruletype::Symbol) = rulesoftype(node, Set{Int}(grammar[ruletype])) -rulesoftype(::Hole, ::Grammar, ::Symbol) = Set{Int}() +rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol) = rulesoftype(node, Set{Int}(grammar[ruletype])) +rulesoftype(::Hole, ::AbstractGrammar, ::Symbol) = Set{Int}() """ @@ -46,15 +46,15 @@ rulesoftype(::Hole, ::Set{Int}, ::RuleNode) = Set() rulesoftype(::Hole, ::Set{Int}, ::Hole) = Set() """ - rulesoftype(node::RuleNode, grammar::Grammar, ruletype::Symbol, ignoreNode::RuleNode) + rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree. !!! warning The `ignoreNode` must be a subtree of `node` for it to have an effect. """ -rulesoftype(node::RuleNode, grammar::Grammar, ruletype::Symbol, ignoreNode::RuleNode) = rulesoftype(node, Set(grammar[ruletype]), ignoreNode) -rulesoftype(::Hole, ::Grammar, ::Symbol, ::RuleNode) = Set() +rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) = rulesoftype(node, Set(grammar[ruletype]), ignoreNode) +rulesoftype(::Hole, ::AbstractGrammar, ::Symbol, ::RuleNode) = Set() """ swap_node(expr::AbstractRuleNode, new_expr::AbstractRuleNode, path::Vector{Int}) @@ -172,12 +172,12 @@ end """ - rulenode2expr(rulenode::RuleNode, grammar::Grammar) + rulenode2expr(rulenode::RuleNode, grammar::AbstractGrammar) Converts a [`RuleNode`](@ref) into a Julia expression corresponding to the rule definitions in the grammar. The returned expression can be evaluated with Julia semantics using `eval()`. """ -function rulenode2expr(rulenode::RuleNode, grammar::Grammar) +function rulenode2expr(rulenode::RuleNode, grammar::AbstractGrammar) root = (rulenode._val !== nothing) ? rulenode._val : deepcopy(grammar.rules[rulenode.ind]) if !grammar.isterminal[rulenode.ind] # not terminal @@ -187,22 +187,22 @@ function rulenode2expr(rulenode::RuleNode, grammar::Grammar) end -function _rulenode2expr(rulenode::Hole, grammar::Grammar) +function _rulenode2expr(rulenode::Hole, grammar::AbstractGrammar) # Find the index of the first element that is true index = findfirst(==(true), rulenode.domain) return isnothing(index) ? :Nothing : grammar.types[index] end -rulenode2expr(rulenode::Hole, grammar::Grammar) = _rulenode2expr(rulenode::Hole, grammar::Grammar) +rulenode2expr(rulenode::Hole, grammar::AbstractGrammar) = _rulenode2expr(rulenode::Hole, grammar::AbstractGrammar) -function _rulenode2expr(expr::Expr, rulenode::RuleNode, grammar::Grammar, j=0) +function _rulenode2expr(expr::Expr, rulenode::RuleNode, grammar::AbstractGrammar, j=0) for (k,arg) in enumerate(expr.args) if isa(arg, Expr) expr.args[k],j = _rulenode2expr(arg, rulenode, grammar, j) elseif haskey(grammar.bytype, arg) child = rulenode.children[j+=1] if isa(child, Hole) - expr.args[k] = _rulenode2expr(child, grammar) - continue + expr.args[k] = _rulenode2expr(child, grammar) + continue end expr.args[k] = (child._val !== nothing) ? child._val : deepcopy(grammar.rules[child.ind]) @@ -215,10 +215,13 @@ function _rulenode2expr(expr::Expr, rulenode::RuleNode, grammar::Grammar, j=0) end -function _rulenode2expr(typ::Symbol, rulenode::RuleNode, grammar::Grammar, j=0) +function _rulenode2expr(typ::Symbol, rulenode::RuleNode, grammar::AbstractGrammar, j=0) retval = typ if haskey(grammar.bytype, typ) child = rulenode.children[1] + if isa(child, Hole) + return retval, j + end retval = (child._val !== nothing) ? child._val : deepcopy(grammar.rules[child.ind]) if !grammar.isterminal[child.ind] @@ -232,20 +235,20 @@ end """ Calculates the log probability associated with a rulenode in a probabilistic grammar. """ -function rulenode_log_probability(node::RuleNode, grammar::Grammar) +function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar) log_probability(grammar, node.ind) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=1) end -rulenode_log_probability(::Hole, ::Grammar) = 1 +rulenode_log_probability(::Hole, ::AbstractGrammar) = 1 """ - iscomplete(grammar::Grammar, node::RuleNode) + iscomplete(grammar::AbstractGrammar, node::RuleNode) Returns true if the expression represented by the [`RuleNode`](@ref) is a complete expression, meaning that it is fully defined and doesn't have any [`Hole`](@ref)s. """ -function iscomplete(grammar::Grammar, node::RuleNode) +function iscomplete(grammar::AbstractGrammar, node::RuleNode) if isterminal(grammar, node) return true elseif isempty(node.children) @@ -256,57 +259,93 @@ function iscomplete(grammar::Grammar, node::RuleNode) end end -iscomplete(grammar::Grammar, ::Hole) = false +iscomplete(grammar::AbstractGrammar, ::Hole) = false """ - return_type(grammar::Grammar, node::RuleNode) + return_type(grammar::AbstractGrammar, node::RuleNode) Gives the return type or nonterminal symbol in the production rule used by `node`. """ -return_type(grammar::Grammar, node::RuleNode)::Symbol = grammar.types[node.ind] +return_type(grammar::AbstractGrammar, node::RuleNode)::Symbol = grammar.types[node.ind] """ - child_types(grammar::Grammar, node::RuleNode) + child_types(grammar::AbstractGrammar, node::RuleNode) Returns the list of child types (nonterminal symbols) in the production rule used by `node`. """ -child_types(grammar::Grammar, node::RuleNode)::Vector{Symbol} = grammar.childtypes[node.ind] +child_types(grammar::AbstractGrammar, node::RuleNode)::Vector{Symbol} = grammar.childtypes[node.ind] """ - isterminal(grammar::Grammar, node::RuleNode)::Bool + isterminal(grammar::AbstractGrammar, node::RuleNode)::Bool Returns true if the production rule used by `node` is terminal, i.e., does not contain any nonterminal symbols. """ -isterminal(grammar::Grammar, node::RuleNode)::Bool = grammar.isterminal[node.ind] +isterminal(grammar::AbstractGrammar, node::RuleNode)::Bool = grammar.isterminal[node.ind] """ - nchildren(grammar::Grammar, node::RuleNode)::Int + nchildren(grammar::AbstractGrammar, node::RuleNode)::Int Returns the number of children in the production rule used by `node`. """ -nchildren(grammar::Grammar, node::RuleNode)::Int = length(child_types(grammar, node)) +nchildren(grammar::AbstractGrammar, node::RuleNode)::Int = length(child_types(grammar, node)) """ - isvariable(grammar::Grammar, node::RuleNode)::Bool + isvariable(grammar::AbstractGrammar, node::RuleNode)::Bool -Returns true if the rule used by `node` represents a variable. +Return true if the rule used by `node` represents a variable in a program (essentially, an input to the program) """ -isvariable(grammar::Grammar, node::RuleNode)::Bool = grammar.isterminal[node.ind] && grammar.rules[node.ind] isa Symbol +isvariable(grammar::AbstractGrammar, node::RuleNode)::Bool = ( + grammar.isterminal[node.ind] && + grammar.rules[node.ind] isa Symbol && + !_is_defined_in_modules(grammar.rules[node.ind], [Main, Base]) +) +""" + isvariable(grammar::AbstractGrammar, node::RuleNode, mod::Module)::Bool + +Return true if the rule used by `node` represents a variable. + +Taking into account the symbols defined in the given module(s). +""" +isvariable(grammar::AbstractGrammar, node::RuleNode, mod::Module...)::Bool = ( + grammar.isterminal[node.ind] && + grammar.rules[node.ind] isa Symbol && + !_is_defined_in_modules(grammar.rules[node.ind], [mod..., Main, Base]) +) + +""" + isvariable(grammar::AbstractGrammar, ind::Int)::Bool -isvariable(grammar::Grammar, ind::Int)::Bool = grammar.isterminal[ind] && grammar.rules[ind] isa Symbol +Return true if the rule with index `ind` represents a variable. +""" +isvariable(grammar::AbstractGrammar, ind::Int)::Bool = ( + grammar.isterminal[ind] && + grammar.rules[ind] isa Symbol && + !_is_defined_in_modules(grammar.rules[ind], [Main, Base]) +) +""" + isvariable(grammar::AbstractGrammar, ind::Int, mod::Module)::Bool +Return true if the rule with index `ind` represents a variable. + +Taking into account the symbols defined in the given module(s). +""" +isvariable(grammar::AbstractGrammar, ind::Int, mod::Module...)::Bool = ( + grammar.isterminal[ind] && + grammar.rules[ind] isa Symbol && + !_is_defined_in_modules(grammar.rules[ind], [mod..., Main, Base]) +) """ - contains_returntype(node::RuleNode, grammar::Grammar, sym::Symbol, maxdepth::Int=typemax(Int)) + contains_returntype(node::RuleNode, grammar::AbstractGrammar, sym::Symbol, maxdepth::Int=typemax(Int)) Returns true if the tree rooted at `node` contains at least one node at depth less than `maxdepth` with the given return type or nonterminal symbol. """ -function contains_returntype(node::RuleNode, grammar::Grammar, sym::Symbol, maxdepth::Int=typemax(Int)) +function contains_returntype(node::RuleNode, grammar::AbstractGrammar, sym::Symbol, maxdepth::Int=typemax(Int)) maxdepth < 1 && return false if return_type(grammar, node) == sym return true @@ -319,7 +358,7 @@ function contains_returntype(node::RuleNode, grammar::Grammar, sym::Symbol, maxd return false end -function Base.display(rulenode::RuleNode, grammar::Grammar) +function Base.display(rulenode::RuleNode, grammar::AbstractGrammar) root = rulenode2expr(rulenode, grammar) if isa(root, Expr) walk_tree(root) diff --git a/src/utils.jl b/src/utils.jl index 76a937d..b35d062 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,13 +4,13 @@ AbstractTrees.printnode(io::IO, node::RuleNode) = print(io, node.ind) """ - mindepth_map(grammar::Grammar) + mindepth_map(grammar::AbstractGrammar) -Returns the minimum depth achievable for each production rule in the [`Grammar`](@ref). +Returns the minimum depth achievable for each production rule in the [`AbstractGrammar`](@ref). In other words, this function finds the depths of the lowest trees that can be made using each of the available production rules as a root. """ -function mindepth_map(grammar::Grammar) +function mindepth_map(grammar::AbstractGrammar) dmap0 = Int[isterminal(grammar,i) ? 1 : typemax(Int)/2 for i in eachindex(grammar.rules)] dmap1 = fill(-1, length(grammar.rules)) while dmap0 != dmap1 @@ -23,37 +23,37 @@ function mindepth_map(grammar::Grammar) end -function _mindepth(grammar::Grammar, rule_index::Int, dmap::AbstractVector{Int}) +function _mindepth(grammar::AbstractGrammar, rule_index::Int, dmap::AbstractVector{Int}) isterminal(grammar, rule_index) && return 1 return 1 + maximum([mindepth(grammar, ctyp, dmap) for ctyp in child_types(grammar, rule_index)]) end """ - mindepth(grammar::Grammar, typ::Symbol, dmap::AbstractVector{Int}) + mindepth(grammar::AbstractGrammar, typ::Symbol, dmap::AbstractVector{Int}) Returns the minimum depth achievable for a given nonterminal symbol. The minimum depth is the depth of the lowest tree that can be made using `typ` as a start symbol. `dmap` can be obtained from [`mindepth_map`](@ref). """ -function mindepth(grammar::Grammar, typ::Symbol, dmap::AbstractVector{Int}) +function mindepth(grammar::AbstractGrammar, typ::Symbol, dmap::AbstractVector{Int}) return minimum(dmap[grammar.bytype[typ]]) end """ SymbolTable -Data structure for mapping terminal symbols in the [`Grammar`](@ref) to their Julia interpretation. +Data structure for mapping terminal symbols in the [`AbstractGrammar`](@ref) to their Julia interpretation. """ const SymbolTable = Dict{Symbol,Any} """ - SymbolTable(grammar::Grammar, mod::Module=Main) + SymbolTable(grammar::AbstractGrammar, mod::Module=Main) Returns a [`SymbolTable`](@ref) populated with a mapping from symbols in the -[`Grammar`](@ref) to symbols in module `mod` or `Main`, if defined. +[`AbstractGrammar`](@ref) to symbols in module `mod` or `Main`, if defined. """ -function HerbGrammar.SymbolTable(grammar::Grammar, mod::Module=Main) +function HerbGrammar.SymbolTable(grammar::AbstractGrammar, mod::Module=Main) tab = SymbolTable() for rule in grammar.rules _add_to_symboltable!(tab, rule, mod) @@ -78,20 +78,24 @@ function _add_to_symboltable!(tab::SymbolTable, rule::Expr, mod::Module) return true end +function _apply_if_defined_in_modules(func::Function, s::Symbol, mods::Vector{Module}) + for mod in mods + if isdefined(mod, s) + func(mod, s) + return true + end + end + return false +end + +function _is_defined_in_modules(s::Symbol, mods::Vector{Module}) + _apply_if_defined_in_modules((mod, s) -> nothing, s, mods) +end function _add_to_symboltable!(tab::SymbolTable, s::Symbol, mod::Module) - if isdefined(mod, s) - tab[s] = getfield(mod, s) - return true - elseif isdefined(Base, s) - tab[s] = getfield(Base, s) - return true - elseif isdefined(Main, s) - tab[s] = getfield(Main, s) - return true - else - return false - end + _add_to_table! = (mod, s) -> tab[s] = getfield(mod, s) + + return _apply_if_defined_in_modules(_add_to_table!, s, [mod, Base, Main]) end diff --git a/test/runtests.jl b/test/runtests.jl index b64a4c3..f4a7f0a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,5 +2,6 @@ using HerbGrammar using Test @testset "HerbGrammar.jl" verbose=true begin - include("test_cfg.jl") + include("test_csg.jl") + include("test_rulenode_operators.jl") end diff --git a/test/test_cfg.jl b/test/test_cfg.jl deleted file mode 100644 index 4b069b1..0000000 --- a/test/test_cfg.jl +++ /dev/null @@ -1,89 +0,0 @@ -@testset verbose=true "CFGs" begin - @testset "creating grammars" begin - g₁ = @cfgrammar begin - Real = |(1:9) - end - @test g₁.rules == collect(1:9) - @test :Real ∈ g₁.types - - g₂ = @cfgrammar begin - Real = |([1,2,3]) - end - @test g₂.rules == [1,2,3] - - g₃ = @cfgrammar begin - Real = 1 | 2 | 3 - end - @test g₃.rules == [1,2,3] - end - - - @testset "adding rules to grammar" begin - g₁ = @cfgrammar begin - Real = |(1:2) - end - - # Basic adding - add_rule!(g₁, :(Real = 3)) - @test g₁.rules == [1, 2, 3] - - # Adding multiple rules in one line - add_rule!(g₁, :(Real = 4 | 5)) - @test g₁.rules == [1, 2, 3, 4, 5] - - # Adding already existing rules - add_rule!(g₁, :(Real = 5)) - @test g₁.rules == [1, 2, 3, 4, 5] - - # Adding multiple already existing rules - add_rule!(g₁, :(Real = |(1:9))) - @test g₁.rules == collect(1:9) - - # Adding other types - g₂ = @cfgrammar begin - Real = 1 | 2 | 3 - end - - add_rule!(g₂, :(Bool = Real ≤ Real)) - @test length(g₂.rules) == 4 - @test :Real ∈ g₂.types - @test :Bool ∈ g₂.types - @test g₂.rules[g₂.bytype[:Bool][1]] == :(Real ≤ Real) - @test g₂.childtypes[g₂.bytype[:Bool][1]] == [:Real, :Real] - - end - - - @testset "Writing and loading CFG to/from disk" begin - g₁ = @cfgrammar begin - Real = |(1:5) - Real = 6 | 7 | 8 - end - - store_cfg("toy_cfg.grammar", g₁) - g₂ = read_cfg("toy_cfg.grammar") - @test :Real ∈ g₂.types - @test g₂.rules == collect(1:8) - - # delete file afterwards - rm("toy_cfg.grammar") - end - - @testset "Writing and loading probabilistic CFG to/from disk" begin - g₁ = @pcfgrammar begin - 0.5 : Real = |(0:3) - 0.5 : Real = x - end - - store_cfg("toy_pcfg.grammar", g₁) - g₂ = read_pcfg("toy_pcfg.grammar") - @test :Real ∈ g₂.types - @test g₂.rules == [0, 1, 2, 3, :x] - @test g₂.log_probabilities == g₁.log_probabilities - - - # delete file afterwards - rm("toy_pcfg.grammar") - end - -end diff --git a/test/test_csg.jl b/test/test_csg.jl new file mode 100644 index 0000000..30864e2 --- /dev/null +++ b/test/test_csg.jl @@ -0,0 +1,188 @@ +@testset verbose=true "CSGs" begin + @testset "Create empty grammar" begin + g = @csgrammar begin end + @test isempty(g.rules) + @test isempty(g.types) + @test isempty(g.isterminal) + @test isempty(g.iseval) + @test isempty(g.bytype) + @test isempty(g.domains) + @test isempty(g.childtypes) + @test isnothing(g.log_probabilities) + end + + @testset "Creating grammars" begin + g₁ = @cfgrammar begin + Real = |(1:9) + end + @test g₁.rules == collect(1:9) + @test :Real ∈ g₁.types + + g₂ = @cfgrammar begin + Real = |([1,2,3]) + end + @test g₂.rules == [1,2,3] + + g₃ = @cfgrammar begin + Real = 1 | 2 | 3 + end + @test g₃.rules == [1,2,3] + end + + + @testset "Adding rules to grammar" begin + g₁ = @csgrammar begin + Real = |(1:2) + end + + # Basic adding + add_rule!(g₁, :(Real = 3)) + @test g₁.rules == [1, 2, 3] + + # Adding multiple rules in one line + add_rule!(g₁, :(Real = 4 | 5)) + @test g₁.rules == [1, 2, 3, 4, 5] + + # Adding already existing rules + add_rule!(g₁, :(Real = 5)) + @test g₁.rules == [1, 2, 3, 4, 5] + + # Adding multiple already existing rules + add_rule!(g₁, :(Real = |(1:9))) + @test g₁.rules == collect(1:9) + + # Adding other types + g₂ = @csgrammar begin + Real = 1 | 2 | 3 + end + + add_rule!(g₂, :(Bool = Real ≤ Real)) + @test length(g₂.rules) == 4 + @test :Real ∈ g₂.types + @test :Bool ∈ g₂.types + @test g₂.rules[g₂.bytype[:Bool][1]] == :(Real ≤ Real) + @test g₂.childtypes[g₂.bytype[:Bool][1]] == [:Real, :Real] + + @test_throws ArgumentError add_rule!(g₂, :(Real != Bool)) + end + + @testset "Merging two grammars" begin + g₁ = @csgrammar begin + Number = |(1:2) + Number = x + end + + g₂ = @csgrammar begin + Real = Real + Real + Real = Real * Real + end + + merge_grammars!(g₁, g₂) + + @test length(g₁.rules) == 5 + @test :Real ∈ g₁.types + end + + @testset "Writing and loading CSG to/from disk" begin + g₁ = @csgrammar begin + Real = |(1:5) + Real = 6 | 7 | 8 + end + + store_csg(g₁, "toy_cfg.grammar") + g₂ = read_csg("toy_cfg.grammar") + @test :Real ∈ g₂.types + @test g₂.rules == collect(1:8) + + # delete file afterwards + rm("toy_cfg.grammar") + end + + @testset "Writing and loading probabilistic CSG to/from disk" begin + g₁ = @pcsgrammar begin + 0.5 : Real = |(0:3) + 0.5 : Real = x + end + + store_csg(g₁, "toy_pcfg.grammar") + g₂ = read_pcsg("toy_pcfg.grammar") + @test :Real ∈ g₂.types + @test g₂.rules == [0, 1, 2, 3, :x] + @test g₂.log_probabilities == g₁.log_probabilities + + + # delete file afterwards + rm("toy_pcfg.grammar") + end + + @testset "creating probabilistic CSG" begin + g = @pcsgrammar begin + 0.5 : R = |(0:2) + 0.3 : R = x + 0.2 : B = true | false + end + + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + @test g.bytype[:R] == Int[1,2,3,4] + @test g.bytype[:B] == Int[5,6] + @test :R ∈ g.types && :B ∈ g.types + end + + @testset "creating a non-normalized PCSG" begin + g = @pcsgrammar begin + 0.5 : R = |(0:2) + 0.5 : R = x + 0.5 : B = true | false + end + + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + @test g.rules == [0, 1, 2, :x, :true, :false] + @test g.bytype[:R] == Int[1,2,3,4] + @test g.bytype[:B] == Int[5,6] + @test :R ∈ g.types && :B ∈ g.types + end + + @testset "Adding a rule to a probabilistic CSG" begin + g = @pcsgrammar begin + 0.5 : R = x + 0.5 : R = R + R + end + + add_rule!(g, 0.5, :(R = 1 | 2)) + + @test g.rules == [:x, :(R + R), 1, 2] + + add_rule!(g, 0.5, :(B = t | f)) + + @test g.bytype[:B] == Int[5, 6] + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + end + + @testset "Test that strict equality is used during rule creation" begin + g₁ = @csgrammar begin + R = x + R = R + R + end + + add_rule!(g₁, :(R = 1 | 2)) + + add_rule!(g₁,:(Bool = true)) + + @test all(g₁.rules .== [:x, :(R + R), 1, 2, true]) + + g₁ = @csgrammar begin + R = x + R = R + R + end + + add_rule!(g₁,:(Bool = true)) + + add_rule!(g₁, :(R = 1 | 2)) + + @test all(g₁.rules .== [:x, :(R + R), true, 1, 2]) + end + +end diff --git a/test/test_rulenode_operators.jl b/test/test_rulenode_operators.jl new file mode 100644 index 0000000..dee0c14 --- /dev/null +++ b/test/test_rulenode_operators.jl @@ -0,0 +1,18 @@ +module SomeDefinitions + a_variable_that_is_defined = 7 +end + +@testset verbose = true "RuleNode Operators" begin + @testset "Check if a symbol is a variable" begin + g₁ = @cfgrammar begin + Real = |(1:5) + Real = a_variable + Real = a_variable_that_is_defined + end + + @test !isvariable(g₁, RuleNode(5, g₁), SomeDefinitions) + @test isvariable(g₁, RuleNode(6, g₁), SomeDefinitions) + @test !isvariable(g₁, RuleNode(7, g₁), SomeDefinitions) + @test isvariable(g₁, RuleNode(7, g₁)) + end +end