Skip to content

Commit

Permalink
proper refinement and exact 2d wasserstein distance
Browse files Browse the repository at this point in the history
  • Loading branch information
enricofacca committed Aug 1, 2024
1 parent 0cf473c commit cd55e15
Showing 1 changed file with 81 additions and 28 deletions.
109 changes: 81 additions & 28 deletions tests/unit/test_geodetic_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
nref = 0


def test_case(rows=8, cols=8, nref=0,
x_src=1, y_src=1,
x_dst=6, y_dst=6,
permeability_value=1.0):
def test_case(rows=8, cols=8,
x_src=1, y_src=1,
x_dst=6, y_dst=6,
permeability_value = 1.0, nref=0):

h=1.0/8.0

# Coarse src image
src_square_2d = np.zeros((rows, cols), dtype=float)
Expand All @@ -30,8 +31,6 @@ def test_case(rows=8, cols=8, nref=0,
src_image_2d = darsia.Image(src_square_2d, **meta_2d)

# Coarse dst image
x_dst = 6
y_dst = 6
dst_squares_2d = np.zeros((rows, cols), dtype=float)
dst_squares_2d[x_dst:x_dst+1, y_dst:y_dst+1] = 1
dst_squares_2d = zoom(dst_squares_2d, 2**nref, order=0)
Expand All @@ -44,6 +43,70 @@ def test_case(rows=8, cols=8, nref=0,
dst_image_2d.img /= geometry_2d.integrate(dst_image_2d)


bar_src = [(x_src+0.5)/8, (y_src+0.5)/8]
bar_dst = [(x_dst+0.5)/8, (y_dst+0.5)/8]
intermidate_points = (bar_src + bar_dst)/2


a=0.5-bar_src[0] # x distance from the source to the layer
b=1/8 # half of the width of the layer
l=bar_src[1]-intermidate_points[1] # y_distance from the source to intermidate point


# Compute the incidence point of the ray starting at the source and
# going to the intermidate point.
# Varaible name from snell'law Wikipedia page
from scipy.optimize import toms748

def h(x):
"""
Rearranged derivative of travel time from the source to the intermidate point
"""
return x**2 * ( (l-x)**2+b**2) - permeability_layer **2 * (l-x)**2 * (x**2+a**2)

def L(x):
" length of the shortest path from the source center to the intermidate point"
L = np.sqrt(x**2+a**2) + np.sqrt((l-x)**2+b**2)
return 2 * L


x_opt = toms748(h,0,l)
print(f"{x_opt=} {x_opt+1.5/8=} {h(x_opt)=} L={L(x_opt)}")

theta_1 = np.arctan(x_opt/a)
theta_2 = np.arctan((l-x_opt)/b)

# check result
if not np.close(np.sin(theta_1)/np.sin(theta_2),permeability_value):
raise ValueError("Angles do not satisfiy Snell Law")


opt_direction_1 = [np.cos(theta_1),-np.sin(theta_1)]
opt_direction_2 = [np.cos(theta_2),-np.sin(theta_2)]

if np.close(theta_1<np.pi/4):
L_max = 1/8 / np.cos(theta_1)
tdens_max = L_max * src_value

def tdens_1(x,y):
if 0<(x-x_src/8)<1/8 and 0<(y-y_src)<1/8:
# compute the distance to the upper/left boundary along
# the direction
dist = (x_src * h - x) / np.sin(theta_1)
return dist







Lx = abs(x_src-x_dst)/8
Ly = abs(y_src-y_dst)/8
L=np.sqrt((x_src-x_dst)**2+(y_src-y_dst)**2)/8






Expand All @@ -55,9 +118,9 @@ def test_case(rows=8, cols=8, nref=0,
y_src = 1
src_square_2d = np.zeros((rows, cols), dtype=float)
src_square_2d[x_src:x_src+1, y_src:y_src+1] = 1
src_square_2d = zoom(src_square_2d, 2**nref, order=0)
meta_2d = {"width": 1, "height": 1, "space_dim": 2, "scalar": True}
src_image_2d = darsia.Image(src_square_2d, **meta_2d)
src_image_2d = darsia.uniform_refinement(src_image_2d, nref)



Expand All @@ -66,15 +129,17 @@ def test_case(rows=8, cols=8, nref=0,
y_dst = 6
dst_squares_2d = np.zeros((rows, cols), dtype=float)
dst_squares_2d[x_dst:x_dst+1, y_dst:y_dst+1] = 1
dst_squares_2d = zoom(dst_squares_2d, 2**nref, order=0)
dst_image_2d = darsia.Image(dst_squares_2d, **meta_2d)
dst_image_2d = darsia.uniform_refinement(dst_image_2d, nref)

# Rescale
shape_meta_2d = src_image_2d.shape_metadata()
geometry_2d = darsia.Geometry(**shape_meta_2d)
src_image_2d.img /= geometry_2d.integrate(src_image_2d)
dst_image_2d.img /= geometry_2d.integrate(dst_image_2d)

src_value = np.max(src_image_2d.img)

# Reference value for comparison
true_distance_2d = 0.379543951823
n = src_image_2d.shape[0]
Expand All @@ -83,37 +148,27 @@ def test_case(rows=8, cols=8, nref=0,
x = -np.outer(x_approx,np.ones(src_image_2d.shape[1]))
y = -np.outer(np.ones(src_image_2d.shape[1]),x_approx)
opt_pot = x + y
print(opt_pot.shape)
print(src_image_2d.img.shape)
#np.savetxt("opt_pot.npy",opt_pot,fmt='%.2e')
true_distance_2d=np.tensordot((src_image_2d.img-dst_image_2d.img),opt_pot,axes=((0,1),(0,1)))/(opt_pot.size)
print(f"{true_distance_2d=}")


permeability_layer = 1e1
permeability_layer = 0.1
permeability_2d = np.ones((rows, cols), dtype=float)
permeability_2d[0:8, 3:5] = 10
permeability_2d = zoom(permeability_2d, 2**nref, order=0)
permeability_2d[0:8, 3:5] = permeability_layer
kappa_2d = 1.0 / permeability_2d
kappa_image_2d = darsia.Image(kappa_2d, **meta_2d)

kappa_image_2d = darsia.uniform_refinement(kappa_image_2d, nref)



a=(1.5)/8
b=1/8
l=2.5/8

def f(x):
return (x/(np.sqrt(x**2+a**2)))/((l-x)/(np.sqrt((l-x)**2+b**2)))*permeability_layer

# sovle using scipy newton
from scipy.optimize import newton
from scipy.optimize import toms748

def f(x):
return (x/(np.sqrt(x**2+a**2)))/((l-x)/(np.sqrt((l-x)**2+b**2)))- permeability_layer

def h(x):
return x**2 * ( (l-x)**2+b**2) - permeability_layer **2 * (l-x)**2 * (x**2+a**2)

Expand All @@ -122,18 +177,16 @@ def derivative_h(x):


def L(x):
L= np.sqrt(x**2+a**2) + np.sqrt((l-x)**2+b**2)
L= np.sqrt(x**2+a**2) + np.sqrt((l-x)**2+b**2) / permeability_layer
return 2 * L


# use scipy newton to solve the equation
x0 = 0.0#(1.5/8)
x_tom = toms748(h,0,l)

print(f"{x_tom=} {x_tom+1.5/8=} {h(x_tom)=} {f(x_tom)=} L={L(x_tom)}")



true_distance_2d = src_value * (1/8)**2 * L(x_tom)
print(f"{x_tom=} {x_tom+1.5/8=} {h(x_tom)=} L={L(x_tom)} {(5/8)/np.cos(np.pi/4.0)} {true_distance_2d=} {src_value=}")

Lx = abs(x_src-x_dst)/8
Ly = abs(y_src-y_dst)/8
Expand Down Expand Up @@ -385,8 +438,8 @@ def test_sinkhorn(method_key, reg_key, dim):
# options.update({"formulation": "full"})
options.update({"verbose": True})
options.update({"num_iter": 400})
options.update({"linear_solver": "direct"})
options.update({"kappa": kappa_2d})
options.update({"linear_solver": "ksp"})
options.update({"kappa": kappa_image_2d.img})
distance, info = darsia.wasserstein_distance(
src_image[dim],
dst_image[dim],
Expand Down

0 comments on commit cd55e15

Please sign in to comment.