给出一棵n个点的树,树上的边权要么为0,要么为1
php
要求找出有多少条路径,知足:node
1.路径上0的数量等于1的数量spa
2.可以在这条路径上找到一个点(不包括起点和终点),使得起点到这个点,终点到这个点所构成的两条路径都知足条件1code
点分治
blog
对于一个分治中心,咱们处理通过分治中心的路径数get
先把边权为0的看成边权为-1,而后当前分治树求出每一个点的深度(也就是根到当前点的路径上的边权和)string
若是x到根的路径上出现与x深度相同的点,则说明这两个点能够构成一条知足条件1的路径,那么咱们就把x打上标记it
而后对于每个点进行分类讨论:io
1.深度为0,没打标记----->那么这个点能 和 与本身不在一棵子树上且深度为0的点 构成答案路径ast
2.深度为0,打了标记----->那么这个点能 和 与本身不在一棵子树上且深度为0的点 以及 根节点 构成答案路径
假设x!=0
3.深度为x,没打标记----->那么这个点能 和 与本身不在一棵子树上且深度为-x且打了标记的点 构成答案路径
4.深度为x,打了标记----->那么这个点能 和 与本身不在一棵子树上且深度为-x的点 构成答案路径
上面的操做用tol[0或1][x]记录每种深度的数量就能作了
#include<cstdio> #include<cstring> #include<cstdlib> #include<cmath> #include<algorithm> #define mp(x,y) make_pair(x,y) #define pin pair<int,int> #define Maxn 110000 using namespace std; typedef long long LL; struct node{int x,y,c,next;}a[Maxn*2];int len,last[Maxn]; void ins(int x,int y,int c){a[++len]=(node){x,y,c,last[x]};last[x]=len;} int sum,ms[Maxn],tot[Maxn],rt; bool v[Maxn]; void getrt(int x,int fa) { tot[x]=1;ms[x]=0; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y==fa||v[y]==true) continue; getrt(y,x); tot[x]+=tot[y]; ms[x]=max(ms[x],tot[y]); } ms[x]=max(ms[x],sum-tot[x]); if(ms[x]<ms[rt]) rt=x; } LL ans; int dep[Maxn];//深度 int bo[Maxn];//标记 pin sta[Maxn];int tp; int cnt[Maxn*2]; int tol1[2][Maxn*2];//点种类的数量,tol[0~1][x+Maxn]表示是否打标记且深度为x的点数 int tol2[2][Maxn*2]; void getdep(int x,int fa) { if(tol2[bo[x]][dep[x]+Maxn]==1) sta[++tp]=mp(bo[x],dep[x]); for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y==fa||v[y]==true) continue; dep[y]=dep[x]+a[k].c; if(cnt[dep[y]+Maxn]!=0) bo[y]=1; cnt[dep[y]+Maxn]++; tol2[bo[y]][dep[y]+Maxn]++; getdep(y,x); cnt[dep[y]+Maxn]--; } } int ts[Maxn],sp; void clear(int x,int fa) { for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y==fa||v[y]==true) continue; tol1[bo[y]][dep[y]+Maxn]--; bo[y]=0; clear(y,x); } } void solve(int x,int fa) { v[x]=true;dep[x]=0; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y==fa||v[y]==true) continue; dep[y]=a[k].c;tp=0; tol2[bo[y]][dep[y]+Maxn]++; cnt[dep[y]+Maxn]++; getdep(y,0); cnt[dep[y]+Maxn]--; for(int i=1;i<=tp;i++) { int p1=sta[i].first,p2=sta[i].second; if(p1==0&&p2==0) ans+=(LL)tol2[0][Maxn]*(tol1[0][Maxn]+tol1[1][Maxn]); if(p1==1&&p2==0) ans+=(LL)tol2[1][Maxn]*(tol1[0][Maxn]+tol1[1][Maxn]+1); if(p1==0&&p2!=0) ans+=(LL)tol2[0][p2+Maxn]*tol1[1][Maxn-p2]; if(p1==1&&p2!=0) ans+=(LL)tol2[1][p2+Maxn]*(tol1[1][Maxn-p2]+tol1[0][Maxn-p2]); } for(int i=1;i<=tp;i++) { int p1=sta[i].first,p2=sta[i].second; tol1[p1][p2+Maxn]+=tol2[p1][p2+Maxn]; tol2[p1][p2+Maxn]=0; } } clear(x,0); for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y==fa||v[y]==true) continue; rt=0;sum=tot[y]; getrt(y,0); solve(rt,0); } } int main() { //freopen("a.in","r",stdin); //freopen("vio.out","w",stdout); int n; scanf("%d",&n); len=0;memset(last,0,sizeof(last)); for(int i=1;i<n;i++) { int x,y,c; scanf("%d%d%d",&x,&y,&c); if(c==0) c=-1; ins(x,y,c);ins(y,x,c); } sum=ms[0]=n; rt=0;getrt(1,0); memset(v,false,sizeof(v)); tp=0;solve(rt,0); printf("%lld\n",ans); return 0; }