diff --git a/confluence/integrator.py b/confluence/integrator.py index 3fb1e16..b0064f4 100644 --- a/confluence/integrator.py +++ b/confluence/integrator.py @@ -30,6 +30,11 @@ def __init__(self, n, A0, H, W, S, dA, routing_table): self.nreaches = len(A0) self.A0 = np.array(A0) self.n = np.array(n) + w = np.mean(W, axis=0) + h = np.mean(H, axis=0) + if dA is None: + dA = np.array([(w[r] + W[np.argmin(H[:, r]), r]) / 2 * (h[r] - H[np.argmin(H[:, r]), r]) + for r in range(self.nreaches)]).T self.data = (H, W, S, dA) self.rivs = self._riverTopology(routing_table) @@ -55,12 +60,8 @@ def objective(self, x): H, W, S, dA = self.data w = np.mean(W, axis=0) h = np.mean(H, axis=0) - if dA is None: - dA = np.array([(w[r] + W[np.argmin(H[:, r]), r]) / 2 * (h[r] - H[np.argmin(H[:, r]), r]) - for r in range(self.nreaches)]).T - else: - dA = np.mean(dA, axis=0) - Q = 1 / n * (A0 + dA)**(5 / 3) * w**(-2 / 3) * np.mean(S, + da = np.mean(dA, axis=0) + Q = 1 / n * (A0 + da)**(5 / 3) * w**(-2 / 3) * np.mean(S, axis=0)**(1 / 2) error = [np.sqrt((Q[i] - np.sum(Q[j] for j in self.rivs[i]))**2) for i in self.rivs] return np.sum(error) @@ -72,12 +73,8 @@ def constraint(self, i, x): H, W, S, dA = self.data w = np.mean(W, axis=0) h = np.mean(H, axis=0) - if dA is None: - dA = np.array([(w[r] + W[np.argmin(H[:, r]), r]) / 2 * (h[r] - H[np.argmin(H[:, r]), r]) - for r in range(self.nreaches)]).T - else: - dA = np.mean(dA, axis=0) - Q = 1 / n * (A0 + dA)**(5 / 3) * w**(-2 / 3) * np.mean(S, + da = np.mean(dA, axis=0) + Q = 1 / n * (A0 + da)**(5 / 3) * w**(-2 / 3) * np.mean(S, axis=0)**(1 / 2) return Q[i] - np.sum([Q[j] for j in self.rivs[i]])