From 4395bc2361f67ec23d9ccabc2e261be8c88ac778 Mon Sep 17 00:00:00 2001 From: Sebastiaan Brand Date: Tue, 9 Jul 2024 16:33:48 +0200 Subject: [PATCH] compute inner product of complex vectors --- src/sylvan_aadd.c | 11 +++++++---- src/sylvan_aadd.h | 5 ++++- src/sylvan_edge_weights.c | 19 +++++++++++++++++++ src/sylvan_edge_weights.h | 3 +++ src/sylvan_edge_weights_complex.c | 6 ++++++ src/sylvan_edge_weights_complex.h | 1 + test/test_qmdd_basics.c | 9 ++++++++- 7 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/sylvan_aadd.c b/src/sylvan_aadd.c index 15ca12e..07cfb58 100644 --- a/src/sylvan_aadd.c +++ b/src/sylvan_aadd.c @@ -851,7 +851,7 @@ TASK_IMPL_4(AADD_WGT, aadd_inner_product, AADD, a, AADD, b, BDDVAR, nvars, BDDVA // TODO: allow for skipping variables (and multiply res w/ 2^{# skipped}) // (requires adding some wgt_from_int() function in wgt interface) if (nextvar == nvars) { - return wgt_mul(AADD_WEIGHT(a), AADD_WEIGHT(b)); + return wgt_mul(AADD_WEIGHT(a), wgt_conj(AADD_WEIGHT(b))); } // Get var(a) and var(b) @@ -869,7 +869,7 @@ TASK_IMPL_4(AADD_WGT, aadd_inner_product, AADD, a, AADD, b, BDDVAR, nvars, BDDVA if (cachenow) { if (cache_get4(CACHE_AADD_INPROD, AADD_TARGET(a), AADD_TARGET(b), nextvar, nvars, &res)) { res = wgt_mul(res, AADD_WEIGHT(a)); - res = wgt_mul(res, AADD_WEIGHT(b)); + res = wgt_mul(res, wgt_conj(AADD_WEIGHT(b))); return res; } } @@ -887,9 +887,12 @@ TASK_IMPL_4(AADD_WGT, aadd_inner_product, AADD, a, AADD, b, BDDVAR, nvars, BDDVA cache_put4(CACHE_AADD_INPROD, AADD_TARGET(a), AADD_TARGET(b), nextvar, nvars, res); } - // Multiply result with product of weights of a and b + // Multiply result with product of weights of a and (conjugate of) b + // (Note that we can compute the complex conjugate of |b> by taking the + // complex conjugate of all edge weights separately, since + // (w1 • w2)* = (w2* • w1*) and for scalars (w2* • w1*) = (w1* • w2*).) res = wgt_mul(res, AADD_WEIGHT(a)); - res = wgt_mul(res, AADD_WEIGHT(b)); + res = wgt_mul(res, wgt_conj(AADD_WEIGHT(b))); return res; } diff --git a/src/sylvan_aadd.h b/src/sylvan_aadd.h index 790d337..d34da44 100644 --- a/src/sylvan_aadd.h +++ b/src/sylvan_aadd.h @@ -160,7 +160,10 @@ TASK_DECL_4(AADD, aadd_matvec_mult_rec, AADD, AADD, BDDVAR, BDDVAR); TASK_DECL_4(AADD, aadd_matmat_mult_rec, AADD, AADD, BDDVAR, BDDVAR); -/* Computes inner product of two vectors a, b */ +/** + * Computes inner product of two vectors + * (Note that if b contains complex values, the complex conjugate is taken) +*/ #define aadd_inner_product(a,b,nvars) (RUN(aadd_inner_product,a,b,nvars,0)) TASK_DECL_4(AADD_WGT, aadd_inner_product, AADD, AADD, BDDVAR, BDDVAR); diff --git a/src/sylvan_edge_weights.c b/src/sylvan_edge_weights.c index c3b3a1b..6d7b34c 100644 --- a/src/sylvan_edge_weights.c +++ b/src/sylvan_edge_weights.c @@ -35,6 +35,7 @@ _weight_lookup_ptr_f _weight_lookup_ptr; init_one_zero_f init_one_zero; weight_abs_f weight_abs; weight_neg_f weight_neg; +weight_conj_f weight_conj; weight_sqr_f weight_sqr; weight_add_f weight_add; weight_sub_f weight_sub; @@ -81,6 +82,7 @@ void init_edge_weight_functions(edge_weight_type_t edge_weight_type) init_one_zero = (init_one_zero_f) &init_complex_one_zero; weight_abs = (weight_abs_f) &weight_complex_abs; weight_neg = (weight_neg_f) &weight_complex_neg; + weight_conj = (weight_conj_f) &weight_complex_conj; weight_sqr = (weight_sqr_f) &weight_complex_sqr; weight_add = (weight_add_f) &weight_complex_add; weight_sub = (weight_sub_f) &weight_complex_sub; @@ -394,6 +396,23 @@ wgt_neg(AADD_WGT a) return res; } +AADD_WGT +wgt_conj(AADD_WGT a) +{ + // special cases + if (a == AADD_ZERO || a == AADD_ONE || a == AADD_MIN_ONE) return a; + + AADD_WGT res; + + weight_t w = weight_malloc(); + weight_value(a, w); + weight_conj(w); + res = weight_lookup_ptr(w); + free(w); + + return res; +} + AADD_WGT wgt_add(AADD_WGT a, AADD_WGT b) { diff --git a/src/sylvan_edge_weights.h b/src/sylvan_edge_weights.h index 4e1c939..9a04269 100644 --- a/src/sylvan_edge_weights.h +++ b/src/sylvan_edge_weights.h @@ -71,6 +71,7 @@ typedef void (*init_one_zero_f)(void *wgt_store); /* Arithmetic operations on edge weights */ typedef void (*weight_abs_f)(weight_t a); // a <-- |a| typedef void (*weight_neg_f)(weight_t a); // a <-- -a +typedef void (*weight_conj_f)(weight_t a); // a <-- a* typedef void (*weight_sqr_f)(weight_t a); // a <-- a^2 typedef void (*weight_add_f)(weight_t a, weight_t b); // a <-- a + b typedef void (*weight_sub_f)(weight_t a, weight_t b); // a <-- a - b @@ -96,6 +97,7 @@ extern _weight_lookup_ptr_f _weight_lookup_ptr; extern init_one_zero_f init_one_zero; extern weight_abs_f weight_abs; extern weight_neg_f weight_neg; +extern weight_conj_f weight_conj; extern weight_sqr_f weight_sqr; extern weight_add_f weight_add; extern weight_sub_f weight_sub; @@ -136,6 +138,7 @@ void wgt_set_inverse_chaching(bool on); /* Arithmetic operations on AADD_WGT's */ AADD_WGT wgt_abs(AADD_WGT a); // returns |a| AADD_WGT wgt_neg(AADD_WGT a); // returns -a +AADD_WGT wgt_conj(AADD_WGT a); // returns a* AADD_WGT wgt_add(AADD_WGT a, AADD_WGT b); // returns a + b AADD_WGT wgt_sub(AADD_WGT a, AADD_WGT b); // returns a - b AADD_WGT wgt_mul(AADD_WGT a, AADD_WGT b); // returns a * b diff --git a/src/sylvan_edge_weights_complex.c b/src/sylvan_edge_weights_complex.c index fc78814..81b6aa0 100644 --- a/src/sylvan_edge_weights_complex.c +++ b/src/sylvan_edge_weights_complex.c @@ -95,6 +95,12 @@ weight_complex_neg(complex_t *a) a->i = -(a->i); } +void +weight_complex_conj(complex_t *a) +{ + a->i = -(a->i); +} + void weight_complex_sqr(complex_t *a) { diff --git a/src/sylvan_edge_weights_complex.h b/src/sylvan_edge_weights_complex.h index 78cff1c..f10f126 100644 --- a/src/sylvan_edge_weights_complex.h +++ b/src/sylvan_edge_weights_complex.h @@ -16,6 +16,7 @@ void init_complex_one_zero(void *wgt_store); void weight_complex_abs(complex_t *a); void weight_complex_neg(complex_t *a); +void weight_complex_conj(complex_t *a); void weight_complex_sqr(complex_t *a); void weight_complex_add(complex_t *a, complex_t *b); void weight_complex_sub(complex_t *a, complex_t *b); diff --git a/test/test_qmdd_basics.c b/test/test_qmdd_basics.c index e1da74f..8fd6481 100644 --- a/test/test_qmdd_basics.c +++ b/test/test_qmdd_basics.c @@ -337,7 +337,7 @@ int test_vector_addition() int test_inner_product() { - QMDD q0, q1, q00, q01, q10, q11, q000, q001, q010, q100; + QMDD q0, q1, q00, q01, q10, q11, q000, q001, q010, q100, q000i, q001i; bool x4[] = {0, 0, 0, 0}; BDDVAR nqubits = 4; @@ -349,6 +349,10 @@ int test_inner_product() q11 = aadd_plus(q1, q1); // [0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0] q000 = aadd_plus(q00, q0); // [0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0] q001 = aadd_plus(q00, q1); // [0 0 2 0 0 0 0 0 0 0 0 1 0 0 0 0] + q000i = aadd_bundle(AADD_TARGET(q000), wgt_mul(AADD_WEIGHT(q000), complex_lookup(0,1))); + // [0 0 3i 0 0 0 0 0 0 0 0 0 0 0 0 0] + q001i = aadd_bundle(AADD_TARGET(q001), wgt_mul(AADD_WEIGHT(q001), complex_lookup(0,1))); + // [0 0 2i 0 0 0 0 0 0 0 0 i 0 0 0 0] test_assert(aadd_inner_product(q00, q00, nqubits) == complex_lookup(4.0, 0.0)); test_assert(aadd_inner_product(q01, q01, nqubits) == complex_lookup(2.0, 0.0)); @@ -360,6 +364,9 @@ int test_inner_product() test_assert(aadd_inner_product(q01, q000, nqubits) == complex_lookup(3.0, 0.0)); test_assert(aadd_inner_product(q000, q001, nqubits) == complex_lookup(6.0, 0.0)); test_assert(aadd_inner_product(q01, q001, nqubits) == complex_lookup(3.0, 0.0)); + test_assert(aadd_inner_product(q000i, q000i, nqubits) == complex_lookup(9.0, 0.0)); + test_assert(aadd_inner_product(q001i, q001i, nqubits) == complex_lookup(5.0, 0.0)); + test_assert(aadd_inner_product(q000i, q001i, nqubits) == complex_lookup(6.0, 0.0)); if (VERBOSE) printf("aadd inner product: ok\n"); return 0;