C++ 中高数的模幂

发布于 2024-08-20 01:06:53 字数 902 浏览 10 评论 0原文

所以我最近一直致力于米勒-拉宾素性测试的实现。我将其限制在所有 32 位数字的范围内,因为这是一个只是为了好玩的项目,我正在做这个项目来熟悉 C++,并且我不想使用任何 64 位数字一会儿。额外的好处是,该算法对于所有 32 位数字都是确定性的,因此我可以显着提高效率,因为我确切地知道要测试哪些见证人。

因此,对于较小的数字,该算法效果非常好。然而,该过程的一部分依赖于模幂,即 (num ^ pow) % mod。例如,

3 ^ 2 % 5 = 
9 % 5 = 
4

下面是我用于模幂运算的代码:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned test;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = (test * num) % mod;
        num = (num * num) % mod;
    }

    return test;

}

正如您可能已经猜到的,当参数都是异常大的数字时,就会出现问题。例如,如果我想测试数字 673109 的素数,我将在某一时刻必须找到:

(2 ^ 168277) % 673109

现在 2 ^ 168277 是一个非常大的数字,并且在该过程中的某个地方它溢出了测试,这导致错误的评估。

另一方面,

诸如4000111222 ^ 3 % 1608

出于同样的原因,

之类的参数也会错误地计算。有没有人对模幂提出建议,以防止这种溢出和/或操纵它产生正确的结果? (在我看来,溢出只是模数的另一种形式,即 num % (UINT_MAX+1))

So I've been working recently on an implementation of the Miller-Rabin primality test. I am limiting it to a scope of all 32-bit numbers, because this is a just-for-fun project that I am doing to familiarize myself with c++, and I don't want to have to work with anything 64-bits for awhile. An added bonus is that the algorithm is deterministic for all 32-bit numbers, so I can significantly increase efficiency because I know exactly what witnesses to test for.

So for low numbers, the algorithm works exceptionally well. However, part of the process relies upon modular exponentiation, that is (num ^ pow) % mod. so, for example,

3 ^ 2 % 5 = 
9 % 5 = 
4

here is the code I have been using for this modular exponentiation:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned test;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = (test * num) % mod;
        num = (num * num) % mod;
    }

    return test;

}

As you might have already guessed, problems arise when the arguments are all exceptionally large numbers. For example, if I want to test the number 673109 for primality, I will at one point have to find:

(2 ^ 168277) % 673109

now 2 ^ 168277 is an exceptionally large number, and somewhere in the process it overflows test, which results in an incorrect evaluation.

on the reverse side, arguments such as

4000111222 ^ 3 % 1608

also evaluate incorrectly, for much the same reason.

Does anyone have suggestions for modular exponentiation in a way that can prevent this overflow and/or manipulate it to produce the correct result? (the way I see it, overflow is just another form of modulo, that is num % (UINT_MAX+1))

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(6

左秋 2024-08-27 01:06:53

平方求幂对于模幂仍然“有效”。您的问题不在于 2 ^ 168277 是一个异常大的数字,而是您的中间结果之一是一个相当大的数字(大于 2^32),因为 673109 大于 2^16 。

所以我认为以下内容就可以了。我可能错过了一个细节,但基本思想是有效的,这就是“真正的”加密代码如何进行大模幂运算(尽管不是使用 32 和 64 位数字,而是使用永远不必大于2 * log(模数)):

  • 像您一样,从平方开始求幂。
  • 以 64 位无符号整数执行实际平方。
  • 像您一样,在每一步减少模 673109 以回到 32 位范围内。

显然,如果你的 C++ 实现没有 64 位整数,那就有点尴尬了,尽管你总是可以伪造一个。

第 22 张幻灯片上有一个示例: http://www .cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf,尽管它使用非常小的数字(小于 2^16),所以它可能无法说明任何您尚未说明的内容知道。

您的另一个示例,如果您在开始之前将 40001112221608 减少,则 4000111222 ^ 3 % 1608 将适用于您当前的代码。 1608 足够小,您可以安全地将 32 位 int 中的任意两个 mod-1608 数字相乘。

Exponentiation by squaring still "works" for modulo exponentiation. Your problem isn't that 2 ^ 168277 is an exceptionally large number, it's that one of your intermediate results is a fairly large number (bigger than 2^32), because 673109 is bigger than 2^16.

So I think the following will do. It's possible I've missed a detail, but the basic idea works, and this is how "real" crypto code might do large mod-exponentiation (although not with 32 and 64 bit numbers, rather with bignums that never have to get bigger than 2 * log (modulus)):

  • Start with exponentiation by squaring, as you have.
  • Perform the actual squaring in a 64-bit unsigned integer.
  • Reduce modulo 673109 at each step to get back within the 32-bit range, as you do.

Obviously that's a bit awkward if your C++ implementation doesn't have a 64 bit integer, although you can always fake one.

There's an example on slide 22 here: http://www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf, although it uses very small numbers (less than 2^16), so it may not illustrate anything you don't already know.

Your other example, 4000111222 ^ 3 % 1608 would work in your current code if you just reduce 4000111222 modulo 1608 before you start. 1608 is small enough that you can safely multiply any two mod-1608 numbers in a 32 bit int.

倥絔 2024-08-27 01:06:53

我最近用 C++ 为 RSA 写了一些东西,不过有点混乱。

#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>

BigInteger::BigInteger() {
    digits.push_back(0);
    negative = false;
}

BigInteger::~BigInteger() {
}

void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sum_n_carry = 0;
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size()) {
        n = b.digits.size();
    }
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size()) {
            a_digit = a.digits[i];
        }
        if (i < (int)b.digits.size()) {
            b_digit = b.digits[i];
        }
        sum_n_carry += a_digit + b_digit;
        c.digits[i] = (sum_n_carry & 0xFFFF);
        sum_n_carry >>= 16;
    }
    if (sum_n_carry != 0) {
        putCarryInfront(c, sum_n_carry);
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
}

void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sub_n_borrow = 0;
    int n = a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        sub_n_borrow += a_digit - b_digit;
        if (sub_n_borrow >= 0) {
            c.digits[i] = sub_n_borrow;
            sub_n_borrow = 0;
        } else {
            c.digits[i] = 0x10000 + sub_n_borrow;
            sub_n_borrow = -1;
        }
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
}

int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    //std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
    for (int i = n-1; i >= 0; --i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        if (a_digit < b_digit) {
            //std::cout << "-1" << std::endl;
            return -1;
        } else if (a_digit > b_digit) {
            //std::cout << "+1" << std::endl;
            return +1;
        }
    }
    //std::cout << "0" << std::endl;
    return 0;
}

void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
    unsigned int mult_n_carry = 0;
    c.digits.clear();
    c.digits.resize(a.digits.size());
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = b;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        mult_n_carry += a_digit * b_digit;
        c.digits[i] = (mult_n_carry & 0xFFFF);
        mult_n_carry >>= 16;
    }
    if (mult_n_carry != 0) {
        putCarryInfront(c, mult_n_carry);
    }
    //std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
}

void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
    b.digits.resize(a.digits.size() + times);
    for (int i = 0; i < times; ++i) {
        b.digits[i] = 0;
    }
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i + times] = a.digits[i];
    }
}

void BigInteger::shiftRight(BigInteger& a) {
    //std::cout << "shr " << a.toString() << " == ";
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        a.digits[i] >>= 1;
        if (i+1 < (int)a.digits.size()) {
            if ((a.digits[i+1] & 0x1) != 0) {
                a.digits[i] |= 0x8000;
            }
        }
    }
    //std::cout << a.toString() << std::endl;
}

void BigInteger::shiftLeft(BigInteger& a) {
    bool lastBit = false;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        bool bit = (a.digits[i] & 0x8000) != 0;
        a.digits[i] <<= 1;
        if (lastBit)
            a.digits[i] |= 1;
        lastBit = bit;
    }
    if (lastBit) {
        a.digits.push_back(1);
    }
}

void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
    BigInteger b;
    b.negative = a.negative;
    b.digits.resize(a.digits.size() + 1);
    b.digits[a.digits.size()] = carry;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i] = a.digits[i];
    }
    a.digits.swap(b.digits);
}

void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
    c.digits.clear();
    c.digits.push_back(0);
    BigInteger two("2");
    BigInteger e = b;
    BigInteger f("1");
    BigInteger g = a;
    BigInteger one("1");
    while (cmpWithoutSign(g, e) >= 0) {
        shiftLeft(e);
        shiftLeft(f);
    }
    shiftRight(e);
    shiftRight(f);
    while (cmpWithoutSign(g, b) >= 0) {
        g -= e;
        c += f;
        while (cmpWithoutSign(g, e) < 0) {
            shiftRight(e);
            shiftRight(f);
        }
    }
    e = c;
    e *= b;
    f = a;
    f -= e;
    d = f;
}

BigInteger::BigInteger(const BigInteger& other) {
    digits = other.digits;
    negative = other.negative;
}

BigInteger::BigInteger(const char* other) {
    digits.push_back(0);
    negative = false;
    BigInteger ten;
    ten.digits[0] = 10;
    const char* c = other;
    bool make_negative = false;
    if (*c == '-') {
        make_negative = true;
        ++c;
    }
    while (*c != 0) {
        BigInteger digit;
        digit.digits[0] = *c - '0';
        *this *= ten;
        *this += digit;
        ++c;
    }
    negative = make_negative;
}

bool BigInteger::isOdd() const {
    return (digits[0] & 0x1) != 0;
}

BigInteger& BigInteger::operator=(const BigInteger& other) {
    if (this == &other) // handle self assignment
        return *this;
    digits = other.digits;
    negative = other.negative;
    return *this;
}

BigInteger& BigInteger::operator+=(const BigInteger& other) {
    BigInteger result;
    if (negative) {
        if (other.negative) {
            result.negative = true;
            addWithoutSign(result, *this, other);
        } else {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = false;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = true;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        }
    } else {
        if (other.negative) {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = true;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = false;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        } else {
            result.negative = false;
            addWithoutSign(result, *this, other);
        }
    }
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator-=(const BigInteger& other) {
    BigInteger neg_other = other;
    neg_other.negative = !neg_other.negative;
    return *this += neg_other;
}

BigInteger& BigInteger::operator*=(const BigInteger& other) {
    BigInteger result;
    for (int i = 0; i < (int)digits.size(); ++i) {
        BigInteger mult;
        multByDigitWithoutSign(mult, other, digits[i]);
        BigInteger shift;
        shiftLeftByBase(shift, mult, i);
        BigInteger add;
        addWithoutSign(add, result, shift);
        result = add;
    }
    if (negative != other.negative) {
        result.negative = true;
    } else {
        result.negative = false;
    }
    //std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator/=(const BigInteger& other) {
    BigInteger result, tmp;
    divideWithoutSign(result, tmp, *this, other);
    result.negative = (negative != other.negative);
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator%=(const BigInteger& other) {
    BigInteger c, d;
    divideWithoutSign(c, d, *this, other);
    *this = d;
    return *this;
}

bool BigInteger::operator>(const BigInteger& other) const {
    if (negative) {
        if (other.negative) {
            return cmpWithoutSign(*this, other) < 0;
        } else {
            return false;
        }
    } else {
        if (other.negative) {
            return true;
        } else {
            return cmpWithoutSign(*this, other) > 0;
        }
    }
}

BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
    BigInteger zero("0");
    BigInteger one("1");
    BigInteger e = exponent;
    BigInteger base = *this;
    *this = one;
    while (cmpWithoutSign(e, zero) != 0) {
        //std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
        if (e.isOdd()) {
            *this *= base;
            *this %= modulus;
        }
        shiftRight(e);
        base *= BigInteger(base);
        base %= modulus;
    }
    return *this;
}

std::string BigInteger::toString() const {
    std::ostringstream os;
    if (negative)
        os << "-";
    BigInteger tmp = *this;
    BigInteger zero("0");
    BigInteger ten("10");
    tmp.negative = false;
    std::stack<char> s;
    while (cmpWithoutSign(tmp, zero) != 0) {
        BigInteger tmp2, tmp3;
        divideWithoutSign(tmp2, tmp3, tmp, ten);
        s.push((char)(tmp3.digits[0] + '0'));
        tmp = tmp2;
    }
    while (!s.empty()) {
        os << s.top();
        s.pop();
    }
    /*
    for (int i = digits.size()-1; i >= 0; --i) {
        os << digits[i];
        if (i != 0) {
            os << ",";
        }
    }
    */
    return os.str();

以及一个示例用法。

BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")

// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);

它的速度也很快,并且位数不受限制。

I wrote something for this recently for RSA in C++, bit messy though.

#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>

BigInteger::BigInteger() {
    digits.push_back(0);
    negative = false;
}

BigInteger::~BigInteger() {
}

void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sum_n_carry = 0;
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size()) {
        n = b.digits.size();
    }
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size()) {
            a_digit = a.digits[i];
        }
        if (i < (int)b.digits.size()) {
            b_digit = b.digits[i];
        }
        sum_n_carry += a_digit + b_digit;
        c.digits[i] = (sum_n_carry & 0xFFFF);
        sum_n_carry >>= 16;
    }
    if (sum_n_carry != 0) {
        putCarryInfront(c, sum_n_carry);
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
}

void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sub_n_borrow = 0;
    int n = a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        sub_n_borrow += a_digit - b_digit;
        if (sub_n_borrow >= 0) {
            c.digits[i] = sub_n_borrow;
            sub_n_borrow = 0;
        } else {
            c.digits[i] = 0x10000 + sub_n_borrow;
            sub_n_borrow = -1;
        }
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
}

int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    //std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
    for (int i = n-1; i >= 0; --i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        if (a_digit < b_digit) {
            //std::cout << "-1" << std::endl;
            return -1;
        } else if (a_digit > b_digit) {
            //std::cout << "+1" << std::endl;
            return +1;
        }
    }
    //std::cout << "0" << std::endl;
    return 0;
}

void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
    unsigned int mult_n_carry = 0;
    c.digits.clear();
    c.digits.resize(a.digits.size());
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = b;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        mult_n_carry += a_digit * b_digit;
        c.digits[i] = (mult_n_carry & 0xFFFF);
        mult_n_carry >>= 16;
    }
    if (mult_n_carry != 0) {
        putCarryInfront(c, mult_n_carry);
    }
    //std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
}

void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
    b.digits.resize(a.digits.size() + times);
    for (int i = 0; i < times; ++i) {
        b.digits[i] = 0;
    }
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i + times] = a.digits[i];
    }
}

void BigInteger::shiftRight(BigInteger& a) {
    //std::cout << "shr " << a.toString() << " == ";
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        a.digits[i] >>= 1;
        if (i+1 < (int)a.digits.size()) {
            if ((a.digits[i+1] & 0x1) != 0) {
                a.digits[i] |= 0x8000;
            }
        }
    }
    //std::cout << a.toString() << std::endl;
}

void BigInteger::shiftLeft(BigInteger& a) {
    bool lastBit = false;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        bool bit = (a.digits[i] & 0x8000) != 0;
        a.digits[i] <<= 1;
        if (lastBit)
            a.digits[i] |= 1;
        lastBit = bit;
    }
    if (lastBit) {
        a.digits.push_back(1);
    }
}

void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
    BigInteger b;
    b.negative = a.negative;
    b.digits.resize(a.digits.size() + 1);
    b.digits[a.digits.size()] = carry;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i] = a.digits[i];
    }
    a.digits.swap(b.digits);
}

void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
    c.digits.clear();
    c.digits.push_back(0);
    BigInteger two("2");
    BigInteger e = b;
    BigInteger f("1");
    BigInteger g = a;
    BigInteger one("1");
    while (cmpWithoutSign(g, e) >= 0) {
        shiftLeft(e);
        shiftLeft(f);
    }
    shiftRight(e);
    shiftRight(f);
    while (cmpWithoutSign(g, b) >= 0) {
        g -= e;
        c += f;
        while (cmpWithoutSign(g, e) < 0) {
            shiftRight(e);
            shiftRight(f);
        }
    }
    e = c;
    e *= b;
    f = a;
    f -= e;
    d = f;
}

BigInteger::BigInteger(const BigInteger& other) {
    digits = other.digits;
    negative = other.negative;
}

BigInteger::BigInteger(const char* other) {
    digits.push_back(0);
    negative = false;
    BigInteger ten;
    ten.digits[0] = 10;
    const char* c = other;
    bool make_negative = false;
    if (*c == '-') {
        make_negative = true;
        ++c;
    }
    while (*c != 0) {
        BigInteger digit;
        digit.digits[0] = *c - '0';
        *this *= ten;
        *this += digit;
        ++c;
    }
    negative = make_negative;
}

bool BigInteger::isOdd() const {
    return (digits[0] & 0x1) != 0;
}

BigInteger& BigInteger::operator=(const BigInteger& other) {
    if (this == &other) // handle self assignment
        return *this;
    digits = other.digits;
    negative = other.negative;
    return *this;
}

BigInteger& BigInteger::operator+=(const BigInteger& other) {
    BigInteger result;
    if (negative) {
        if (other.negative) {
            result.negative = true;
            addWithoutSign(result, *this, other);
        } else {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = false;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = true;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        }
    } else {
        if (other.negative) {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = true;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = false;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        } else {
            result.negative = false;
            addWithoutSign(result, *this, other);
        }
    }
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator-=(const BigInteger& other) {
    BigInteger neg_other = other;
    neg_other.negative = !neg_other.negative;
    return *this += neg_other;
}

BigInteger& BigInteger::operator*=(const BigInteger& other) {
    BigInteger result;
    for (int i = 0; i < (int)digits.size(); ++i) {
        BigInteger mult;
        multByDigitWithoutSign(mult, other, digits[i]);
        BigInteger shift;
        shiftLeftByBase(shift, mult, i);
        BigInteger add;
        addWithoutSign(add, result, shift);
        result = add;
    }
    if (negative != other.negative) {
        result.negative = true;
    } else {
        result.negative = false;
    }
    //std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator/=(const BigInteger& other) {
    BigInteger result, tmp;
    divideWithoutSign(result, tmp, *this, other);
    result.negative = (negative != other.negative);
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator%=(const BigInteger& other) {
    BigInteger c, d;
    divideWithoutSign(c, d, *this, other);
    *this = d;
    return *this;
}

bool BigInteger::operator>(const BigInteger& other) const {
    if (negative) {
        if (other.negative) {
            return cmpWithoutSign(*this, other) < 0;
        } else {
            return false;
        }
    } else {
        if (other.negative) {
            return true;
        } else {
            return cmpWithoutSign(*this, other) > 0;
        }
    }
}

BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
    BigInteger zero("0");
    BigInteger one("1");
    BigInteger e = exponent;
    BigInteger base = *this;
    *this = one;
    while (cmpWithoutSign(e, zero) != 0) {
        //std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
        if (e.isOdd()) {
            *this *= base;
            *this %= modulus;
        }
        shiftRight(e);
        base *= BigInteger(base);
        base %= modulus;
    }
    return *this;
}

std::string BigInteger::toString() const {
    std::ostringstream os;
    if (negative)
        os << "-";
    BigInteger tmp = *this;
    BigInteger zero("0");
    BigInteger ten("10");
    tmp.negative = false;
    std::stack<char> s;
    while (cmpWithoutSign(tmp, zero) != 0) {
        BigInteger tmp2, tmp3;
        divideWithoutSign(tmp2, tmp3, tmp, ten);
        s.push((char)(tmp3.digits[0] + '0'));
        tmp = tmp2;
    }
    while (!s.empty()) {
        os << s.top();
        s.pop();
    }
    /*
    for (int i = digits.size()-1; i >= 0; --i) {
        os << digits[i];
        if (i != 0) {
            os << ",";
        }
    }
    */
    return os.str();

And an example usage.

BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")

// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);

Its fast too, and has unlimited number of digits.

各自安好 2024-08-27 01:06:53

有两件事:

  • 您是否使用了适当的数据类型?换句话说,UINT_MAX 是否允许您将 673109 作为参数?

不,它不会,因为在某一时刻你的代码不起作用,因为在某一时刻你有 num = 2^16 并且 num = ... 导致溢出。使用更大的数据类型来保存这个中间值。

  • 如何在每个可能的溢出机会处取模,例如:

    test = ((test % mod) * (num % mod)) % mod;

编辑:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned long long test;
    unsigned long long n = num;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = ((test % mod) * (n % mod)) % mod;
        n = ((n % mod) * (n % mod)) % mod;
    }

    return test; /* note this is potentially lossy */
}

int main(int argc, char* argv[])
{

    /* (2 ^ 168277) % 673109 */
    printf("%u\n", mod_pow(2, 168277, 673109));
    return 0;
}

Two things:

  • Are you using the appropriate data type? In other words, does UINT_MAX allow you to have 673109 as an argument?

No, it does not, since at one point you have Your code does not work because at one point you have num = 2^16 and the num = ... causes overflow. Use a bigger data type to hold this intermediate value.

  • How about taking modulo at every possible overflow oppertunity such as:

    test = ((test % mod) * (num % mod)) % mod;

Edit:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned long long test;
    unsigned long long n = num;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = ((test % mod) * (n % mod)) % mod;
        n = ((n % mod) * (n % mod)) % mod;
    }

    return test; /* note this is potentially lossy */
}

int main(int argc, char* argv[])
{

    /* (2 ^ 168277) % 673109 */
    printf("%u\n", mod_pow(2, 168277, 673109));
    return 0;
}
树深时见影 2024-08-27 01:06:53
    package playTime;

    public class play {

        public static long count = 0; 
        public static long binSlots = 10; 
        public static long y = 645; 
        public static long finalValue = 1; 
        public static long x = 11; 

        public static void main(String[] args){

            int[] binArray = new int[]{0,0,1,0,0,0,0,1,0,1};  

            x = BME(x, count, binArray); 

            System.out.print("\nfinal value:"+finalValue);

        }

        public static long BME(long x, long count, int[] binArray){

            if(count == binSlots){
                return finalValue; 
            }

            if(binArray[(int) count] == 1){
                finalValue = finalValue*x%y; 
            }

            x = (x*x)%y; 
            System.out.print("Array("+binArray[(int) count]+") "
                            +"x("+x+")" +" finalVal("+              finalValue + ")\n");

            count++; 


            return BME(x, count,binArray); 
        }

    }
    package playTime;

    public class play {

        public static long count = 0; 
        public static long binSlots = 10; 
        public static long y = 645; 
        public static long finalValue = 1; 
        public static long x = 11; 

        public static void main(String[] args){

            int[] binArray = new int[]{0,0,1,0,0,0,0,1,0,1};  

            x = BME(x, count, binArray); 

            System.out.print("\nfinal value:"+finalValue);

        }

        public static long BME(long x, long count, int[] binArray){

            if(count == binSlots){
                return finalValue; 
            }

            if(binArray[(int) count] == 1){
                finalValue = finalValue*x%y; 
            }

            x = (x*x)%y; 
            System.out.print("Array("+binArray[(int) count]+") "
                            +"x("+x+")" +" finalVal("+              finalValue + ")\n");

            count++; 


            return BME(x, count,binArray); 
        }

    }
堇年纸鸢 2024-08-27 01:06:53

LL 代表 long long int

LL power_mod(LL a, LL k) {
    if (k == 0)
        return 1;
    LL temp = power(a, k/2);
    LL res;

    res = ( ( temp % P ) * (temp % P) ) % P;
    if (k % 2 == 1)
        res = ((a % P) * (res % P)) % P;
    return res;
}

使用上面的递归函数来查找数字的 mod exp。这不会导致溢出,因为它是以自下而上的方式计算的。

示例测试运行:
a = 2k = 168277 显示输出为 518358,这是正确的,并且该函数在 O(log(k)) 时间内运行;

LL is for long long int

LL power_mod(LL a, LL k) {
    if (k == 0)
        return 1;
    LL temp = power(a, k/2);
    LL res;

    res = ( ( temp % P ) * (temp % P) ) % P;
    if (k % 2 == 1)
        res = ((a % P) * (res % P)) % P;
    return res;
}

Use the above recursive function for finding the mod exp of the number. This will not result in overflow because it calculates in a bottom up manner.

Sample test run for :
a = 2 and k = 168277 shows output to be 518358 which is correct and the function runs in O(log(k)) time;

莫言歌 2024-08-27 01:06:53

您可以使用以下恒等式:

(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)

尝试使用简单的方式并逐步改进。

    if (pow & 1)
        test = ((test % mod) * (num % mod)) % mod;
    num = ((num % mod) * (num % mod)) % mod;

You could use following identity:

(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)

Try using it straightforward way and incrementally improve.

    if (pow & 1)
        test = ((test % mod) * (num % mod)) % mod;
    num = ((num % mod) * (num % mod)) % mod;
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文