Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some cleanup and also aliasing support #19

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions src/parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ parse_identifier(token, qasm) = QasmExpression(:identifier, String(read_raw(toke
function parse_block_body(expr, tokens, stack, start, qasm)
is_scope = tokens[1][end] == lbrace
if is_scope
body = parse_scope(tokens, stack, start, qasm)
body_exprs = convert(Vector{QasmExpression}, collect(Iterators.reverse(body)))::Vector{QasmExpression}
scope_tokens = extract_expression(tokens, lbrace, rbrace, stack, start, qasm)
body = parse_qasm(scope_tokens, qasm, QasmExpression(:scope))
body_exprs = convert(Vector{QasmExpression}, collect(Iterators.reverse(body)))::Vector{QasmExpression}
foreach(body_expr->push!(body_exprs[1], body_expr), body_exprs[2:end])
push!(expr, body_exprs[1])
else # one line
Expand Down Expand Up @@ -104,9 +105,9 @@ function parse_function_def(tokens, stack, start, qasm)
return expr
end
function parse_gate_or_cal_def(head::Symbol, tokens, stack, start, qasm)
def_name = popfirst!(tokens)
def_name = popfirst!(tokens)
def_name[end] == identifier || throw(QasmParseError("$head must have a valid identifier as a name", stack, start, qasm))
def_name_id = parse_identifier(def_name, qasm)
def_name_id = parse_identifier(def_name, qasm)

def_args = parse_arguments_list(tokens, stack, start, qasm)
qubit_tokens = splice!(tokens, 1:findfirst(triplet->triplet[end]==lbrace, tokens)-1)
Expand Down Expand Up @@ -226,11 +227,6 @@ function extract_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, opener,
return extracted_tokens
end

function parse_scope(tokens, stack, start, qasm)
scope_tokens = extract_expression(tokens, lbrace, rbrace, stack, start, qasm)
return parse_qasm(scope_tokens, qasm, QasmExpression(:scope))
end

function parse_list_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm)
expr_list = QasmExpression[]
while !isempty(tokens) && first(tokens)[end] != semicolon
Expand Down Expand Up @@ -383,14 +379,13 @@ function expression_start(tokens, stack, start, qasm)
expr_head = parse_list_expression(interior_tokens, stack, start, qasm)
elseif start_token[end] == classical_type
type_tokens = pushfirst!(tokens, start_token)
raw_expr = parse_classical_type(type_tokens, stack, start, qasm)
raw_expr = parse_classical_type(type_tokens, stack, start, qasm)
expr_head = raw_expr
if !isempty(tokens) && first(tokens)[end] == lparen
interior = extract_expression(tokens, lparen, rparen, stack, start, qasm)
expr_head = QasmExpression(:cast, raw_expr, parse_expression(interior, stack, start, qasm))
elseif !isempty(tokens) && first(tokens)[end] == identifier
expr_head = QasmExpression(:classical_declaration, raw_expr, parse_expression(tokens, stack, start, qasm))
else
expr_head = raw_expr
end
elseif start_token[end] == waveform_token && next_token[end] != identifier
expr_head = QasmExpression(:waveform)
Expand Down Expand Up @@ -418,16 +413,14 @@ end
function parse_range(expr_head, tokens, stack, start, qasm)
popfirst!(tokens)
second_colon = findfirst(triplet->triplet[end] == colon, tokens)
step = QasmExpression(:integer_literal, 1)
if !isnothing(second_colon)
step_tokens = push!(splice!(tokens, 1:second_colon-1), (-1, Int32(-1), semicolon))
popfirst!(tokens) # colon
step = parse_expression(step_tokens, stack, start, qasm)::QasmExpression
else
step = QasmExpression(:integer_literal, 1)
end
if isempty(tokens) || first(tokens)[end] == semicolon # missing stop
stop = QasmExpression(:integer_literal, -1)
else
stop = QasmExpression(:integer_literal, -1)
if !isempty(tokens) && first(tokens)[end] != semicolon # missing stop
stop = parse_expression(tokens, stack, start, qasm)::QasmExpression
end
return QasmExpression(:range, QasmExpression[expr_head, step, stop])
Expand Down Expand Up @@ -485,11 +478,11 @@ function parse_unary_op(tokens, stack, start, qasm)
expr = QasmExpression(:complex_literal, -real(next_expr.args[1]) + im*imag(next_expr.args[1]))
end
elseif head(next_expr) == :binary_op && !next_token_is_paren
# replace first argument if next token isn't a paren
left_hand_side = next_expr.args[2]::QasmExpression
new_left_hand_side = QasmExpression(:unary_op, unary_op_symbol, left_hand_side)
next_expr.args[2] = new_left_hand_side
expr = next_expr
# replace first argument if next token isn't a paren
left_hand_side = next_expr.args[2]::QasmExpression
new_left_hand_side = QasmExpression(:unary_op, unary_op_symbol, left_hand_side)
next_expr.args[2] = new_left_hand_side
expr = next_expr
else
expr = QasmExpression(:unary_op, unary_op_symbol, next_expr)
end
Expand Down Expand Up @@ -718,6 +711,8 @@ function parse_qasm(clean_tokens::Vector{Tuple{Int64, Int32, Token}}, qasm::Stri
push!(stack, delay_expr)
elseif token == end_token
push!(stack, QasmExpression(:end))
elseif token == alias
push!(stack, QasmExpression(:alias, parse_expression(clean_tokens, stack, start, qasm)))
elseif token == identifier || token == builtin_gate
clean_tokens = pushfirst!(clean_tokens, (start, len, token))
expr = parse_expression(clean_tokens, stack, start, qasm)
Expand Down
190 changes: 123 additions & 67 deletions src/visitor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ function name(expr::QasmExpression)::String
head(expr) == :gate_call && return name(expr.args[1]::QasmExpression)
head(expr) == :gate_definition && return name(expr.args[1]::QasmExpression)
head(expr) == :classical_assignment && return name(expr.args[1].args[2]::QasmExpression)
head(expr) == :alias && return name(expr.args[1]::QasmExpression)
head(expr) == :hw_qubit && return replace(expr.args[1], "\$"=>"")
throw(QasmVisitorError("name not defined for expressions of type $(head(expr))"))
end
Expand All @@ -323,8 +324,14 @@ function _evaluate_qubits(::Val{:indexed_identifier}, v, qubit_expr::QasmExpress
haskey(mapping, qubit_name) || throw(QasmVisitorError("Missing input variable '$qubit_name'.", "NameError"))
qubit_ix = v(qubit_expr.args[2]::QasmExpression)
qubits = Iterators.flatmap(qubit_ix) do rq
haskey(mapping, qubit_name * "[$rq]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError"))
return mapping[qubit_name * "[$rq]"]
if rq >= 0
haskey(mapping, qubit_name * "[$rq]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError"))
return mapping[qubit_name * "[$rq]"]
else
qubit_size = length(mapping[qubit_name])
haskey(mapping, qubit_name * "[$(qubit_size + rq)]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError"))
return mapping[qubit_name * "[$(qubit_size + rq)]"]
end
end
return collect(qubits)
end
Expand Down Expand Up @@ -451,6 +458,64 @@ function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression)
return
end

function visit_function_call(v, expr, function_name)
function_def = function_defs(v)[function_name]
function_body = function_def.body::Vector{QasmExpression}
declared_args = only(function_def.arguments.args)::QasmExpression
provided_args = only(expr.args[2].args)::QasmExpression
function_v = QasmFunctionVisitor(v, declared_args, provided_args)
return_val = nothing
body_exprs::Vector{QasmExpression} = head(function_body[1]) == :scope ? function_body[1].args : function_body
for f_expr in body_exprs
if head(f_expr) == :return
return_val = function_v(f_expr.args[1])
else
function_v(f_expr)
end
end
# remap qubits and classical variables
function_args = if head(declared_args) == :array_literal
convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression}
else
declared_args
end
called_args = if head(provided_args) == :array_literal
convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression}
else
provided_args
end
reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args))
reverse_qubits_map = Dict{Int, Int}()
for variable in filter(v->head(v) ∈ (:identifier, :indexed_identifier), keys(reverse_arguments_map))
variable_name = name(variable)
if haskey(classical_defs(v), variable_name) && classical_defs(v)[variable_name].type isa SizedArray && head(reverse_arguments_map[variable]) != :const_declaration
inner_variable_name = name(reverse_arguments_map[variable])
new_val = classical_defs(function_v)[inner_variable_name].val
back_assignment = QasmExpression(:classical_assignment, QasmExpression(:binary_op, Symbol("="), variable, new_val))
v(back_assignment)
elseif haskey(qubit_defs(v), variable_name)
outer_context_map = only(evaluate_qubits(v, variable))
inner_context_map = only(evaluate_qubits(function_v, reverse_arguments_map[variable].args[1]))
reverse_qubits_map[inner_context_map] = outer_context_map
end
end
mapper = isempty(reverse_qubits_map) ? identity : ix->remap(ix, reverse_qubits_map)
push!(v, map(mapper, function_v.instructions))
return return_val
end

function declaration_init(v, expr::QasmExpression)
var_type = expr.args[1].args[1]
init = if var_type isa SizedNumber
undef
elseif var_type isa SizedArray
fill(undef, v(var_type.size))
elseif var_type isa SizedBitVector
falses(max(0, v(var_type.size)))
end
return init, var_type
end

(v::AbstractVisitor)(i::Number) = i
(v::AbstractVisitor)(i::String) = i
(v::AbstractVisitor)(i::BitVector) = i
Expand Down Expand Up @@ -548,6 +613,58 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
isnothing(default) && throw(QasmVisitorError("no case matched and no default defined."))
foreach(v, convert(Vector{QasmExpression}, all_cases[default].args))
end
elseif head(program_expr) == :alias
alias_name = name(program_expr)
right_hand_side = program_expr.args[1].args[1].args[end]
if head(right_hand_side) == :binary_op
right_hand_side.args[1] == Symbol("++") || throw(QasmVisitorError("right hand side of alias must be either an identifier or concatenation"))
concat_left = right_hand_side.args[2]
concat_right = right_hand_side.args[3]
is_left_qubit = haskey(qubit_mapping(v), name(concat_left))
is_right_qubit = haskey(qubit_mapping(v), name(concat_right))
(is_left_qubit ⊻ is_right_qubit) && throw(QasmVisitorError("cannot concatenate qubit and classical arrays"))
if is_left_qubit
left_qs = v(concat_left)
right_qs = v(concat_right)
alias_qubits = collect(vcat(left_qs, right_qs))
qubit_size = length(alias_qubits)
qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size)
qubit_mapping(v)[alias_name] = alias_qubits
for qubit_i in 0:qubit_size-1
qubit_mapping(v)["$alias_name[$qubit_i]"] = [alias_qubits[qubit_i+1]]
end
else # both classical
throw(QasmVisitorError("classical array concatenation not yet supported!"))
end
elseif head(right_hand_side) == :identifier
referent_name = name(right_hand_side)
is_qubit = haskey(qubit_mapping(v), referent_name)
if is_qubit
qubit_defs(v)[alias_name] = qubit_defs(v)[referent_name]
qubit_mapping(v)[alias_name] = qubit_mapping(v)[referent_name]
qubit_size = length(qubit_mapping(v)[alias_name])
for qubit_i in 0:qubit_size-1
qubit_mapping(v)["$alias_name[$qubit_i]"] = qubit_mapping(v)["$referent_name[$qubit_i]"]
end
else
classical_defs(v)[alias_name] = classical_defs(v)[referent_name]
end
elseif head(right_hand_side) == :indexed_identifier
referent_name = name(right_hand_side)
is_qubit = haskey(qubit_mapping(v), referent_name)
if is_qubit
alias_qubits = v(right_hand_side)
qubit_size = length(alias_qubits)
qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size)
qubit_mapping(v)[alias_name] = collect(alias_qubits)
for qubit_i in 0:qubit_size-1
qubit_mapping(v)["$alias_name[$qubit_i]"] = [alias_qubits[qubit_i+1]]
end
else
referent = classical_defs(v)[referent_name]
classical_defs(v)[alias_name] = ClassicalVariable(alias_name, referent.type, view(referent.val, v(right_hand_side.args[end]) .+ 1), referent.is_const)
end
end
elseif head(program_expr) == :identifier
id_name = name(program_expr)
haskey(classical_defs(v), id_name) && return classical_defs(v)[id_name].val
Expand Down Expand Up @@ -610,7 +727,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
condition_value = while_v(program_expr.args[1])
end
elseif head(program_expr) == :classical_assignment
op = program_expr.args[1].args[1]::Symbol
op = program_expr.args[1].args[1]::Symbol
left_hand_side = program_expr.args[1].args[2]::QasmExpression
right_hand_side = program_expr.args[1].args[3]
var_name = name(left_hand_side)::String
Expand Down Expand Up @@ -651,14 +768,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
end
end
elseif head(program_expr) == :classical_declaration
var_type = program_expr.args[1].args[1]
init = if var_type isa SizedNumber
undef
elseif var_type isa SizedArray
fill(undef, v(var_type.size))
elseif var_type isa SizedBitVector
falses(max(0, v(var_type.size)))
end
init, var_type = declaration_init(v, program_expr)
# no initial value
if head(program_expr.args[2]) == :identifier
var_name = name(program_expr.args[2])
Expand All @@ -671,14 +781,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
end
elseif head(program_expr) == :const_declaration
head(program_expr.args[2]) == :classical_assignment || throw(QasmVisitorError("const declaration must assign an initial value."))
var_type = program_expr.args[1].args[1]
init = if var_type isa SizedNumber
undef
elseif var_type isa SizedArray
fill(undef, v(var_type.size))
elseif var_type isa SizedBitVector
falses(max(0, v(var_type.size)))
end
init, var_type = declaration_init(v, program_expr)
op, left_hand_side, right_hand_side = program_expr.args[2].args[1].args
var_name = name(left_hand_side)
v.classical_defs[var_name] = ClassicalVariable(var_name, var_type, init, false)
Expand Down Expand Up @@ -737,54 +840,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
return return_val[1]
else
hasfunction(v, function_name) || throw(QasmVisitorError("function $function_name not defined!"))
function_def = function_defs(v)[function_name]
function_body = function_def.body::Vector{QasmExpression}
declared_args = only(function_def.arguments.args)::QasmExpression
provided_args = only(program_expr.args[2].args)::QasmExpression
function_v = QasmFunctionVisitor(v, declared_args, provided_args)
return_val = nothing
body_exprs::Vector{QasmExpression} = head(function_body[1]) == :scope ? function_body[1].args : function_body
for f_expr in body_exprs
if head(f_expr) == :return
return_val = function_v(f_expr.args[1])
else
function_v(f_expr)
end
end
# remap qubits and classical variables
function_args = if head(declared_args) == :array_literal
convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression}
else
declared_args
end
called_args = if head(provided_args) == :array_literal
convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression}
else
provided_args
end
arguments_map = Dict{QasmExpression, QasmExpression}(zip(function_args, called_args))
reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args))
reverse_qubits_map = Dict{Int, Int}()
for variable in keys(reverse_arguments_map)
if head(variable) ∈ (:identifier, :indexed_identifier)
variable_name = name(variable)
if haskey(classical_defs(v), variable_name) && classical_defs(v)[variable_name].type isa SizedArray
if head(reverse_arguments_map[variable]) != :const_declaration
inner_variable_name = name(reverse_arguments_map[variable])
new_val = classical_defs(function_v)[inner_variable_name].val
back_assignment = QasmExpression(:classical_assignment, QasmExpression(:binary_op, Symbol("="), variable, new_val))
v(back_assignment)
end
elseif haskey(qubit_defs(v), variable_name)
outer_context_map = only(evaluate_qubits(v, variable))
inner_context_map = only(evaluate_qubits(function_v, reverse_arguments_map[variable].args[1]))
reverse_qubits_map[inner_context_map] = outer_context_map
end
end
end
mapper = isempty(reverse_qubits_map) ? identity : ix->remap(ix, reverse_qubits_map)
push!(v, map(mapper, function_v.instructions))
return return_val
return visit_function_call(v, program_expr, function_name)
end
elseif head(program_expr) == :function_definition
function_def = program_expr.args
Expand Down
Loading
Loading