Skip to content

Commit

Permalink
Fix broken test + Use gcd from STL
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Jan 18, 2024
1 parent 0906acd commit a650f5c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 36 deletions.
11 changes: 5 additions & 6 deletions include/math/rational.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
#define FREE_TENSOR_RATIONAL_H

#include <iostream>

#include <math/utils.h>
#include <numeric>

namespace freetensor {

Expand All @@ -14,7 +13,7 @@ template <class T> struct Rational {
if (p == 0) {
q_ = 1;
} else {
T g = gcd(p, q);
T g = std::gcd(p, q);
p_ /= g, q_ /= g;
if (q_ < 0) {
p_ = -p_, q_ = -q_;
Expand All @@ -27,14 +26,14 @@ template <class T> struct Rational {
}

friend Rational operator+(const Rational<T> &lhs, const Rational<T> &rhs) {
T g = gcd(lhs.q_, rhs.q_);
T g = std::gcd(lhs.q_, rhs.q_);
T p = rhs.q_ / g * lhs.p_ + lhs.q_ / g * rhs.p_;
T q = lhs.q_ / g * rhs.q_;
return Rational<T>{p, q};
}

friend Rational operator-(const Rational<T> &lhs, const Rational<T> &rhs) {
T g = gcd(lhs.q_, rhs.q_);
T g = std::gcd(lhs.q_, rhs.q_);
T p = rhs.q_ / g * lhs.p_ - lhs.q_ / g * rhs.p_;
T q = lhs.q_ / g * rhs.q_;
return Rational<T>{p, q};
Expand All @@ -58,7 +57,7 @@ template <class T> struct Rational {
}

friend auto operator<=>(const Rational<T> &lhs, const Rational<T> &rhs) {
T g = gcd(lhs.q_, rhs.q_);
T g = std::gcd(lhs.q_, rhs.q_);
return rhs.q_ / g * lhs.p_ <=> lhs.q_ / g * rhs.p_;
}

Expand Down
23 changes: 3 additions & 20 deletions include/math/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace freetensor {

template <typename T>
concept IntegralExceptBool = requires {
requires std::integral<T>;
requires !std::same_as<T, bool>;
};
requires std::integral<T>;
requires !std::same_as<T, bool>;
};

// NOTE: For floating-points, we always use double to deal with compile-time
// operations
Expand All @@ -38,23 +38,6 @@ inline auto mod(IntegralExceptBool auto a, IntegralExceptBool auto b) {
return m;
}

template <IntegralExceptBool T, IntegralExceptBool U> auto gcd(T _x, U _y) {
std::common_type_t<T, U> x = std::abs(_x), y = std::abs(_y);
if (x < y) {
std::swap(x, y);
}
do {
auto z = x % y;
x = y;
y = z;
} while (y);
return x;
}

inline auto lcm(IntegralExceptBool auto x, IntegralExceptBool auto y) {
return x / gcd(x, y) * y;
}

template <class T> T square(T x) { return x * x; }
inline bool square(bool x) { return x && x; }

Expand Down
1 change: 1 addition & 0 deletions src/analyze/comp_unique_bounds_combination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <container_utils.h>
#include <math/bounds.h>
#include <math/min_max.h>
#include <math/utils.h>

namespace freetensor {

Expand Down
3 changes: 2 additions & 1 deletion src/math/bounds.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <algorithm>
#include <climits>
#include <functional>
#include <numeric>
#include <type_traits>

#include <analyze/all_uses.h>
Expand All @@ -14,7 +15,7 @@ commonDenominator(const LinearExpr<Rational<int64_t>> &_lin) {
auto lin = _lin;
auto common = lin.bias_.q_;
for (auto &&[k, a] : lin.coeff_) {
common = lcm(common, k.q_);
common = std::lcm(common, k.q_);
}
lin.bias_.p_ *= common / lin.bias_.q_;
lin.bias_.q_ = common;
Expand Down
25 changes: 17 additions & 8 deletions src/pass/shrink_linear_indices.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include <algorithm>
#include <numeric>
#include <unordered_map>

#include <analyze/all_defs.h>
#include <analyze/analyze_linear.h>
#include <analyze/comp_transient_bounds.h>
#include <analyze/comp_unique_bounds_combination.h>
#include <container_utils.h>
#include <math/utils.h>
#include <mutator.h>
#include <pass/shrink_linear_indices.h>
#include <visitor.h>
Expand Down Expand Up @@ -158,18 +158,27 @@ Stmt shrinkLinearIndices(const Stmt &_ast, const ID &vardef) {
for (size_t n = bound.size(), i = n - 1; ~i; i--) {
int g = newCoeff[0];
for (size_t j = 1; j <= i; j++) {
g = gcd(g, newCoeff[j]);
g = std::gcd(g, newCoeff[j]);
}
int64_t l = LLONG_MAX, u = LLONG_MIN;
if (i + 1 < n) {
for (size_t j = i + 1; j < n; j++) {
l = std::min(l, newCoeff[j] * bound[j].second.lower_);
u = std::max(u, newCoeff[j] * bound[j].second.upper_);
// TODO: Use saturation arithmetic when C++26 is available for safer
// code when there may be overflow
for (size_t j = i + 1; j < n; j++) {
if (auto x = bound[j].second.lower_; x > LLONG_MIN) {
l = std::min(l, newCoeff[j] * x);
} else {
l = LLONG_MIN;
}
} else {
if (auto x = bound[j].second.upper_; x < LLONG_MAX) {
u = std::max(u, newCoeff[j] * x);
} else {
u = LLONG_MAX;
}
}
if (l > u) {
l = u = 0;
}
if (u - l + 1 < g) {
if (l > LLONG_MIN && u < LLONG_MAX && u - l + 1 < g) {
for (size_t j = 0; j <= i; j++) {
newCoeff[j] = newCoeff[j] / g * (u - l + 1);
}
Expand Down
3 changes: 2 additions & 1 deletion src/pass/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,8 @@ Expr SimplifyPass::visit(const IfExpr &_op) {
if (thenLin.coeff_[i].k_ == elseLin.coeff_[j].k_ &&
HashComparator{}(thenLin.coeff_[i].a_, elseLin.coeff_[j].a_)) {
common.coeff_.emplace_back(thenLin.coeff_[i]);
thenLin.coeff_[i++].k_ = elseLin.coeff_[j++].k_ = 0;
thenLin.coeff_[i].k_ = elseLin.coeff_[j].k_ = 0;
j++;
}
}
if (thenLin.bias_ == elseLin.bias_) {
Expand Down

0 comments on commit a650f5c

Please sign in to comment.