diff --git a/src/parser.jl b/src/parser.jl index bf45304..1d0f4d7 100644 --- a/src/parser.jl +++ b/src/parser.jl @@ -4,8 +4,9 @@ parse_identifier(token, qasm) = QasmExpression(:identifier, String(read_raw(toke function parse_block_body(expr, tokens, stack, start, qasm) is_scope = tokens[1][end] == lbrace if is_scope - body = parse_scope(tokens, stack, start, qasm) - body_exprs = convert(Vector{QasmExpression}, collect(Iterators.reverse(body)))::Vector{QasmExpression} + scope_tokens = extract_expression(tokens, lbrace, rbrace, stack, start, qasm) + body = parse_qasm(scope_tokens, qasm, QasmExpression(:scope)) + body_exprs = convert(Vector{QasmExpression}, collect(Iterators.reverse(body)))::Vector{QasmExpression} foreach(body_expr->push!(body_exprs[1], body_expr), body_exprs[2:end]) push!(expr, body_exprs[1]) else # one line @@ -104,9 +105,9 @@ function parse_function_def(tokens, stack, start, qasm) return expr end function parse_gate_or_cal_def(head::Symbol, tokens, stack, start, qasm) - def_name = popfirst!(tokens) + def_name = popfirst!(tokens) def_name[end] == identifier || throw(QasmParseError("$head must have a valid identifier as a name", stack, start, qasm)) - def_name_id = parse_identifier(def_name, qasm) + def_name_id = parse_identifier(def_name, qasm) def_args = parse_arguments_list(tokens, stack, start, qasm) qubit_tokens = splice!(tokens, 1:findfirst(triplet->triplet[end]==lbrace, tokens)-1) @@ -226,11 +227,6 @@ function extract_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, opener, return extracted_tokens end -function parse_scope(tokens, stack, start, qasm) - scope_tokens = extract_expression(tokens, lbrace, rbrace, stack, start, qasm) - return parse_qasm(scope_tokens, qasm, QasmExpression(:scope)) -end - function parse_list_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm) expr_list = QasmExpression[] while !isempty(tokens) && first(tokens)[end] != semicolon @@ -383,14 +379,13 @@ function expression_start(tokens, stack, start, qasm) expr_head = parse_list_expression(interior_tokens, stack, start, qasm) elseif start_token[end] == classical_type type_tokens = pushfirst!(tokens, start_token) - raw_expr = parse_classical_type(type_tokens, stack, start, qasm) + raw_expr = parse_classical_type(type_tokens, stack, start, qasm) + expr_head = raw_expr if !isempty(tokens) && first(tokens)[end] == lparen interior = extract_expression(tokens, lparen, rparen, stack, start, qasm) expr_head = QasmExpression(:cast, raw_expr, parse_expression(interior, stack, start, qasm)) elseif !isempty(tokens) && first(tokens)[end] == identifier expr_head = QasmExpression(:classical_declaration, raw_expr, parse_expression(tokens, stack, start, qasm)) - else - expr_head = raw_expr end elseif start_token[end] == waveform_token && next_token[end] != identifier expr_head = QasmExpression(:waveform) @@ -418,16 +413,14 @@ end function parse_range(expr_head, tokens, stack, start, qasm) popfirst!(tokens) second_colon = findfirst(triplet->triplet[end] == colon, tokens) + step = QasmExpression(:integer_literal, 1) if !isnothing(second_colon) step_tokens = push!(splice!(tokens, 1:second_colon-1), (-1, Int32(-1), semicolon)) popfirst!(tokens) # colon step = parse_expression(step_tokens, stack, start, qasm)::QasmExpression - else - step = QasmExpression(:integer_literal, 1) end - if isempty(tokens) || first(tokens)[end] == semicolon # missing stop - stop = QasmExpression(:integer_literal, -1) - else + stop = QasmExpression(:integer_literal, -1) + if !isempty(tokens) && first(tokens)[end] != semicolon # missing stop stop = parse_expression(tokens, stack, start, qasm)::QasmExpression end return QasmExpression(:range, QasmExpression[expr_head, step, stop]) @@ -485,11 +478,11 @@ function parse_unary_op(tokens, stack, start, qasm) expr = QasmExpression(:complex_literal, -real(next_expr.args[1]) + im*imag(next_expr.args[1])) end elseif head(next_expr) == :binary_op && !next_token_is_paren - # replace first argument if next token isn't a paren - left_hand_side = next_expr.args[2]::QasmExpression - new_left_hand_side = QasmExpression(:unary_op, unary_op_symbol, left_hand_side) - next_expr.args[2] = new_left_hand_side - expr = next_expr + # replace first argument if next token isn't a paren + left_hand_side = next_expr.args[2]::QasmExpression + new_left_hand_side = QasmExpression(:unary_op, unary_op_symbol, left_hand_side) + next_expr.args[2] = new_left_hand_side + expr = next_expr else expr = QasmExpression(:unary_op, unary_op_symbol, next_expr) end @@ -718,6 +711,8 @@ function parse_qasm(clean_tokens::Vector{Tuple{Int64, Int32, Token}}, qasm::Stri push!(stack, delay_expr) elseif token == end_token push!(stack, QasmExpression(:end)) + elseif token == alias + push!(stack, QasmExpression(:alias, parse_expression(clean_tokens, stack, start, qasm))) elseif token == identifier || token == builtin_gate clean_tokens = pushfirst!(clean_tokens, (start, len, token)) expr = parse_expression(clean_tokens, stack, start, qasm) diff --git a/src/visitor.jl b/src/visitor.jl index 5b3207f..238d95a 100644 --- a/src/visitor.jl +++ b/src/visitor.jl @@ -304,6 +304,7 @@ function name(expr::QasmExpression)::String head(expr) == :gate_call && return name(expr.args[1]::QasmExpression) head(expr) == :gate_definition && return name(expr.args[1]::QasmExpression) head(expr) == :classical_assignment && return name(expr.args[1].args[2]::QasmExpression) + head(expr) == :alias && return name(expr.args[1]::QasmExpression) head(expr) == :hw_qubit && return replace(expr.args[1], "\$"=>"") throw(QasmVisitorError("name not defined for expressions of type $(head(expr))")) end @@ -323,8 +324,14 @@ function _evaluate_qubits(::Val{:indexed_identifier}, v, qubit_expr::QasmExpress haskey(mapping, qubit_name) || throw(QasmVisitorError("Missing input variable '$qubit_name'.", "NameError")) qubit_ix = v(qubit_expr.args[2]::QasmExpression) qubits = Iterators.flatmap(qubit_ix) do rq - haskey(mapping, qubit_name * "[$rq]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError")) - return mapping[qubit_name * "[$rq]"] + if rq >= 0 + haskey(mapping, qubit_name * "[$rq]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError")) + return mapping[qubit_name * "[$rq]"] + else + qubit_size = length(mapping[qubit_name]) + haskey(mapping, qubit_name * "[$(qubit_size + rq)]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError")) + return mapping[qubit_name * "[$(qubit_size + rq)]"] + end end return collect(qubits) end @@ -451,6 +458,64 @@ function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression) return end +function visit_function_call(v, expr, function_name) + function_def = function_defs(v)[function_name] + function_body = function_def.body::Vector{QasmExpression} + declared_args = only(function_def.arguments.args)::QasmExpression + provided_args = only(expr.args[2].args)::QasmExpression + function_v = QasmFunctionVisitor(v, declared_args, provided_args) + return_val = nothing + body_exprs::Vector{QasmExpression} = head(function_body[1]) == :scope ? function_body[1].args : function_body + for f_expr in body_exprs + if head(f_expr) == :return + return_val = function_v(f_expr.args[1]) + else + function_v(f_expr) + end + end + # remap qubits and classical variables + function_args = if head(declared_args) == :array_literal + convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression} + else + declared_args + end + called_args = if head(provided_args) == :array_literal + convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression} + else + provided_args + end + reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args)) + reverse_qubits_map = Dict{Int, Int}() + for variable in filter(v->head(v) ∈ (:identifier, :indexed_identifier), keys(reverse_arguments_map)) + variable_name = name(variable) + if haskey(classical_defs(v), variable_name) && classical_defs(v)[variable_name].type isa SizedArray && head(reverse_arguments_map[variable]) != :const_declaration + inner_variable_name = name(reverse_arguments_map[variable]) + new_val = classical_defs(function_v)[inner_variable_name].val + back_assignment = QasmExpression(:classical_assignment, QasmExpression(:binary_op, Symbol("="), variable, new_val)) + v(back_assignment) + elseif haskey(qubit_defs(v), variable_name) + outer_context_map = only(evaluate_qubits(v, variable)) + inner_context_map = only(evaluate_qubits(function_v, reverse_arguments_map[variable].args[1])) + reverse_qubits_map[inner_context_map] = outer_context_map + end + end + mapper = isempty(reverse_qubits_map) ? identity : ix->remap(ix, reverse_qubits_map) + push!(v, map(mapper, function_v.instructions)) + return return_val +end + +function declaration_init(v, expr::QasmExpression) + var_type = expr.args[1].args[1] + init = if var_type isa SizedNumber + undef + elseif var_type isa SizedArray + fill(undef, v(var_type.size)) + elseif var_type isa SizedBitVector + falses(max(0, v(var_type.size))) + end + return init, var_type +end + (v::AbstractVisitor)(i::Number) = i (v::AbstractVisitor)(i::String) = i (v::AbstractVisitor)(i::BitVector) = i @@ -548,6 +613,58 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) isnothing(default) && throw(QasmVisitorError("no case matched and no default defined.")) foreach(v, convert(Vector{QasmExpression}, all_cases[default].args)) end + elseif head(program_expr) == :alias + alias_name = name(program_expr) + right_hand_side = program_expr.args[1].args[1].args[end] + if head(right_hand_side) == :binary_op + right_hand_side.args[1] == Symbol("++") || throw(QasmVisitorError("right hand side of alias must be either an identifier or concatenation")) + concat_left = right_hand_side.args[2] + concat_right = right_hand_side.args[3] + is_left_qubit = haskey(qubit_mapping(v), name(concat_left)) + is_right_qubit = haskey(qubit_mapping(v), name(concat_right)) + (is_left_qubit ⊻ is_right_qubit) && throw(QasmVisitorError("cannot concatenate qubit and classical arrays")) + if is_left_qubit + left_qs = v(concat_left) + right_qs = v(concat_right) + alias_qubits = collect(vcat(left_qs, right_qs)) + qubit_size = length(alias_qubits) + qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size) + qubit_mapping(v)[alias_name] = alias_qubits + for qubit_i in 0:qubit_size-1 + qubit_mapping(v)["$alias_name[$qubit_i]"] = [alias_qubits[qubit_i+1]] + end + else # both classical + throw(QasmVisitorError("classical array concatenation not yet supported!")) + end + elseif head(right_hand_side) == :identifier + referent_name = name(right_hand_side) + is_qubit = haskey(qubit_mapping(v), referent_name) + if is_qubit + qubit_defs(v)[alias_name] = qubit_defs(v)[referent_name] + qubit_mapping(v)[alias_name] = qubit_mapping(v)[referent_name] + qubit_size = length(qubit_mapping(v)[alias_name]) + for qubit_i in 0:qubit_size-1 + qubit_mapping(v)["$alias_name[$qubit_i]"] = qubit_mapping(v)["$referent_name[$qubit_i]"] + end + else + classical_defs(v)[alias_name] = classical_defs(v)[referent_name] + end + elseif head(right_hand_side) == :indexed_identifier + referent_name = name(right_hand_side) + is_qubit = haskey(qubit_mapping(v), referent_name) + if is_qubit + alias_qubits = v(right_hand_side) + qubit_size = length(alias_qubits) + qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size) + qubit_mapping(v)[alias_name] = collect(alias_qubits) + for qubit_i in 0:qubit_size-1 + qubit_mapping(v)["$alias_name[$qubit_i]"] = [alias_qubits[qubit_i+1]] + end + else + referent = classical_defs(v)[referent_name] + classical_defs(v)[alias_name] = ClassicalVariable(alias_name, referent.type, view(referent.val, v(right_hand_side.args[end]) .+ 1), referent.is_const) + end + end elseif head(program_expr) == :identifier id_name = name(program_expr) haskey(classical_defs(v), id_name) && return classical_defs(v)[id_name].val @@ -610,7 +727,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) condition_value = while_v(program_expr.args[1]) end elseif head(program_expr) == :classical_assignment - op = program_expr.args[1].args[1]::Symbol + op = program_expr.args[1].args[1]::Symbol left_hand_side = program_expr.args[1].args[2]::QasmExpression right_hand_side = program_expr.args[1].args[3] var_name = name(left_hand_side)::String @@ -651,14 +768,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) end end elseif head(program_expr) == :classical_declaration - var_type = program_expr.args[1].args[1] - init = if var_type isa SizedNumber - undef - elseif var_type isa SizedArray - fill(undef, v(var_type.size)) - elseif var_type isa SizedBitVector - falses(max(0, v(var_type.size))) - end + init, var_type = declaration_init(v, program_expr) # no initial value if head(program_expr.args[2]) == :identifier var_name = name(program_expr.args[2]) @@ -671,14 +781,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) end elseif head(program_expr) == :const_declaration head(program_expr.args[2]) == :classical_assignment || throw(QasmVisitorError("const declaration must assign an initial value.")) - var_type = program_expr.args[1].args[1] - init = if var_type isa SizedNumber - undef - elseif var_type isa SizedArray - fill(undef, v(var_type.size)) - elseif var_type isa SizedBitVector - falses(max(0, v(var_type.size))) - end + init, var_type = declaration_init(v, program_expr) op, left_hand_side, right_hand_side = program_expr.args[2].args[1].args var_name = name(left_hand_side) v.classical_defs[var_name] = ClassicalVariable(var_name, var_type, init, false) @@ -737,54 +840,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) return return_val[1] else hasfunction(v, function_name) || throw(QasmVisitorError("function $function_name not defined!")) - function_def = function_defs(v)[function_name] - function_body = function_def.body::Vector{QasmExpression} - declared_args = only(function_def.arguments.args)::QasmExpression - provided_args = only(program_expr.args[2].args)::QasmExpression - function_v = QasmFunctionVisitor(v, declared_args, provided_args) - return_val = nothing - body_exprs::Vector{QasmExpression} = head(function_body[1]) == :scope ? function_body[1].args : function_body - for f_expr in body_exprs - if head(f_expr) == :return - return_val = function_v(f_expr.args[1]) - else - function_v(f_expr) - end - end - # remap qubits and classical variables - function_args = if head(declared_args) == :array_literal - convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression} - else - declared_args - end - called_args = if head(provided_args) == :array_literal - convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression} - else - provided_args - end - arguments_map = Dict{QasmExpression, QasmExpression}(zip(function_args, called_args)) - reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args)) - reverse_qubits_map = Dict{Int, Int}() - for variable in keys(reverse_arguments_map) - if head(variable) ∈ (:identifier, :indexed_identifier) - variable_name = name(variable) - if haskey(classical_defs(v), variable_name) && classical_defs(v)[variable_name].type isa SizedArray - if head(reverse_arguments_map[variable]) != :const_declaration - inner_variable_name = name(reverse_arguments_map[variable]) - new_val = classical_defs(function_v)[inner_variable_name].val - back_assignment = QasmExpression(:classical_assignment, QasmExpression(:binary_op, Symbol("="), variable, new_val)) - v(back_assignment) - end - elseif haskey(qubit_defs(v), variable_name) - outer_context_map = only(evaluate_qubits(v, variable)) - inner_context_map = only(evaluate_qubits(function_v, reverse_arguments_map[variable].args[1])) - reverse_qubits_map[inner_context_map] = outer_context_map - end - end - end - mapper = isempty(reverse_qubits_map) ? identity : ix->remap(ix, reverse_qubits_map) - push!(v, map(mapper, function_v.instructions)) - return return_val + return visit_function_call(v, program_expr, function_name) end elseif head(program_expr) == :function_definition function_def = program_expr.args diff --git a/test/runtests.jl b/test/runtests.jl index f281fb2..93b5d6e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -192,6 +192,85 @@ Quasar.builtin_gates[] = complex_builtin_gates (type="gphase", arguments=InstructionArgument[2*π], targets=[0, 1], controls=[0=>0, 1=>1], exponent=1.0), ] end + @testset "Qubit aliasing" begin + qasm = """ + qubit[2] one; + qubit[2] two; + // Aliased register of four qubits + let concatenated = two ++ one; + // First qubit in aliased qubit array + let first = concatenated[0]; + // Last qubit in aliased qubit array + let last = concatenated[-1]; + let new_cat = concatenated; + h concatenated[2]; + x concatenated[1]; + y first; + z last; + i new_cat[0]; + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + visitor(parsed) + @test visitor.instructions == [(type="h", arguments=InstructionArgument[], targets=[0], controls=Pair{Int,Int}[], exponent=1.0), + (type="x", arguments=InstructionArgument[], targets=[3], controls=Pair{Int,Int}[], exponent=1.0), + (type="y", arguments=InstructionArgument[], targets=[2], controls=Pair{Int,Int}[], exponent=1.0), + (type="z", arguments=InstructionArgument[], targets=[1], controls=Pair{Int,Int}[], exponent=1.0), + (type="i", arguments=InstructionArgument[], targets=[2], controls=Pair{Int,Int}[], exponent=1.0), + ] + qasm = """ + qubit[2] one; + bit[2] two = "10"; + let concatenated = two ++ one; + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + @test_throws Quasar.QasmVisitorError("cannot concatenate qubit and classical arrays") visitor(parsed) + + qasm = """ + qubit[2] one; + qubit[2] two; + let concatenated = two - one; + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + @test_throws Quasar.QasmVisitorError("right hand side of alias must be either an identifier or concatenation") visitor(parsed) + end + @testset "Classical aliasing" begin + qasm = """ + bit[2] one = "01"; + bit[2] two = "10"; + // Aliased register of four bits + let concatenated = two ++ one; // "1001" + // First bit in aliased qubit array + let first = concatenated[0]; + // Last qubit in aliased qubit array + let last = concatenated[-1]; + let new_cat = concatenated; + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + @test_throws Quasar.QasmVisitorError("classical array concatenation not yet supported!") visitor(parsed) + #@test collect(visitor.classical_defs["concatenated"].val) == BitVector((true, false, false, true)) + #@test visitor.classical_defs["first"].val == true + #@test visitor.classical_defs["last"].val == true + #@test collect(visitor.classical_defs["new_cat"].val) == BitVector((true, false, false, true)) + # test that these are *references* + qasm = """ + bit[2] one = "01"; + bit[2] two = "10"; + // Aliased register of four bits + let concatenated = one; // "01" + // First bit in aliased qubit array + let first = concatenated[0]; + concatenated[1] = false; + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + visitor(parsed) + @test visitor.classical_defs["one"].val == BitVector((false, false)) + @test only(visitor.classical_defs["first"].val) == false + end @testset "Randomized Benchmarking" begin qasm = """ qubit[2] q; @@ -1400,31 +1479,31 @@ Quasar.builtin_gates[] = complex_builtin_gates @test visitor.classical_defs["array_2"].val == zeros(Int, 10) @test visitor.classical_defs["array_3"].val == [1, 2, 3, 4, 0, 6, 0, 8, 0, 10] end - # TODO - #=@testset "Rotation parameter expressions" begin - @testset "Operation: $operation" for (operation, state_vector) in + @testset "Rotation parameter expressions" begin + @testset "Operation: $operation, argument: $arg" for (operation, arg) in [ - ["rx(π) q[0];", [0, -im]], - ["rx(pi) q[0];", [0, -im]], - ["rx(ℇ) q[0];", [0.21007866, -0.97768449im]], - ["rx(euler) q[0];", [0.21007866, -0.97768449im]], - ["rx(τ) q[0];", [-1, 0]], - ["rx(tau) q[0];", [-1, 0]], - ["rx(pi + pi) q[0];", [-1, 0]], - ["rx(pi - pi) q[0];", [1, 0]], - ["rx(-pi + pi) q[0];", [1, 0]], - ["rx(pi * 2) q[0];", [-1, 0]], - ["rx(pi / 2) q[0];", [0.70710678, -0.70710678im]], - ["rx(-pi / 2) q[0];", [0.70710678, 0.70710678im]], - ["rx(-pi) q[0];", [0, im]], - ["rx(pi + 2 * pi) q[0];", [0, im]], - ["rx(pi + pi / 2) q[0];", [-0.70710678, -0.70710678im]], - ["rx((pi / 4) + (pi / 2) / 2) q[0];", [0.70710678, -0.70710678im]], - ["rx(0) q[0];", [1, 0]], - ["rx(0 + 0) q[0];", [1, 0]], - ["rx((1.1 + 2.04) / 2) q[0];", [0.70738827, -0.70682518im]], - ["rx((6 - 2.86) * 0.5) q[0];", [0.70738827, -0.70682518im]], - ["rx(pi ** 2) q[0];", [0.22058404, 0.97536797im]], + ["rx(π) q[0];", π], + ["rx(pi) q[0];", π], + ["rx(ℇ) q[0];", ℯ], + ["rx(euler) q[0];", ℯ], + ["rx(τ) q[0];", 2π], + ["rx(tau) q[0];", 2π], + ["rx(pi + pi) q[0];", 2π], + ["rx(pi - pi) q[0];", 0], + ["rx(-pi + pi) q[0];", 0], + ["rx(-pi - pi) q[0];", -2π], + ["rx(pi * 2) q[0];", 2π], + ["rx(pi / 2) q[0];", π/2], + ["rx(-pi / 2) q[0];", -π/2], + ["rx(-pi) q[0];", -π], + ["rx(pi + 2 * pi) q[0];", 3π], + ["rx(pi + pi / 2) q[0];", 3π/2], + ["rx((pi / 4) + (pi / 2) / 2) q[0];", π/2], + ["rx(0) q[0];", 0], + ["rx(0 + 0) q[0];", 0], + ["rx((1.1 + 2.04) / 2) q[0];", 3.14/2], + ["rx((6 - 2.86) * 0.5) q[0];", (6 - 2.86) * 0.5], + ["rx(pi ** 2) q[0];", π^2], ] qasm = """ OPENQASM 3.0; @@ -1432,8 +1511,12 @@ Quasar.builtin_gates[] = complex_builtin_gates qubit[1] q; $operation """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + visitor(parsed) + @test visitor.instructions[1] == (type="rx", arguments=InstructionArgument[arg], targets=[0], controls=Pair{Int,Int}[], exponent=1.0) end - end=# + end @testset "Qubits with variable as size" begin qasm_string = """ OPENQASM 3.0;