原文连接www.cnblogs.com/zhouzhendong/p/UOJ470.htmlhtml
作完情报中心来看这个题忽然发现两题有类似之处而后就会作了。c++
首先,咱们考虑将全部答案点对分为两类。git
第一种状况很是简单,这里不加赘述。spa
对于第二种状况,咱们首先考虑简单作法:code
考虑对于每个节点分开处理。htm
按照某一种顺序枚举它的子树,对于全部“一端在当前子树内,另外一端在当前子树以前的子树”的路径,咱们求它们的贡献。blog
接下来提到的“虚树“中默认加入当前节点。get
考虑对当前子树内路径端点创建虚树,而后在虚树上 dfs。对于虚树上的一个节点,它在另一个子树中有相同语言的节点就是它在虚树上的子树中的全部端点的另外一端点构成的虚树大小。it
一个节点的子树中全部端点对应的点构成的虚树能够由儿子节点的虚树合并而来。class
若是事先将虚树内的节点存在 set 中,则能够在关于点数较少的虚树的复杂度内合并两棵虚树,具体地说是 size * log(n) 。
考虑使用 DSU on tree,咱们能够获得一个 $O(n\log ^ 3n)$ 的作法。
注意到,在不少问题里,线段树合并均可以处理树上启发式合并的问题,并且复杂度都会降低。这里也相似,考虑合并两个 dfs序 分别独立的虚树时,只须要特殊考虑 dfs序 小的虚树的 dfs序最大节点和 dfs序 大的虚树的dfs序最小节点到根的路径交便可。
因而,咱们考虑采用线段树合并维护子树虚树 size,因为线段树合并中须要求 LCA,因此咱们考虑用 ST表 来求 LCA,作到单次询问 $O(1)$,便可获得一个总时间复杂度 $O((n+m)\log n)$ 的作法。
#include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof x) #define For(i,a,b) for (int i=(a);i<=(b);i++) #define Fod(i,b,a) for (int i=(b);i>=(a);i--) #define fi first #define se second #define pb(x) push_back(x) #define mp(x,y) make_pair(x,y) #define outval(x) cerr<<#x" = "<<x<<endl #define outtag(x) cerr<<"---------------"#x"---------------"<<endl #define outarr(a,L,R) cerr<<#a"["<<L<<".."<<R<<"] = ";\ For(_x,L,R)cerr<<a[_x]<<" ";cerr<<endl; using namespace std; typedef long long LL; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=100005*2; int n,m; vector <int> e[N]; struct cha{ int x,y,lca; int xf,yf; }a[N]; int depth[N],fa[N][20]; int ett[N],c=0,I[N]; void dfs(int x,int pre,int d){ depth[x]=d,fa[x][0]=pre; For(i,1,19) fa[x][i]=fa[fa[x][i-1]][i-1]; ett[I[x]=++c]=x; for (int y : e[x]) if (y!=pre) dfs(y,x,d+1),ett[++c]=x; } int st[N][20],Log[N]; int min_dep(int x,int y){ return depth[x]<depth[y]?x:y; } void Get_ST(){ For(i,2,c) Log[i]=Log[i>>1]+1; For(i,1,c){ st[i][0]=ett[i]; For(j,1,19){ st[i][j]=st[i][j-1]; if (i-(1<<(j-1))>0) st[i][j]=min_dep(st[i][j],st[i-(1<<(j-1))][j-1]); } } } int LCA(int x,int y){ x=I[x],y=I[y]; if (x>y) swap(x,y); int d=Log[y-x+1]; return min_dep(st[x+(1<<d)-1][d],st[y][d]); } int Dis(int x,int y){ return depth[x]+depth[y]-2*depth[LCA(x,y)]; } namespace Seg{ const int S=N*20*2; int sz[S],lp[S],rp[S],ls[S],rs[S]; int cnt=0; void pushup(int rt){ if (!sz[ls[rt]]&&!sz[rs[rt]]) sz[rt]=lp[rt]=rp[rt]=0; else if (!sz[rs[rt]]) sz[rt]=sz[ls[rt]],lp[rt]=lp[ls[rt]],rp[rt]=rp[ls[rt]]; else if (!sz[ls[rt]]) sz[rt]=sz[rs[rt]],lp[rt]=lp[rs[rt]],rp[rt]=rp[rs[rt]]; else { sz[rt]=sz[ls[rt]]+sz[rs[rt]]-depth[LCA(rp[ls[rt]],lp[rs[rt]])]; lp[rt]=lp[ls[rt]],rp[rt]=rp[rs[rt]]; } } void Ins(int &rt,int L,int R,int x){ if (!rt) rt=++cnt,sz[rt]=ls[rt]=rs[rt]=lp[rt]=rp[rt]=0; if (L==R){ lp[rt]=rp[rt]=x,sz[rt]=depth[x]; return; } int mid=(L+R)>>1; if (I[x]<=mid) Ins(ls[rt],L,mid,x); else Ins(rs[rt],mid+1,R,x); pushup(rt); } int Merge(int x,int y,int L,int R){ if (!x||!y) return x|y; if (L==R) return x; int mid=(L+R)>>1,rt=++cnt; ls[rt]=Merge(ls[x],ls[y],L,mid); rs[rt]=Merge(rs[x],rs[y],mid+1,R); pushup(rt); return rt; } } int go_son(int x,int f){ Fod(i,19,0) if (depth[x]-(1<<i)>depth[f]) x=fa[x][i]; return x; } LL ans=0; vector <int> qid[N]; int up[N]; bool cmp_qid(int x,int y){ return I[a[x].xf]<I[a[y].xf]; } bool cmpI(int x,int y){ return I[x]<I[y]; } int rt[N]; void Solve(int x,int *id,int n){ static int t[N],st[N]; int tc=0,top=0; For(i,0,n-1) t[++tc]=a[id[i]].x; t[++tc]=x; sort(t+1,t+tc+1,cmpI); tc=unique(t+1,t+tc+1)-t-1; For(i,1,tc) rt[t[i]]=0; For(i,0,n-1) Seg::Ins(rt[a[id[i]].x],1,c,a[id[i]].y); For(_,1,tc){ int i=t[_]; if (top){ int lca=LCA(i,st[top]); while (depth[st[top]]>depth[lca]){ int now=st[top]; if (depth[st[top-1]]>=depth[lca]){ ans+=(LL)(depth[now]-depth[st[top-1]])*(Seg::sz[rt[now]]-depth[x]); rt[st[top-1]]=Seg::Merge(rt[st[top-1]],rt[now],1,c); top--; } else { ans+=(LL)(depth[now]-depth[lca])*(Seg::sz[rt[now]]-depth[x]); rt[lca]=rt[now]; st[top]=lca; break; } } } st[++top]=i; } while (top>1){ int now=st[top]; ans+=(LL)(depth[now]-depth[st[top-1]])*(Seg::sz[rt[now]]-depth[x]); rt[st[top-1]]=Seg::Merge(rt[st[top-1]],rt[now],1,c); top--; } } void Solve(int x,int pre){ for (int y : e[x]) if (y!=pre) Solve(y,x),up[x]=max(up[x],up[y]-1); ans+=up[x]; sort(qid[x].begin(),qid[x].end(),cmp_qid); int s=(int)qid[x].size(); for (int i=0,j;i<s;i=j+1){ for (j=i;j+1<s&&I[a[qid[x][i]].xf]==I[a[qid[x][j+1]].xf];j++); Solve(x,&qid[x][i],j-i+1); } } int main(){ n=read(),m=read(); For(i,1,n-1){ int x=read(),y=read(); e[x].pb(y),e[y].pb(x); } dfs(1,0,0); Get_ST(); For(i,1,m){ int x=a[i].x=read(),y=a[i].y=read(),lca=a[i].lca=LCA(x,y); up[x]=max(up[x],depth[x]-depth[lca]); up[y]=max(up[y],depth[y]-depth[lca]); if (x!=lca&&y!=lca){ a[i].xf=go_son(x,lca); a[i].yf=go_son(y,lca); if (I[a[i].xf]<I[a[i].yf]) swap(a[i].xf,a[i].yf),swap(a[i].x,a[i].y); qid[lca].pb(i); } } Solve(1,0); cout<<ans<<endl; return 0; }