diff --git a/bigint.h b/bigint.h index bcd5662..429768d 100644 --- a/bigint.h +++ b/bigint.h @@ -37,6 +37,7 @@ #include #include #include +#include namespace BigInt { @@ -49,15 +50,22 @@ namespace BigInt { { if (!is_bigint(s)) throw std::runtime_error("Invalid Big Integer."); - str = s; - is_big = true; + if(s[0] == '-') + { + is_neg = true; + vec = string_to_vector(s.substr(1)); + } + else + { + vec = string_to_vector(s); + } } bigint(const char c) { int temp = static_cast(c); if (isdigit(temp)) { - base_repr = char_to_int(c); + *this = bigint(char_to_int(c)); } else { throw std::runtime_error("Invalid Big Integer has been fed."); } @@ -71,61 +79,65 @@ namespace BigInt { bigint(double n) : bigint(static_cast(n)) {} - bigint(long long n) : base_repr(n) {} + bigint(long long n) { + +// if ( n >= 1000000000000000000) + if ( n >= 1000000000000000000 || n <= -1000000000000000000) + { + vec.emplace_back(n / 1000000000000000000); + vec.emplace_back(n % 1000000000000000000); + } + else{ + vec.emplace_back(n); + } + + if (n < 0) + { + is_neg = true; + for (auto& x : vec) { x = std::abs(x); } + } + } bigint(const bigint &n) { *this = n; } - bigint(const char *string) : bigint(std::string(string)) {} + bigint(const char* string) : bigint(std::string(string)) {} + + bigint(std::vector n) {this->vec = n;} bigint& operator=(const bigint& other) { if (this == &other) return *this; - this->is_big = other.is_big; - if (this->is_big) - { - this->str = other.str; - } - else - { - this->base_repr = other.base_repr; - } + this->is_neg = other.is_neg; + this->vec = other.vec; return *this; } explicit operator int() const { - if (!is_big) { - return static_cast(base_repr); - } - std::stringstream ss(str); - int num; - ss >> num; - return num; + return static_cast(vec.back()); } friend std::ostream &operator<<(std::ostream &stream, const bigint &n) { - stream << (n.is_big ? n.str : std::to_string(n.base_repr)); + stream << vector_to_string(n.vec); return stream; } bigint operator+=(const bigint &rhs) { - if (this->is_big || rhs.is_big) - { - *this = add(this->is_big ? this->str : std::to_string(this->base_repr), rhs.is_big ? rhs.str : std::to_string(rhs.base_repr)); - } - else if (rhs.base_repr > 0 && this->base_repr > std::numeric_limits::max() - rhs.base_repr || - rhs < 0 && this->base_repr < std::numeric_limits::min() - rhs.base_repr) - { - *this = add(std::to_string(this->base_repr), std::to_string(rhs.base_repr)); - } - else - { - this->base_repr += rhs.base_repr; - } + *this = add(*this, rhs); + +// else if (rhs.base_repr > 0 && this->base_repr > std::numeric_limits::max() - rhs.base_repr || +// rhs < 0 && this->base_repr < std::numeric_limits::min() - rhs.base_repr) +// { +// *this = add(std::to_string(this->base_repr), std::to_string(rhs.base_repr)); +// } +// else +// { +// this->base_repr += rhs.base_repr; +// } return *this; } @@ -138,19 +150,18 @@ namespace BigInt { bigint operator-=(const bigint &rhs) { - if (this->is_big || rhs.is_big) - { - *this = subtract(this->is_big ? this->str : std::to_string(this->base_repr), rhs.is_big ? rhs.str : std::to_string(rhs.base_repr)); - } - else if ((rhs.base_repr < 0 && this->base_repr > std::numeric_limits::max() + rhs.base_repr) || - (rhs > 0 && this->base_repr < std::numeric_limits::min() + rhs.base_repr)) - { - *this = subtract(std::to_string(this->base_repr), std::to_string(rhs.base_repr)); - } - else - { - this->base_repr -= rhs.base_repr; - } + + *this = subtract(*this, rhs); + +// else if ((rhs.base_repr < 0 && this->base_repr > std::numeric_limits::max() + rhs.base_repr) || +// (rhs > 0 && this->base_repr < std::numeric_limits::min() + rhs.base_repr)) +// { +// *this = subtract(std::to_string(this->base_repr), std::to_string(rhs.base_repr)); +// } +// else +// { +// this->base_repr -= rhs.base_repr; +// } return *this; } @@ -163,23 +174,21 @@ namespace BigInt { bigint operator*=(const bigint &rhs) { - if (this->is_big || rhs.is_big) - { - *this = multiply(this->is_big ? this->str : std::to_string(this->base_repr), rhs.is_big ? rhs.str : std::to_string(rhs.base_repr)); - } - else if ((this->base_repr == -1 && rhs.base_repr == std::numeric_limits::min()) || - (rhs.base_repr == -1 && this->base_repr == std::numeric_limits::min()) || - (rhs.base_repr != 0 && - this->base_repr > std::numeric_limits::max() / rhs.base_repr) || - (rhs.base_repr != 0 && - this->base_repr < std::numeric_limits::min() / rhs.base_repr)) - { - *this = multiply(std::to_string(this->base_repr), std::to_string(rhs.base_repr)); - } - else - { - this->base_repr *= rhs.base_repr; - } + *this = multiply(*this, rhs); + +// else if ((this->base_repr == -1 && rhs.base_repr == std::numeric_limits::min()) || +// (rhs.base_repr == -1 && this->base_repr == std::numeric_limits::min()) || +// (rhs.base_repr != 0 && +// this->base_repr > std::numeric_limits::max() / rhs.base_repr) || +// (rhs.base_repr != 0 && +// this->base_repr < std::numeric_limits::min() / rhs.base_repr)) +// { +// *this = multiply(std::to_string(this->base_repr), std::to_string(rhs.base_repr)); +// } +// else +// { +// this->base_repr *= rhs.base_repr; +// } return *this; } @@ -192,15 +201,7 @@ namespace BigInt { bigint &operator/=(const bigint &rhs) { - if (this->is_big || rhs.is_big) - { - *this = divide(this->is_big ? this->str : std::to_string(this->base_repr), rhs.is_big ? rhs.str : std::to_string(rhs.base_repr)); - } - else - { - this->base_repr /= rhs.base_repr; - } - + *this = divide(*this, rhs); return *this; } @@ -213,15 +214,7 @@ namespace BigInt { bigint operator%=(const bigint &rhs) { - if (this->is_big || rhs.is_big) - { - *this = mod(this->is_big ? this->str : std::to_string(this->base_repr), rhs.is_big ? rhs.str : std::to_string(rhs.base_repr)); - } - else - { - this->base_repr %= rhs.base_repr; - } - + *this = mod(*this, rhs); return *this; } @@ -260,12 +253,11 @@ namespace BigInt { friend bool operator==(const bigint &l, const bigint &r) { - if (l.is_big || r.is_big) + if (l.is_neg != r.is_neg) { - return (l.is_big ? l.str : std::to_string(l.base_repr)) == - (r.is_big ? r.str : std::to_string(r.base_repr)); + return false; } - return l.base_repr == r.base_repr; + return l.vec == r.vec; } friend bool operator!=(const bigint &l, const bigint &r) @@ -273,15 +265,7 @@ namespace BigInt { friend bool operator<(const bigint &lhs, const bigint &rhs) { - if (lhs.is_big || rhs.is_big) - { - return less_than(lhs.is_big ? lhs.str : std::to_string(lhs.base_repr), rhs.is_big ? rhs.str : std::to_string(rhs.base_repr)); - } - - return lhs.base_repr < rhs.base_repr; - - - + return less_than(lhs, rhs); } friend bool operator>(const bigint &l, const bigint &r) @@ -295,10 +279,7 @@ namespace BigInt { explicit operator bool() const { - if (!this->is_big) { - return this->base_repr; - } - return this->str != "0" || !this->str.empty(); + return !(this->vec.empty()) || this->vec.front(); } friend std::hash; @@ -325,11 +306,13 @@ namespace BigInt { inline static bigint abs(const bigint &s) { - if (!s.is_big) { - return std::abs(s.base_repr); - } if (is_negative(s)) - return s.str.substr(1, s.str.length() - 1); + { + bigint temp = s; + temp.is_neg = false; + + return temp; + } return s; } @@ -359,22 +342,12 @@ namespace BigInt { inline static bool is_even(const bigint &input) { - if (!input.is_big) - { - return !(input.base_repr & 1); - } -#if __cplusplus >= 202002L - return (input.str.ends_with("0") || input.str.ends_with("2") || input.str.ends_with("4") || - input.str.ends_with("6") || input.str.ends_with("8")); -#else - auto back = input.str.back(); - return back == '0' || back == '2' || back == '4' || back == '6' || back == '8'; -#endif + return !(input.vec.back() & 1); } inline static bool is_negative(const bigint &input) { - return !input.is_big ? input.base_repr < 0 : input.str[0] == '-'; + return input.is_neg; } static bool is_prime(const bigint &); @@ -382,16 +355,7 @@ namespace BigInt { inline static bigint sum_of_digits(const bigint& input) { bigint sum; - if (input.is_big) - { - for (auto c : input.str) - { - sum += char_to_int(c); - } - } - else - { - auto base = input.base_repr; + for (auto base : input.vec) { for (sum = 0; base > 0; sum += base % 10, base /= 10); } return sum; @@ -400,12 +364,22 @@ namespace BigInt { static bigint random(size_t length); private: - bool is_big = false; - std::string str; - long long int base_repr{}; + std::vector vec{}; + bool is_neg{}; // Function Definitions for Internal Uses - static std::string trim(std::string); + + static bigint trim(const bigint& input) { + auto temp = input; + while (temp.vec.front() == 0){ + temp.vec.erase(temp.vec.begin()); + } + return temp; + } + + static std::vector string_to_vector(std::string input); + + static std::string vector_to_string(const std::vector& input); static bigint add(const bigint &, const bigint &); @@ -434,6 +408,8 @@ namespace BigInt { static bool is_bigint(const std::string &); + static int count_digits(const bigint&); + inline static int char_to_int(const char input) { return input - '0'; @@ -444,9 +420,11 @@ namespace BigInt { return input + '0'; } - inline static bigint negate(bigint input) + inline static bigint negate(const bigint& input) { - return input.str.insert(0, "-"); + bigint temp = input; + temp.is_neg = true; + return temp; } inline static bool less_than(const bigint& lhs, const bigint& rhs) @@ -461,12 +439,12 @@ namespace BigInt { return is_negative(lhs); } - if(lhs.str.length() == rhs.str.length()) + if(lhs.vec.size() == rhs.vec.size()) { - return lhs.str < rhs.str; + return lhs.vec < rhs.vec; } - return lhs.str.length() < rhs.str.length(); + return lhs.vec.size() < rhs.vec.size(); } }; @@ -482,8 +460,29 @@ namespace BigInt { return s.find_first_not_of("0123456789", 0) == std::string::npos; } + std::pair add_with_carry(long long lhs, long long rhs) + { + long long max_number = 1000000000000000000; + auto sum = lhs + rhs; + + if (sum >= max_number) + { + // Carry needs to happen + auto carry = sum / max_number; + auto result = sum % max_number; + return {carry, result}; + } + else + { + return {0, sum}; + } + } + inline bigint bigint::add(const bigint &lhs, const bigint &rhs) { + // Ensure LHS is larger than RHS, and both are positive + if (lhs == 0) return rhs; + if (rhs == 0) return lhs; if (is_negative(lhs) && is_negative(rhs)) { return negate(add(abs(lhs) ,abs(rhs))); @@ -496,209 +495,216 @@ namespace BigInt { { return lhs - abs(rhs); } + if (lhs < rhs) + { + return add(rhs, lhs); + } - // Actual string addition implementation - int lhs_length = lhs.str.length(); - int rhs_length = rhs.str.length(); - int max_length = std::max(lhs_length, rhs_length); - std::string result; - int carry = 0; + bigint full_rhs = rhs; + std::vector> carry_result(lhs.vec.size() + 1); - for (int i = 0; i < max_length; ++i) + // Fill the smaller to match the larger size + while (lhs.vec.size() > full_rhs.vec.size()) { - auto digit_1 = i < lhs_length ? char_to_int(lhs.str[lhs_length - 1 - i]) : 0; - auto digit_2 = i < rhs_length ? char_to_int(rhs.str[rhs_length - 1 - i]) : 0; + full_rhs.vec.insert(full_rhs.vec.begin(), 0); + } - int sum = digit_1 + digit_2 + carry; - carry = sum / 10; - sum = sum % 10; + std::transform(lhs.vec.rbegin(), lhs.vec.rend(), full_rhs.vec.rbegin(), carry_result.rbegin(), add_with_carry); - result.insert(result.begin(), int_to_char(sum)); + std::vector final(lhs.vec.size() + 1); + for (int i = carry_result.size() - 1; i >= 0; --i) { + final[i] += carry_result[i].second; + final[i - 1] += carry_result[i].first; } - if (carry > 0) - { - result.insert(result.begin(), int_to_char(carry)); - } + return trim(bigint(final)); + } - return result; + std::pair subtract_with_borrow(long long lhs, long long rhs) + { + long long max_number = 1000000000000000000; + if (lhs < rhs) + { + // Borrow needs to happen + auto result = (lhs + max_number) - rhs; + return {1, result}; // 1 represents a borrow + } + else + { + auto result = lhs - rhs; + return {0, result}; // 0 means no borrow + } } inline bigint bigint::subtract(const bigint &lhs, const bigint &rhs) { - if (lhs == rhs) - { - return 0; - } + // Ensure LHS is larger than RHS, and both are positive + if (rhs == 0) return lhs; + if (lhs == rhs) return 0; if (is_negative(lhs) && is_negative(rhs)) { return subtract(abs(rhs), abs(lhs)); } - if (is_negative(rhs)) { return add(lhs, abs(rhs)); } - if (is_negative(lhs)) { return add(lhs, negate(rhs)); } - if (lhs < rhs) { return negate(subtract(rhs, lhs)); } - // Actual string subtraction implementation - int lhs_length = lhs.str.size(); - int rhs_length = rhs.str.size(); - std::string result; - int borrow = 0; - - // Subtract digits from right to left - for (int i = 0; i < lhs_length; ++i) { - int digit1 = char_to_int(lhs.str[lhs_length - 1 - i]); - int digit2 = (i < rhs_length) ? char_to_int(rhs.str[rhs_length - 1 - i]) : 0; - - // Apply borrow if necessary - digit1 -= borrow; + bigint full_rhs = rhs; + std::vector> borrow_result(lhs.vec.size()); + // Fill the smaller to match the larger size + while (lhs.vec.size() > full_rhs.vec.size()) + { + full_rhs.vec.insert(full_rhs.vec.begin(), 0); + } - if (digit1 < digit2) - { - digit1 += 10; - borrow = 1; - } else { - borrow = 0; - } + std::transform(lhs.vec.rbegin(), lhs.vec.rend(), full_rhs.vec.rbegin(), borrow_result.rbegin(), + subtract_with_borrow); - int diff = digit1 - digit2; - result.insert(result.begin(), int_to_char(diff)); + std::vector final(lhs.vec.size()); + for (int i = borrow_result.size() - 1; i >= 0; --i) { + final[i] += borrow_result[i].second; + if (borrow_result[i].first) + final[i - 1] -= borrow_result[i].first; } - return trim(result);; + return trim(final); } inline bigint bigint::multiply(const bigint &lhs, const bigint &rhs) { - if (is_negative(lhs) && is_negative(rhs)) { + if (lhs == 0 || rhs == 0) return 0; + if (lhs == 1) return rhs; + if (rhs == 1) return lhs; + + if (is_negative(lhs) && is_negative(rhs)) + { return (abs(lhs) * abs(lhs)); } - if (is_negative(lhs) || is_negative(rhs)) { + if (is_negative(lhs) || is_negative(rhs)) + { return negate(abs(lhs) * abs(rhs)); } - - std::string ans = ""; - - int str1_len = lhs.str.length(); - int str2_len = rhs.str.length(); - std::vector result(str1_len + str2_len, 0); - int i_n1 = 0; - int i_n2 = 0; - for (int i = str1_len - 1; i >= 0; i--) { - int carry = 0; - int n1 = lhs.str[i] - '0'; - i_n2 = 0; - for (int j = str2_len - 1; j >= 0; j--) { - int n2 = rhs.str[j] - '0'; - int sum = n1 * n2 + result[i_n1 + i_n2] + carry; - carry = sum / 10; - result[i_n1 + i_n2] = sum % 10; - i_n2++; - } - if (carry > 0) { - result[i_n1 + i_n2] += carry; - } - i_n1++; - } - int i = result.size() - 1; - while (i >= 0 && result[i] == 0) { - i--; - } - if (i == -1) { - return 0; - } - while (i >= 0) { - ans += std::to_string(result[i--]); + if (lhs < rhs) + { + return multiply(rhs, lhs); } - return ans; - } + const long long base = 1000000000000000000LL; // 10^18 + std::vector result(lhs.vec.size() + rhs.vec.size(), 0); - inline bigint bigint::divide(const bigint &numerator, const bigint &denominator) - { // return arithmetic division of str1/str2 + auto lhs_it = lhs.vec.rbegin(); + auto rhs_it = rhs.vec.rbegin(); - if (denominator == 0) { - throw std::domain_error("Attempted to divide by zero."); - } - if (numerator == denominator) + for (auto it_lhs = lhs.vec.rbegin(); it_lhs != lhs.vec.rend(); ++it_lhs) { - return 1; + for (auto it_rhs = rhs.vec.rbegin(); it_rhs != rhs.vec.rend(); ++it_rhs) + { + // Calculate the product and the corresponding indices in the result vector + // use 128 bits to carefully store overflow + __int128 mul = static_cast<__int128>(*it_lhs) * static_cast<__int128>(*it_rhs); + auto pos_low_it = result.rbegin() + (std::distance(lhs.vec.rbegin(), it_lhs) + std::distance(rhs.vec.rbegin(), it_rhs)); + auto pos_high_it = pos_low_it + 1; + + // Add the product to the result vector + *pos_low_it += mul % base; + if (pos_high_it != result.rend()) + { + *pos_high_it += mul / base; + } + + // Handle carry + if (*pos_low_it >= base) { + if (pos_high_it != result.rend()) { + *pos_high_it += *pos_low_it / base; + } + *pos_low_it %= base; + } + } } - if (denominator == 1) { - return numerator; + // Handle carries for remaining positions + for (auto r_iter = result.rbegin(); r_iter != result.rend() - 1; ++r_iter) { + if (*r_iter >= base) + { + *(r_iter + 1) += *r_iter / base; + *r_iter %= base; + } } - if (numerator == 0) { - return 0; + return trim(result); + } + + + inline bigint bigint::divide(const bigint &numerator, const bigint &denominator) + { + if (denominator == 0) + { + throw std::domain_error("Attempted to divide by zero."); } + if (numerator == denominator) return 1; + if (denominator == 1) return numerator; + if (numerator == 0) return 0; if (is_negative(numerator) && is_negative(denominator)) { return divide(abs(numerator), abs(denominator)); } - else if (is_negative(numerator) || is_negative(denominator)) + if (is_negative(numerator) || is_negative(denominator)) { return negate(divide(abs(numerator), abs(denominator))); } - - if (denominator > numerator) + if (numerator < denominator) { return 0; } - - bigint count = "0"; - bigint temp = numerator; - while(temp >= denominator) + if (numerator.vec.size() <= 1) { - int lenDiff = temp.str.length() - denominator.str.length(); - if(lenDiff > 0 && temp.str[0] > denominator.str[0]) - { - count += pow(10, lenDiff); - temp -= denominator * pow(10, lenDiff); - } - else if(lenDiff > 0) - { - count += pow(10, lenDiff-1); - temp -= denominator * pow(10, lenDiff-1); - } - else - { - count++; - temp -= denominator; - } + return numerator.vec.back() / denominator.vec.back(); } - return count; - } + bigint remainder = numerator; + bigint quotient = 0; + auto count = count_digits(remainder) - count_digits(denominator) - 1; - inline std::string bigint::trim(std::string s) - { - if (s == "0") - return s; + auto numerator_size = pow(10, count); + + auto temp = denominator * numerator_size; + + while (denominator * numerator_size < remainder) + { + temp = denominator * numerator_size; + remainder -= temp; + quotient += numerator_size; + count = count_digits(remainder) - count_digits(denominator) - 1; - int i = s[0] == '-' ? 1 : 0; - while (s[i] == '0') { - s.erase(i, 1); + if (numerator_size <= 1) + { + quotient += remainder / denominator; + break; + } + if (remainder.vec.size() <= 1) + { + quotient += remainder.vec.back() / denominator.vec.back(); + break; + } + numerator_size = pow(10, count); } - return s; + return quotient; } inline bigint bigint::sqrt(const bigint &input) @@ -739,8 +745,8 @@ namespace BigInt { if (input == 1) return 0; - if (!input.is_big) { - return std::log2(input.base_repr); + if (input.vec.size() == 1) { + return std::log2(input.vec.back()); } bigint exponent = 0; @@ -766,11 +772,15 @@ namespace BigInt { if (input == 1) return 0; - if (!input.is_big) { - return std::log10(input.base_repr); + if (input.vec.size() == 1) { + return std::log10(input.vec.back()); + } + int count = 0; + for (auto number : input.vec) { + count += count_digits(number); } - return static_cast(input.str.length()) - 1; + return count - 1; } inline bigint bigint::logwithbase(const bigint &input, const bigint &base) @@ -827,7 +837,6 @@ namespace BigInt { return ans; } - inline bool bigint::is_prime(const bigint &s) { if (is_negative(s) || s == 1) @@ -875,6 +884,47 @@ namespace BigInt { return {str}; } + std::vector bigint::string_to_vector(std::string input) { + // Break into chunks of 18 characters + std::vector result; + int chunk_size = 18; + + if (input.size() > chunk_size) + { + // Pad the length to get appropriate sized chunks + while ( input.size() % chunk_size != 0) + { + input.insert(0, "0"); + } + } + for (int i = 0; i < input.size(); i+=chunk_size) + { + std::string temp_str = input.substr(i, chunk_size); + result.emplace_back(stoll(temp_str)); + } + + return result; + } + + std::string bigint::vector_to_string(const std::vector& input) { + std::stringstream ss; + bool first = true; + for (auto partial : input) { + if (first) { + ss << partial; // No padding for the first number + first = false; + } else { + ss << std::setw(18) << std::setfill('0') << partial; // Pad to 18 digits + } + } + return ss.str(); + } + + int bigint::count_digits(const bigint & input) { + std::string my_string = vector_to_string(input.vec); + return my_string.length() - 1; + } + } // namespace::BigInt namespace std { @@ -882,7 +932,14 @@ namespace std { struct hash { std::size_t operator()(const BigInt::bigint& input) const { - return input.is_big ? std::hash()(input.str) : std::hash()(input.base_repr); + std::size_t seed = input.vec.size(); + for(auto x : input.vec) { + x = ((x >> 16) ^ x) * 0x45d9f3b; + x = ((x >> 16) ^ x) * 0x45d9f3b; + x = (x >> 16) ^ x; + seed ^= x + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; } }; } diff --git a/tests/test.cpp b/tests/test.cpp index 61922e3..12f5a7a 100644 --- a/tests/test.cpp +++ b/tests/test.cpp @@ -252,7 +252,7 @@ TEST(BigInt_test, Speed_Tests) } t2 = high_resolution_clock::now(); std::cout << " Multiplications: " << "[" << number_size << "] " - << formatTime(duration_cast(t2 - t1).count()) ; + << formatTime(duration_cast(t2 - t1).count()) ; // Division timing t1 = high_resolution_clock::now();