多项式 学习笔记

作者: xht37 分类: 笔记 发布时间: 2020-01-11 16:11

点击数:237

唉,只能功利的学学了……

多项式

长成

$$
f(x) = \sum_{i=0}^n f_ix^i
$$
这个鬼样子的式子就叫做多项式。

一般要求 $f_n \ne 0$,在这种情况下,$n$ 被称为多项式 $f(x)$ 的,记作 $\operatorname{deg} f$。

拉格朗日插值

【模板】P4781 【模板】拉格朗日插值

由小学知识可知 $n$ 个点 $(x_i,y_i)$ 可以唯一地确定一个多项式 $y = f(x)$。

现在,给定这 $n$ 个点,请你确定这个多项式,并求出 $f(k) \bmod 998244353$。

首先,有 $\operatorname{deg} f < n$。

构造 $n$ 个多项式 $g_{1\dots n}(x)$,其中 $g_i(x) = y_i\prod_{i \ne j} \frac{x-x_j}{x_i-x_j}$。

在这种构造下,$g_i(x_i) = y_i$,$g_i(x_j) = 0(i \ne j)$。

因此,有 $f(x) = \sum_{i=1}^n g_i(x)$,即:
$$
f(x) = \sum_{i=1}^n\left(y_i\prod_{i \ne j} \frac{x-x_j}{x_i-x_j}\right)
$$
那么,$f(k)$ 显然可以 $\mathcal O(n^2)$ 求了。

const int N = 2e3 + 7;
int n;
modint k, x[N], y[N], ans;

int main() {
    rd(n), rd(k);
    for (int i = 1; i <= n; i++) rd(x[i]), rd(y[i]);
    for (int i = 1; i <= n; i++) {
        modint a = y[i], b = 1;
        for (int j = 1; j <= n; j++)
            if (i != j) a *= k - x[j], b *= x[i] - x[j];
        ans += a / b;
    }
    print(ans);
    return 0;
}

注意,当 $x_i = i$ 或者 $x_i = i-1$ 时,进一步化简公式可以做到 $\mathcal O(n)$ 求 $f(k)$,但是要求出多项式还是得 $\mathcal O(n^2)$。

多项式乘法

快速傅里叶变换 (FFT)

利用复数将 $\mathcal O(n^2)$ 的多项式乘法优化到 $\mathcal O(n \log n)$。

【模板】P3803 【模板】多项式乘法

namespace FFT {
    const int N = 1 << 21 | 1;
    const double PI = acos(-1);
    struct I {
        double x, y;
        inline I() {}
        inline I(double x, double y) : x(x), y(y) {}
        inline I &operator = (int o) { return x = o, y = 0, *this; }
        inline I operator + (const I o) const { return I(x + o.x, y + o.y); }
        inline I operator - (const I o) const { return I(x - o.x, y - o.y); }
        inline I operator * (const I o) const {
            return I(x * o.x - y * o.y, x * o.y + y * o.x);
        }
    } a[N], b[N];
    inline void rd(I &x) { int o; ::rd(o), x = o; }
    inline void print(I x, char k = '\n') { ::print((int)(x.x + 0.5), k); }
    int n, m, k, l, r[N];
    inline void fft(I *a, int n, int x) {
        for (int i = 0; i < n; i++) if (i < r[i]) swap(a[i], a[r[i]]);
        for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1) {
            I W = I(cos(PI / k), x * sin(PI / k));
            for (int i = 0; i < n; i += o) {
                I w = I(1, 0);
                for (int j = 0; j < k; j++, w = w * W) {
                    I x = a[i+j], y = w * a[i+j+k];
                    a[i+j] = x + y, a[i+j+k] = x - y;
                }
            }
        }
    }
    inline void solve() {
        k = 1, l = 0;
        while (k <= n + m) k <<= 1, ++l;
        for (int i = 0; i < k; i++)
            r[i] = (r[i>>1] >> 1) | ((i & 1) << (l - 1));
        for (int i = n + 1; i < k; i++) a[i] = 0;
        for (int i = m + 1; i < k; i++) b[i] = 0;
        fft(a, k, 1), fft(b, k, 1);
        for (int i = 0; i < k; i++) a[i] = a[i] * b[i];
        fft(a, k, -1);
        for (int i = 0; i <= n + m; i++) a[i].x /= k;
    }
}
using namespace FFT;

int main() {
    rd(n), rd(m);
    for (int i = 0; i <= n; i++) rd(a[i]);
    for (int i = 0; i <= m; i++) rd(b[i]);
    solve();
    for (int i = 0; i <= n + m; i++) print(a[i], " \n"[i==n+m]);
    return 0;
}

快速数论变换 (NTT)

利用原根将 $\mathcal O(n^2)$ 的多项式乘法优化到 $\mathcal O(n \log n)$。

【模板】P3803 【模板】多项式乘法

namespace NTT {
    const int N = 1 << 21 | 1;
    int n, m, k, l, r[N];
    modint vk, g = 3, vg = 332748118, a[N], b[N];
    inline void ntt(modint *a, int n, int x) {
        for (int i = 0; i < n; i++) if (i < r[i]) swap(a[i], a[r[i]]);
        for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1) {
            modint W = (x > 0 ? g : vg) ^ ((P - 1) / o);
            for (int i = 0; i < n; i += o) {
                modint w = 1;
                for (int j = 0; j < k; j++, w *= W)
                    a[i+j+k] *= w,
                    a[i+j] += a[i+j+k],
                    a[i+j+k] = a[i+j] - a[i+j+k] - a[i+j+k];
            }
        }
    }
    inline void solve() {
        k = 1, l = 0;
        while (k <= n + m) k <<= 1, ++l;
        vk = (modint)1 / k;
        for (int i = 0; i < k; i++)
            r[i] = (r[i>>1] >> 1) | ((i & 1) << (l - 1));
        for (int i = n + 1; i < k; i++) a[i] = 0;
        for (int i = m + 1; i < k; i++) b[i] = 0;
        ntt(a, k, 1), ntt(b, k, 1);
        for (int i = 0; i < k; i++) a[i] *= b[i];
        ntt(a, k, -1);
        for (int i = 0; i <= n + m; i++) a[i] *= vk;
    }
}
using namespace NTT;

int main() {
    rd(n), rd(m);
    for (int i = 0; i <= n; i++) rd(a[i]);
    for (int i = 0; i <= m; i++) rd(b[i]);
    solve();
    for (int i = 0; i <= n + m; i++) print(a[i], " \n"[i==n+m]);
    return 0;
}

多项式位运算卷积 (FWT)

记对序列 $a$ 进行快速沃尔什变换后的序列为 $fwt[a]$。

已知序列 $a,b$,求一个新序列 $c = a \cdot b$,直接计算是 $\mathcal O(n^2)$ 的。

若 $a \to fwt[a]$ 和 $b \to fwt[b]$ 是 $\mathcal O(n \log n)$ 的,而 $fwt[c] = fwt[a] \cdot fwt[b]$ 是 $\mathcal O(n)$ 的,同时 $fwt[c] \to c$ 也是 $\mathcal O(n \log n)$ 的。

那么我们可以利用上述过程 $\mathcal O(n \log n)$ 求出 $c$。

在 OI 中,FWT 是用于解决对下标进行位运算卷积问题的方法。

$$
c_{i}=\sum_{i=j \oplus k} a_{j} b_{k}
$$

其中 $\oplus$ 是二元位运算中的一种。

要求

$$
c_{i}=\sum_{i=j | k} a_{j} b_{k}
$$

显然有 $j|i = i, k|i=i \to (j|k)|i = i$。

构造 $fwt[a]_i = \sum_{j|i=i} a_j$。

则有

$$
\begin{aligned}
fwt[a] \times fwt[b] &= \left(\sum_{j|i=i} a_j\right)\left(\sum_{k|i=i} b_k\right) \\
&= \sum_{j|i=i} \sum_{k|i=i} a_jb_k \\
&= \sum_{(j|k)|i = i} a_jb_k \\
&= fwt[c]
\end{aligned}
$$

$a \to fwt[a]$

要求
$$
fwt[a]_i = \sum_{j|i=i} a_j
$$
令 $a_0$ 表示 $a$ 中下标最高位为 $0$ 的那部分序列,$a_1$ 表示 $a$ 中下标最高位为 $1$ 的那部分序列。

则有
$$
fwt[a] = \text{merge}(fwt[a_0], fwt[a_0] + fwt[a_1])
$$
其中 $\text{merge}$ 表示「拼接」,$+$ 表示对应位置相加。

于是可以分治。

inline void OR(modint *f) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j+k] += f[i+j];
}

$fwt[a] \to a$


$$
fwt[a] = \text{merge}(fwt[a_0], fwt[a_0] + fwt[a_1])
$$
可得
$$
a = \text{merge}(a_0, a_1 – a_0)
$$

inline void IOR(modint *f) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j+k] -= f[i+j];
}

显然两份代码可以合并。

inline void OR(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j+k] += f[i+j] * x;
}

同理或。

inline void AND(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j] += f[i+j+k] * x;
}

异或

定义 $x\otimes y=\text{popcount}(x \& y) \bmod 2$,其中 $\text{popcount}$ 表示「二进制下 $1$ 的个数」。

满足 $(i\otimes j) \text{^} (i\otimes k)=i\otimes(j \text{^} k)$。

构造 $fwt[a]_i = \sum_{i\otimes j = 0} a_j – \sum_{i\otimes j = 1} a_j$。

则有

$$
\begin{aligned}
fwt[a] \times fwt[b] &= \left(\sum_{i\otimes j = 0} a_j – \sum_{i\otimes j = 1} a_j\right)\left(\sum_{i\otimes k = 0} b_k – \sum_{i\otimes k = 1} b_k\right) \\
&=\left(\sum_{i\otimes j=0}a_j\right)\left(\sum_{i\otimes k=0}b_k\right)-\left(\sum_{i\otimes j=0}a_j\right)\left(\sum_{i\otimes k=1}b_k\right)-\left(\sum_{i\otimes j=1}a_j\right)\left(\sum_{i\otimes k=0}b_k\right)+\left(\sum_{i\otimes j=1}a_j\right)\left(\sum_{i\otimes k=1}b_k\right) \\
&=\sum_{i\otimes(j \text{^} k)=0}a_jb_k-\sum_{i\otimes(j\text{^} k)=1}a_jb_k \\
&= fwt[c]
\end{aligned}
$$

因此

$$
\begin{aligned}
fwt[a] &= \text{merge}(fwt[a_0] + fwt[a_1], fwt[a_0] – fwt[a_1]) \\
a &= \text{merge}(\frac{a_0 + a_1}2, \frac{a_0 – a_1}2)
\end{aligned}
$$

inline void XOR(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j] += f[i+j+k],
                f[i+j+k] = f[i+j] - f[i+j+k] - f[i+j+k],
                f[i+j] *= x, f[i+j+k] *= x;
}

【模板】P4717 【模板】快速沃尔什变换 (FWT)

const int N = 1 << 17 | 1;
int n, m;
modint A[N], B[N], a[N], b[N];

inline void in() {
    for (int i = 0; i < n; i++) a[i] = A[i], b[i] = B[i];
}

inline void get() {
    for (int i = 0; i < n; i++) a[i] *= b[i];
}

inline void out() {
    for (int i = 0; i < n; i++) print(a[i], " \n"[i==n-1]);
}

inline void OR(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j+k] += f[i+j] * x;
}

inline void AND(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j] += f[i+j+k] * x;
}

inline void XOR(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j] += f[i+j+k],
                f[i+j+k] = f[i+j] - f[i+j+k] - f[i+j+k],
                f[i+j] *= x, f[i+j+k] *= x;
}

int main() {
    rd(m), n = 1 << m;
    for (int i = 0; i < n; i++) rd(A[i]);
    for (int i = 0; i < n; i++) rd(B[i]);
    in(), OR(a), OR(b), get(), OR(a, P - 1), out();
    in(), AND(a), AND(b), get(), AND(a, P - 1), out();
    in(), XOR(a), XOR(b), get(), XOR(a, (modint)1 / 2), out();
    return 0;
}

分治 FFT

分治 FFT 并不局限于 FFT,是将 CDQ 分治和多项式乘法 (FFT/NTT) 结合起来的思想。

【模板】P4721 【模板】分治 FFT

给定序列 $g_{1\dots n – 1}$,求序列 $f_{0\dots n – 1}$。

其中 $f_i=\sum_{j=1}^if_{i-j}g_j$,边界为 $f_0=1$。

答案对 $998244353$ 取模。

稍微变换一下,有 $f_i=\sum_{j=0}^{i-1}f_{j}g_{i-j}$。

注意到 $f_i$ 是依赖于 $f_{0\dots i-1}$ 的,因此如果依次计算每一项,时间复杂度无法承受。

考虑 CDQ 分治,即对于 $f_{l\dots r}$ 的值,我们先直接计算 $f_{l\dots mid}$ 的值,再考虑 $f_{l\dots mid}$ 对 $f_{mid+1 \dots r}$ 的贡献,最后计算 $f_{mid+1 \dots r}$ 的值。

现在的问题变成如何计算 $f_{l\dots mid}$ 对 $f_{mid+1 \dots r}$ 的贡献。

考虑对于 $x \in [mid+1,r]$,$f_{l \dots mid}$ 对 $f_x$ 的贡献为:
$$
\sum_{i=l}^{mid} f_ig_{x-i}
$$
设 $a_i = f_{i+l}, b_i = g_{i+1}$,则转化为:
$$
\sum_{i=0}^{mid-l} a_ib_{x-l-1-i}
$$
发现这是个卷积式,设 $n = r – l$,可以 $\mathcal O(n \log n)$ 计算。

总时间复杂度 $\mathcal T(n) = \mathcal T(\frac n2) + \mathcal O(n \log n) = \mathcal O(n \log^2 n)$。

const int N = 1e5 + 7;
int n;
modint f[N], g[N];

namespace NTT {
    const int N = 1 << 21 | 1;
    int n, m, k, l, r[N];
    modint vk, g = 3, vg = 332748118, a[N], b[N];
    inline void ntt(modint *a, int n, int x) {
        for (int i = 0; i < n; i++) if (i < r[i]) swap(a[i], a[r[i]]);
        for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1) {
            modint W = (x > 0 ? g : vg) ^ ((P - 1) / o);
            for (int i = 0; i < n; i += o) {
                modint w = 1;
                for (int j = 0; j < k; j++, w *= W)
                    a[i+j+k] *= w,
                    a[i+j] += a[i+j+k],
                    a[i+j+k] = a[i+j] - a[i+j+k] - a[i+j+k];
            }
        }
    }
    inline void solve() {
        k = 1, l = 0;
        while (k <= n + m) k <<= 1, ++l;
        vk = (modint)1 / k;
        for (int i = 0; i < k; i++)
            r[i] = (r[i>>1] >> 1) | ((i & 1) << (l - 1));
        for (int i = n + 1; i < k; i++) a[i] = 0;
        for (int i = m + 1; i < k; i++) b[i] = 0;
        ntt(a, k, 1), ntt(b, k, 1);
        for (int i = 0; i < k; i++) a[i] *= b[i];
        ntt(a, k, -1);
        for (int i = 0; i <= n + m; i++) a[i] *= vk;
    }
}

void cdq(int l, int r) {
    if (l == r) return;
    int mid = (l + r) >> 1;
    cdq(l, mid);
    NTT::n = mid - l, NTT::m = r - l - 1;
    for (int i = 0; i <= mid - l; i++) NTT::a[i] = f[i+l];
    for (int i = 0; i < r - l; i++) NTT::b[i] = g[i+1];
    NTT::solve();
    for (int i = mid + 1; i <= r; i++) f[i] += NTT::a[i-l-1];
    cdq(mid + 1, r);
}

int main() {
    rd(n), f[0] = 1;
    for (int i = 1; i < n; i++) rd(g[i]);
    cdq(0, n - 1);
    for (int i = 0; i < n; i++) print(f[i], " \n"[i==n-1]);
    return 0;
}

参考资料

$\textstyle$

发表评论

电子邮件地址不会被公开。 必填项已用*标注