Skip to content

Commit

Permalink
compute inner product of complex vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiaanbrand committed Jul 9, 2024
1 parent e35d373 commit 4395bc2
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 6 deletions.
11 changes: 7 additions & 4 deletions src/sylvan_aadd.c
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ TASK_IMPL_4(AADD_WGT, aadd_inner_product, AADD, a, AADD, b, BDDVAR, nvars, BDDVA
// TODO: allow for skipping variables (and multiply res w/ 2^{# skipped})
// (requires adding some wgt_from_int() function in wgt interface)
if (nextvar == nvars) {
return wgt_mul(AADD_WEIGHT(a), AADD_WEIGHT(b));
return wgt_mul(AADD_WEIGHT(a), wgt_conj(AADD_WEIGHT(b)));
}

// Get var(a) and var(b)
Expand All @@ -869,7 +869,7 @@ TASK_IMPL_4(AADD_WGT, aadd_inner_product, AADD, a, AADD, b, BDDVAR, nvars, BDDVA
if (cachenow) {
if (cache_get4(CACHE_AADD_INPROD, AADD_TARGET(a), AADD_TARGET(b), nextvar, nvars, &res)) {
res = wgt_mul(res, AADD_WEIGHT(a));
res = wgt_mul(res, AADD_WEIGHT(b));
res = wgt_mul(res, wgt_conj(AADD_WEIGHT(b)));
return res;
}
}
Expand All @@ -887,9 +887,12 @@ TASK_IMPL_4(AADD_WGT, aadd_inner_product, AADD, a, AADD, b, BDDVAR, nvars, BDDVA
cache_put4(CACHE_AADD_INPROD, AADD_TARGET(a), AADD_TARGET(b), nextvar, nvars, res);
}

// Multiply result with product of weights of a and b
// Multiply result with product of weights of a and (conjugate of) b
// (Note that we can compute the complex conjugate of |b> by taking the
// complex conjugate of all edge weights separately, since
// (w1 • w2)* = (w2* • w1*) and for scalars (w2* • w1*) = (w1* • w2*).)
res = wgt_mul(res, AADD_WEIGHT(a));
res = wgt_mul(res, AADD_WEIGHT(b));
res = wgt_mul(res, wgt_conj(AADD_WEIGHT(b)));
return res;
}

Expand Down
5 changes: 4 additions & 1 deletion src/sylvan_aadd.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ TASK_DECL_4(AADD, aadd_matvec_mult_rec, AADD, AADD, BDDVAR, BDDVAR);
TASK_DECL_4(AADD, aadd_matmat_mult_rec, AADD, AADD, BDDVAR, BDDVAR);


/* Computes inner product of two vectors a, b */
/**
* Computes inner product of two vectors <b|a>
* (Note that if b contains complex values, the complex conjugate is taken)
*/
#define aadd_inner_product(a,b,nvars) (RUN(aadd_inner_product,a,b,nvars,0))
TASK_DECL_4(AADD_WGT, aadd_inner_product, AADD, AADD, BDDVAR, BDDVAR);

Expand Down
19 changes: 19 additions & 0 deletions src/sylvan_edge_weights.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ _weight_lookup_ptr_f _weight_lookup_ptr;
init_one_zero_f init_one_zero;
weight_abs_f weight_abs;
weight_neg_f weight_neg;
weight_conj_f weight_conj;
weight_sqr_f weight_sqr;
weight_add_f weight_add;
weight_sub_f weight_sub;
Expand Down Expand Up @@ -81,6 +82,7 @@ void init_edge_weight_functions(edge_weight_type_t edge_weight_type)
init_one_zero = (init_one_zero_f) &init_complex_one_zero;
weight_abs = (weight_abs_f) &weight_complex_abs;
weight_neg = (weight_neg_f) &weight_complex_neg;
weight_conj = (weight_conj_f) &weight_complex_conj;
weight_sqr = (weight_sqr_f) &weight_complex_sqr;
weight_add = (weight_add_f) &weight_complex_add;
weight_sub = (weight_sub_f) &weight_complex_sub;
Expand Down Expand Up @@ -394,6 +396,23 @@ wgt_neg(AADD_WGT a)
return res;
}

AADD_WGT
wgt_conj(AADD_WGT a)
{
// special cases
if (a == AADD_ZERO || a == AADD_ONE || a == AADD_MIN_ONE) return a;

AADD_WGT res;

weight_t w = weight_malloc();
weight_value(a, w);
weight_conj(w);
res = weight_lookup_ptr(w);
free(w);

return res;
}

AADD_WGT
wgt_add(AADD_WGT a, AADD_WGT b)
{
Expand Down
3 changes: 3 additions & 0 deletions src/sylvan_edge_weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ typedef void (*init_one_zero_f)(void *wgt_store);
/* Arithmetic operations on edge weights */
typedef void (*weight_abs_f)(weight_t a); // a <-- |a|
typedef void (*weight_neg_f)(weight_t a); // a <-- -a
typedef void (*weight_conj_f)(weight_t a); // a <-- a*
typedef void (*weight_sqr_f)(weight_t a); // a <-- a^2
typedef void (*weight_add_f)(weight_t a, weight_t b); // a <-- a + b
typedef void (*weight_sub_f)(weight_t a, weight_t b); // a <-- a - b
Expand All @@ -96,6 +97,7 @@ extern _weight_lookup_ptr_f _weight_lookup_ptr;
extern init_one_zero_f init_one_zero;
extern weight_abs_f weight_abs;
extern weight_neg_f weight_neg;
extern weight_conj_f weight_conj;
extern weight_sqr_f weight_sqr;
extern weight_add_f weight_add;
extern weight_sub_f weight_sub;
Expand Down Expand Up @@ -136,6 +138,7 @@ void wgt_set_inverse_chaching(bool on);
/* Arithmetic operations on AADD_WGT's */
AADD_WGT wgt_abs(AADD_WGT a); // returns |a|
AADD_WGT wgt_neg(AADD_WGT a); // returns -a
AADD_WGT wgt_conj(AADD_WGT a); // returns a*
AADD_WGT wgt_add(AADD_WGT a, AADD_WGT b); // returns a + b
AADD_WGT wgt_sub(AADD_WGT a, AADD_WGT b); // returns a - b
AADD_WGT wgt_mul(AADD_WGT a, AADD_WGT b); // returns a * b
Expand Down
6 changes: 6 additions & 0 deletions src/sylvan_edge_weights_complex.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ weight_complex_neg(complex_t *a)
a->i = -(a->i);
}

void
weight_complex_conj(complex_t *a)
{
a->i = -(a->i);
}

void
weight_complex_sqr(complex_t *a)
{
Expand Down
1 change: 1 addition & 0 deletions src/sylvan_edge_weights_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void init_complex_one_zero(void *wgt_store);

void weight_complex_abs(complex_t *a);
void weight_complex_neg(complex_t *a);
void weight_complex_conj(complex_t *a);
void weight_complex_sqr(complex_t *a);
void weight_complex_add(complex_t *a, complex_t *b);
void weight_complex_sub(complex_t *a, complex_t *b);
Expand Down
9 changes: 8 additions & 1 deletion test/test_qmdd_basics.c
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ int test_vector_addition()

int test_inner_product()
{
QMDD q0, q1, q00, q01, q10, q11, q000, q001, q010, q100;
QMDD q0, q1, q00, q01, q10, q11, q000, q001, q010, q100, q000i, q001i;
bool x4[] = {0, 0, 0, 0};
BDDVAR nqubits = 4;

Expand All @@ -349,6 +349,10 @@ int test_inner_product()
q11 = aadd_plus(q1, q1); // [0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0]
q000 = aadd_plus(q00, q0); // [0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0]
q001 = aadd_plus(q00, q1); // [0 0 2 0 0 0 0 0 0 0 0 1 0 0 0 0]
q000i = aadd_bundle(AADD_TARGET(q000), wgt_mul(AADD_WEIGHT(q000), complex_lookup(0,1)));
// [0 0 3i 0 0 0 0 0 0 0 0 0 0 0 0 0]
q001i = aadd_bundle(AADD_TARGET(q001), wgt_mul(AADD_WEIGHT(q001), complex_lookup(0,1)));
// [0 0 2i 0 0 0 0 0 0 0 0 i 0 0 0 0]

test_assert(aadd_inner_product(q00, q00, nqubits) == complex_lookup(4.0, 0.0));
test_assert(aadd_inner_product(q01, q01, nqubits) == complex_lookup(2.0, 0.0));
Expand All @@ -360,6 +364,9 @@ int test_inner_product()
test_assert(aadd_inner_product(q01, q000, nqubits) == complex_lookup(3.0, 0.0));
test_assert(aadd_inner_product(q000, q001, nqubits) == complex_lookup(6.0, 0.0));
test_assert(aadd_inner_product(q01, q001, nqubits) == complex_lookup(3.0, 0.0));
test_assert(aadd_inner_product(q000i, q000i, nqubits) == complex_lookup(9.0, 0.0));
test_assert(aadd_inner_product(q001i, q001i, nqubits) == complex_lookup(5.0, 0.0));
test_assert(aadd_inner_product(q000i, q001i, nqubits) == complex_lookup(6.0, 0.0));

if (VERBOSE) printf("aadd inner product: ok\n");
return 0;
Expand Down

0 comments on commit 4395bc2

Please sign in to comment.