题目大意:给定一个长度为 \(N\) 的序列,求从序列中选出 \(K\) 个数的集合乘积之和是多少。ios
题解:
因为是选出 \(K\) 个数字组成的集合,可知对于要计算的 \(K\) 元组来讲是没有标号的,而元组是由序列中 \(N\) 个数字组合而成的。所以,将要求的元组看做组合对象,该组合对象是由 \(N\) 个不一样种类的组合对象组成的,且组合对象是没有标号的,所以采用普通生成函数计算便可。
对于第 \(i\) 个数的普通生成函数为 \[(1 + a_ix)\],所以,原组合对象的生成函数是\[\prod\limits_{i = 1}^{n}(1+a_ix)\]。能够经过分治乘法来进行计算,时间复杂度为 \(O(nlogn)\)。c++
代码以下函数
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int mod = 998244353, g = 3, ig = 332748118; LL fpow(LL a, LL b, LL c) { LL ret = 1 % c; for (; b; b >>= 1, a = a * a % mod) if (b & 1) ret = ret * a % mod; return ret; } void ntt(vector<LL> &v, vector<int> &rev, int opt) { int tot = v.size(); for (int i = 0; i < tot; i++) if (i < rev[i]) swap(v[i], v[rev[i]]); for (int mid = 1; mid < tot; mid <<= 1) { LL wn = fpow(opt == 1 ? g : ig, (mod - 1) / (mid << 1), mod); for (int j = 0; j < tot; j += mid << 1) { LL w = 1; for (int k = 0; k < mid; k++) { LL x = v[j + k], y = v[j + mid + k] * w % mod; v[j + k] = (x + y) % mod, v[j + mid + k] = (x - y + mod) % mod; w = w * wn % mod; } } } if (opt == -1) { LL itot = fpow(tot, mod - 2, mod); for (int i = 0; i < tot; i++) v[i] = v[i] * itot % mod; } } vector<LL> convolution(vector<LL> &a, int cnta, vector<LL> &b, int cntb, const function<LL(LL, LL)> &calc) { int bit = 0, tot = 1; while (tot <= 2 * max(cnta, cntb)) bit++, tot <<= 1; vector<int> rev(tot); for (int i = 0; i < tot; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << (bit - 1); vector<LL> foo(tot), bar(tot); for (int i = 0; i < cnta; i++) foo[i] = a[i]; for (int i = 0; i < cntb; i++) bar[i] = b[i]; ntt(foo, rev, 1), ntt(bar, rev, 1); for (int i = 0; i < tot; i++) foo[i] = calc(foo[i], bar[i]); ntt(foo, rev, -1); return foo; } int main() { //freopen("data.in", "r", stdin); ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); int n, K; cin >> n >> K; vector<LL> a(n); for (int i = 0; i < n; i++) { cin >> a[i]; } int m; cin >> m; while (m--) { int opt; cin >> opt; vector<LL> b = a; if (opt == 1) { int q, x, y; cin >> q >> x >> y; x--; b[x] = y; for (int i = 0; i < n; i++) { b[i] = (q - b[i] + mod) % mod; } } else { int q, l, r, d; cin >> q >> l >> r >> d; l--, r--; for (int i = l; i <= r; i++) { b[i] = (b[i] + d) % mod; } for (int i = 0; i < n; i++) { b[i] = (q - b[i] + mod) % mod; } } function<vector<LL>(int, int)> solve = [&](int l, int r) { if (l == r) { return vector<LL> {1, b[l]}; } int mid = l + r >> 1; vector<LL> ls = solve(l, mid); vector<LL> rs = solve(mid + 1, r); return convolution(ls, mid - l + 2, rs, r - mid + 1, [&](LL a, LL b) { return a * b % mod; }); }; vector<LL> ans = solve(0, n - 1); cout << ans[K] << endl; } return 0; }