题目连接ios
定义一次点分治的复杂度是全部分治中心分治时的子树大小之和。c++
给定一棵树,问全部点等几率被选作重心,点分治的指望复杂度。git
根据指望的线性性,答案等价于每一个点在点分树上的深度指望之和。数组
思路是从点对的角度考虑某一个点是否会产生贡献。
\[ E(depth[x])=\sum_{y=1}^n P(x\in subtree[y]) \]ide
也就是 \(x\) 在点分树上在 \(1\dots n\) 的子树中的几率和。spa
考虑点分树上 \(y\) 是 \(x\) 的祖先的条件,要求 \(x\) 和 \(y\) 构成的这条链上第一个在点分治过程当中被删除的点是 \(y\) ,因为链上被选中的几率相等,所以这个几率为 \(\frac{1}{dist(x,y) + 1}\)。code
因此答案为
\[ \sum_{x=1}^n\sum_{j=1}^n \frac{1}{dis(i,j) + 1}=\sum_{len = 0}^n \frac{cnt[i]}{i + 1} \]排序
所以须要点分治求长度为 \(i\) 的路径条数 \(cnt[i]\) ,注意到合并的时候是卷积的形式。ip
不考虑重复路径,把子树 dfs 一遍,直接本身进行卷积,再去掉子树内重复计数的路径便可。get
每一层最差以本身的 \(size\) 做为长度进行卷积,所以复杂度为 \(\mathcal O(n\log^2 n)\)
#include <cmath> #include <cstdio> #include <cctype> #include <cstdlib> #include <cstring> #include <iostream> #include <algorithm> #define N 65537 #define mod 998244353 using namespace std; typedef long long ll; inline int rd() { int x = 0; char c = getchar(); while (!isdigit(c)) c = getchar(); while (isdigit(c)) { x = x * 10 + (c ^ 48); c = getchar(); } return x; } inline void print(ll x) { int y = 10, len = 1; while(y <= x) {y *= 10; ++len;} while(len--) {y /= 10; putchar(x / y + 48); x %= y;} putchar('\n'); } inline int fpow(int x, int t = mod - 2) { int res = 1; while (t) { if (t & 1) res = 1ll * res * x % mod; x = 1ll * x * x % mod; t >>= 1; } return res; } int mxlen = (1 << 16), w[2][N], rev[N]; inline int mo(int x) { return x >= mod ? x - mod : x; } inline void init() { int per = fpow(3, (mod - 1) / mxlen); int invper = fpow(per); w[0][0] = w[1][0] = 1; for (int i = 1; i < mxlen; ++i) { w[0][i] = 1ll * w[0][i - 1] * per % mod; w[1][i] = 1ll * w[1][i - 1] * invper % mod; } } inline int Rev(int n) { int len = 1, bit = 0; while (len <= n) len <<= 1, ++bit; for (int i = 0; i < len; ++i) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1))); return len; } inline void NTT(int *f, int len, int o) { for (int i = 0; i < len; ++i) if (i > rev[i]) swap(f[i], f[rev[i]]); for (int i = 1; i < len; i <<= 1) { int wn = mxlen / (i << 1); for (int j = 0; j < len; j += (i << 1)) { int nw = 0, x, y; for (int k = 0; k < i; ++k, nw += wn) { x = f[j + k]; y = 1ll * w[o][nw] * f[i + j + k] % mod; f[j + k] = mo(x + y); f[i + j + k] = mo(x - y + mod); } } } if (o == 1) { int invl = fpow(len); for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod; } } bool vis[N]; int n, m, tot, totn, mx, rt, mxd; int bkt[N], cnt[N], sz[N], hd[N]; struct edge{int to, nxt;} e[N << 1]; inline void add(int u, int v) { e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot; e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot; } void getrt(int u, int fa) { sz[u] = 1; int mxs = 0; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) { getrt(v, u); sz[u] += sz[v]; mxs = max(mxs, sz[v]); } mxs = max(mxs, totn - sz[u]); if (mxs < mx) {mx = mxs; rt = u;} } void getsz(int u, int fa) { sz[u] = 1; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) { getsz(v, u); sz[u] += sz[v]; } } void dfs(int u, int fa, int dep) { ++bkt[dep]; mxd = max(mxd, dep); for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) dfs(v, u, dep + 1); } inline void mul(int *a, int len, int o) { len = Rev(len << 1); NTT(a, len, 0); for (int i = 0; i < len; ++i) a[i] = 1ll * a[i] * a[i] % mod; NTT(a, len, 1); if (o > 0) for (int i = 0; i < len; ++i) cnt[i + 1] += a[i]; else for (int i = 0; i < len; ++i) cnt[i + 3] -= a[i]; for (int i = 0; i < len; ++i) a[i] = 0; } inline void calc(int u, int o) { mxd = 0; dfs(u, 0, 0); mul(bkt, mxd, o); } void divide(int u) { vis[u] = 1; calc(u, 1); for (int i = hd[u], v; i; i = e[i].nxt) if (!vis[v = e[i].to]) { calc(v, -1); getsz(v, u); totn = mx = sz[v]; rt = v; getrt(v, 0); divide(rt); } } int main() { init(); n = rd(); for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1); mx = totn = n; getrt(1, 0); divide(rt); double ans = 0.0; for (int i = 1; i <= n + 1; ++i) ans += (double) cnt[i] / i; printf("%.4lf", ans); return 0; }
在点分治求路径条数时,咱们尝试用按秩合并的思路去搞,也就是将子树按照最深深度排序,而后逐个合并计算答案。
开始的时候只有 \(bkt[0]=1\),而后按顺序卷每个子树求出来的计数数组 \(bktson\) 。
把贡献直接计算,而后再将 \(bktson\) 按位加到 \(bkt\) 上。
考虑复杂度,将子树按照深度从小到大排序后,每次卷积获得的新的链长不会超过新合并的子树深度的二倍,因此每次卷积的数组长度为 \(mxdep[v]\) 的,且每一个位置只会和其父节点卷积一次,所以总复杂度为 \(\mathcal O(n\log^2 n)\)
#include <cmath> #include <cstdio> #include <cctype> #include <cstdlib> #include <cstring> #include <iostream> #include <algorithm> #define N 65537 #define mod 998244353 using namespace std; typedef long long ll; inline int rd() { int x = 0; char c = getchar(); while (!isdigit(c)) c = getchar(); while (isdigit(c)) { x = x * 10 + (c ^ 48); c = getchar(); } return x; } inline void print(ll x) { int y = 10, len = 1; while(y <= x) {y *= 10; ++len;} while(len--) {y /= 10; putchar(x / y + 48); x %= y;} putchar('\n'); } inline int fpow(int x, int t = mod - 2) { int res = 1; while (t) { if (t & 1) res = 1ll * res * x % mod; x = 1ll * x * x % mod; t >>= 1; } return res; } int mxlen = (1 << 16), w[2][N], rev[N]; inline int mo(int x) { return x >= mod ? x - mod : x; } inline void init() { int per = fpow(3, (mod - 1) / mxlen); int invper = fpow(per); w[0][0] = w[1][0] = 1; for (int i = 1; i < mxlen; ++i) { w[0][i] = 1ll * w[0][i - 1] * per % mod; w[1][i] = 1ll * w[1][i - 1] * invper % mod; } } inline int Rev(int n) { int len = 1, bit = 0; while (len <= n) len <<= 1, ++bit; for (int i = 0; i < len; ++i) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1))); return len; } inline void NTT(int *f, int len, int o) { for (int i = 0; i < len; ++i) if (i > rev[i]) swap(f[i], f[rev[i]]); for (int i = 1; i < len; i <<= 1) { int wn = mxlen / (i << 1); for (int j = 0; j < len; j += (i << 1)) { int nw = 0, x, y; for (int k = 0; k < i; ++k, nw += wn) { x = f[j + k]; y = 1ll * w[o][nw] * f[i + j + k] % mod; f[j + k] = mo(x + y); f[i + j + k] = mo(x - y + mod); } } } if (o == 1) { int invl = fpow(len); for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod; } } bool vis[N]; double ans = 0.0; int n, m, tot, totn, mx, rt; int bkt[N], sz[N], hd[N]; struct edge{int to, nxt;} e[N << 1]; inline void add(int u, int v) { e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot; e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot; } void getrt(int u, int fa) { sz[u] = 1; int mxs = 0; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) { getrt(v, u); sz[u] += sz[v]; mxs = max(mxs, sz[v]); } mxs = max(mxs, totn - sz[u]); if (mxs < mx) {mx = mxs; rt = u;} } void getsz(int u, int fa) { sz[u] = 1; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) { getsz(v, u); sz[u] += sz[v]; } } int res[N], tmp[N]; inline int mul(int *a, int *b, int lena, int lenb) { int len = Rev(lenb << 1); for (int i = 0; i < lena; ++i) res[i] = a[i]; for (int i = lena; i < len; ++i) res[i] = 0; for (int i = 0; i < lenb; ++i) tmp[i] = b[i]; for (int i = lenb; i < len; ++i) tmp[i] = 0; NTT(res, len, 0); NTT(tmp, len, 0); for (int i = 0; i < len; ++i) res[i] = 1ll * res[i] * tmp[i] % mod; NTT(res, len, 1); for (int i = 0; i < len; ++i) ans += 2.0 * res[i] / (i + 1); return len; } int mxd[N], s[N], bkts[N]; inline bool cmp(int x, int y) {return mxd[x] < mxd[y];} int dfs(int u, int fa, int dep) { int resd = dep; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) resd = max(resd, dfs(v, u, dep + 1)); return resd; } void dfs2(int u, int fa, int dep) { ++bkts[dep]; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) dfs2(v, u, dep + 1); } void divide(int u) { vis[u] = 1; s[0] = 0; for (int i = hd[u], v; i; i = e[i].nxt) if (!vis[v = e[i].to]) { s[++s[0]] = v; mxd[v] = dfs(v, u, 1); } sort(s + 1, s + 1 + s[0], cmp); bkt[0] = 1; int nowlen = 1; for (int i = 1, v; i <= s[0]; ++i) { dfs2(v = s[i], 0, 1); nowlen = mul(bkt, bkts, nowlen, mxd[v] + 1); for (int i = 0; i <= mxd[v]; ++i) { bkt[i] += bkts[i]; bkts[i] = 0; } } for (int i = 0; i <= nowlen; ++i) bkt[i] = 0; for (int i = hd[u], v; i; i = e[i].nxt) if (!vis[v = e[i].to]) { getsz(v, u); totn = mx = sz[v]; rt = v; getrt(v, 0); divide(rt); } } int main() { init(); n = rd(); for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1); mx = totn = n; getrt(1, 0); divide(rt); printf("%.4lf", ans + n); return 0; }