神奇的思路,仍是要学习一个。c++
题意:给你一个字符串,并定义两个前缀的lcs、两个后缀的lcp,求式子膜\(2^{64}\)的值。
\[ \sum_{1\le i<j\le n} lcp(i,j)lcs(i,j)[lcp(i,j)\le k1][lcs(i,j)\le k2] \]
分析:数组
对于一对存在贡献的\(<i,j>\),咱将它们的lcs、lcp拼起来,可知
\[ s[i-lcs(i,j)+1,i+lcp(i,j)-1]=s[j-lcs(i,j)+1,j+lcp(i,j)-1]\\ s[i-lcs(i,j)]\not=s[j-lcs(i,j)]\\ s[i+lcp(i,j)]\not=s[j+lcp(i,j)]\\ \]
这启发咱们找出全部知足下列条件的子串对\(<i,j,len>\)
\[ s[i,i+len-1]=s[j,j+len-1],s[i-1]\not=s[j-1],s[i+len]\not=s[j+len] \]
能够知道它的贡献为
\[ \sum_{\max(1,len-k2+1)}^{\min(len,k1)} k(len-k+1)=\sum_{k=1}^{min(len,k1)}k(len-k+1)-\sum_{k=1}^{\max(0,len-k2)}k(len-k+1) \]
因而考虑创建SA,并记录后缀的前一个字符。学习
在height数组上从高到低启发式合并,一边统计答案。ui
#include <bits/stdc++.h> #define ull unsigned long long using namespace std; const int N=1e5+10; int n,k1,k2; char s[N]; int sa[N],ht[N],rc[N],c[N]; int lp[N],rp[N],bl[N],siz[N],cnt[N][26]; void buildSa() { int *x=ht,*y=rc,i,p,k,m=128; for(i=0; i<=m; ++i) c[i]=0; for(i=1; i<=n; ++i) c[x[i]=s[i]]++; for(i=1; i<=m; ++i) c[i]+=c[i-1]; for(i=n; i>=1; --i) sa[c[x[i]]--]=i; for(k=1; k<n; k<<=1) { for(i=n-k+1,p=0; i<=n; ++i) y[++p]=i; for(i=1; i<=n; ++i) if(sa[i]>k) y[++p]=sa[i]-k; for(i=0; i<=m; ++i) c[i]=0; for(i=1; i<=n; ++i) c[x[y[i]]]++; for(i=1; i<=m; ++i) c[i]+=c[i-1]; for(i=n; i>=1; --i) sa[c[x[y[i]]]--]=y[i]; swap(x,y), x[sa[1]]=p=1; for(i=2; i<=n; ++i) x[sa[i]]= y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k]?p:++p; if((m=p)>=n) break; } for(i=1; i<=n; ++i) rc[sa[i]]=i; for(i=1,k=0; i<=n; ++i) { p=sa[rc[i]-1]; if(k) k--; while(s[i+k]==s[p+k]) ++k; ht[rc[i]]=k; } // for(int i=1; i<=n; ++i) { // cout<<(s+sa[i]); // if(i>1) cout<<" "<<ht[i]; // cout<<endl; // } } pair<int,int> h[N]; ull sm(int x) {return (ull)x*(x+1)/2;} ull ssm(int x) {return (ull)x*(2*x+1)*(x+1)/6;} ull F(int x) { if(x>=k1+k2) return 0; ull s1=(ull)(x+1)*sm(min(x,k1))-ssm(min(x,k1)); ull s2=(ull)(x+1)*sm(max(0,x-k2))-ssm(max(0,x-k2)); return s1-s2; } ull f[N]; ull calc(int x,int y) { ull res=(ull)siz[x]*siz[y]; for(int i=0; i<26; ++i) res-=(ull)cnt[x][i]*cnt[y][i]; return res; } void merge(int x,int y) { for(int i=0; i<26; ++i) cnt[y][i]+=cnt[x][i]; for(int i=lp[x]; i<=rp[x]; ++i) bl[i]=y; lp[y]=min(lp[y],lp[x]); rp[y]=max(rp[y],rp[x]); siz[y]+=siz[x]; } int main() { scanf("%s%d%d",s+1,&k1,&k2); n=strlen(s+1); k1=min(k1,n); k2=min(k2,n); for(int i=1; i<=n; ++i) f[i]=F(i); buildSa(); for(int i=1; i<=n; ++i) { lp[i]=rp[i]=bl[i]=i; siz[i]=1; if(sa[i]>1) cnt[i][s[sa[i]-1]-'a']++; } for(int i=2; i<=n; ++i) h[i-1]=make_pair(-ht[i],i); sort(h+1,h+n); ull ans=0; for(int i=1; i<n; ++i) { int len=-h[i].first; int x=bl[h[i].second]; int y=bl[h[i].second-1]; if(siz[x]>siz[y]) swap(x,y); ans+=(ull)f[len]*calc(x,y); merge(x,y); // printf("%d,%d,%d,(%llu)\n",len,x,y,ans); } printf("%llu\n",ans); return 0; }