如果这篇博客帮助到你,可以请我喝一杯咖啡~
CC BY-NC-SA 4.0 (除特别声明或转载文章外)
观察题意,题目求的是询问串与编号从 $1$ 到 $i$ 的后缀的 $\text{LCP}$ 之和加 $i-1$,其中 $i$ 为询问串第一次匹配的位置;若无匹配位置则求得东西为询问串和所有后缀的 $\text{LCP}$ 之和加 $n$。
考虑对长序列号建后缀树,每个节点维护子树内编号最小的后缀的编号,将询问串放上去匹配,我们将失配的位置或者匹配到的位置存下来,如果在边上,就把这个点分裂出来。设每条边边权为 $w_i$,初始全为 $0$,把询问离线下来,按 $i$ 排序,我们要做的就是把一个后缀对应的节点到根的所有边的 $w_i$ 加 $1$,查询一个节点到根的 $\sum w_ilen_i$,树剖加线段树或者全局平衡二叉树均可,时间复杂度 $O(n\log^2 n)$ 或 $O(n\log n)$。
具体的细节可以看代码:
$O(n\log^2n)$:
#include<cstdio>
#include<algorithm>
int n,m;
char s[100010],t[100010];
struct suffixTree{
static const int inf=100000000;
int link[200010],len[300010],start[300010],s[200010],n,tail,now,rem,ch[300010][11];
suffixTree():link(),len(),start(),s(),n(0),tail(1),now(1),rem(0),ch(){len[0]=inf;}
int newnode(int st,int le){
link[++tail]=1;start[tail]=st;len[tail]=le;return tail;
}
void extend(int x){
s[++n]=x,rem++;
for(int last=1;rem;){
while(rem>len[ch[now][s[n-rem+1]]]) rem-=len[now=ch[now][s[n-rem+1]]];
int &v=ch[now][s[n-rem+1]];
int c=s[start[v]+rem-1];
if(!v||x==c){
link[last]=now;last=now;
if(!v) v=newnode(n,inf);
else break;
}else{
int u=newnode(start[v],rem-1);
ch[u][c]=v;ch[u][x]=newnode(n,inf);
start[v]+=rem-1;len[v]-=rem-1;
link[last]=v=u;last=u;
}
if(now==1) rem--;
else now=link[now];
}
}
}T;
int dep[300010],p[300010],pos[100010];
void dfs(int x){
if(T.start[x]+T.len[x]-1>n){dep[x]-=T.len[x],dep[x]+=n-T.start[x]+1,pos[p[x]=n-dep[x]+1]=x;return;}
p[x]=n+1;
for(int i=0;i<11;i++)if(T.ch[x][i]){
dep[T.ch[x][i]]=dep[x]+T.len[T.ch[x][i]];
dfs(T.ch[x][i]);
p[x]=std::min(p[x],p[T.ch[x][i]]);
}
}
int qcnt;
long long ans[100010];
struct que{
int pos,qq,id;
bool operator <(que const &x)const{
return qq<x.qq;
}
}q[100010];
void insert(int x,const char *c){
if(*c==0){
q[qcnt].pos=x,q[qcnt].qq=p[x];
ans[qcnt]=p[x]-1;
return;
}
if(T.ch[x][*c-'0']){
int s=T.ch[x][*c-'0'];
int k=*c-'0';
for(int i=0;i<T.len[s];i++){
if(*c==0){
int cnt=++T.tail;
T.start[cnt]=T.start[s];
T.len[cnt]=i;
T.start[s]+=i;
T.len[s]-=i;
T.ch[x][k]=cnt;
T.ch[cnt][T.s[T.start[s]]]=s;
p[cnt]=p[s];
q[qcnt].pos=cnt,q[qcnt].qq=p[s];
ans[qcnt]=p[s]-1;
return;
}else if(*c!=::s[T.start[s]+i]){
int cnt=++T.tail;
T.start[cnt]=T.start[s];
T.len[cnt]=i;
T.start[s]+=i;
T.len[s]-=i;
T.ch[x][k]=cnt;
T.ch[cnt][T.s[T.start[s]]]=s;
p[cnt]=p[s];
q[qcnt].pos=cnt,q[qcnt].qq=n;
ans[qcnt]=n;
return;
}else ++c;
}
insert(s,c);
}else{
q[qcnt].pos=x,q[qcnt].qq=n;
ans[qcnt]=n;
}
}
int dfn[300010],rk[300010],sz[300010],top[300010],f[300010],son[300010];
void dfs1(int x,int fa){
sz[x]=1,f[x]=fa;
for(int i=0;i<11;i++)if(T.ch[x][i]){
dfs1(T.ch[x][i],x);
sz[x]+=sz[T.ch[x][i]];
if(sz[T.ch[x][i]]>sz[son[x]])son[x]=T.ch[x][i];
}
}
int ct;
void dfs2(int x,int tp){
dfn[x]=++ct,rk[ct]=x,top[x]=tp;
if(son[x])dfs2(son[x],tp);else return;
for(int i=0;i<11;i++)if(T.ch[x][i]&&T.ch[x][i]!=son[x]) dfs2(T.ch[x][i],T.ch[x][i]);
}
struct node{
long long sum,p,tag;
}tr[1200010];
void build(int x=1,int l=1,int r=T.tail){
if(l==r) return tr[x].sum=T.len[rk[l]],void();
int mid=(l+r)>>1,ls=x<<1,rs=x<<1|1;
build(ls,l,mid),build(rs,mid+1,r);
tr[x].sum=tr[ls].sum+tr[rs].sum;
}
void addtag(int x,int t){
tr[x].tag+=t;
tr[x].p+=t*tr[x].sum;
}
void pushdown(int x,int ls,int rs){
if(tr[x].tag)addtag(ls,tr[x].tag),addtag(rs,tr[x].tag),tr[x].tag=0;
}
void update(int pl,int pr,int x=1,int l=1,int r=T.tail){
if(l==pl&&r==pr) return addtag(x,1);
int mid=(l+r)>>1,ls=x<<1,rs=x<<1|1;
pushdown(x,ls,rs);
if(pr<=mid) update(pl,pr,ls,l,mid);
else if(pl>mid) update(pl,pr,rs,mid+1,r);
else update(pl,mid,ls,l,mid),update(mid+1,pr,rs,mid+1,r);
tr[x].p=tr[ls].p+tr[rs].p;
}
long long query(int pl,int pr,int x=1,int l=1,int r=T.tail){
if(l==pl&&r==pr) return tr[x].p;
int mid=(l+r)>>1,ls=x<<1,rs=x<<1|1;
pushdown(x,ls,rs);
if(pr<=mid) return query(pl,pr,ls,l,mid);
else if(pl>mid) return query(pl,pr,rs,mid+1,r);
else return query(pl,mid,ls,l,mid)+query(mid+1,pr,rs,mid+1,r);
}
void update(int x){while(x)update(dfn[top[x]],dfn[x]),x=f[top[x]];}
long long query(int x){long long ans=0;while(x)ans+=query(dfn[top[x]],dfn[x]),x=f[top[x]];return ans;}
int main(){
scanf("%d%s",&n,s+1);
for(int i=1;i<=n;i++)T.extend(s[i]-'0');
s[n+1]='$';
T.extend(10);
T.ch[1][10]=0;
dfs(1);
scanf("%d",&m);
for(qcnt=1;qcnt<=m;qcnt++){
scanf("%s",t+1);
q[qcnt].id=qcnt;
insert(1,t+1);
}
dfs1(1,0);
dfs2(1,1);
build();
std::sort(q+1,q+m+1);
for(int i=1,j=1;i<=m;i++){
for(;j<=q[i].qq;j++)update(pos[j]);
ans[q[i].id]+=query(q[i].pos);
}
for(int i=1;i<=m;i++)printf("%lld\n",ans[i]);
return 0;
}
这个是 $O(n\log n)$:
#include<cstdio>
#include<algorithm>
#include<vector>
int n,m;
char ch[100010],t[100010];
struct suffixTree{
static const int inf=100000000;
int link[200010],len[300010],start[300010],s[200010],n,tail,now,rem,ch[300010][11];
suffixTree():link(),len(),start(),s(),n(0),tail(1),now(1),rem(0),ch(){len[0]=inf;}
int newnode(int st,int le){
link[++tail]=1;start[tail]=st;len[tail]=le;return tail;
}
void extend(int x){
s[++n]=x,rem++;
for(int last=1;rem;){
while(rem>len[ch[now][s[n-rem+1]]]) rem-=len[now=ch[now][s[n-rem+1]]];
int &v=ch[now][s[n-rem+1]];
int c=s[start[v]+rem-1];
if(!v||x==c){
link[last]=now;last=now;
if(!v) v=newnode(n,inf);
else break;
}else{
int u=newnode(start[v],rem-1);
ch[u][c]=v;ch[u][x]=newnode(n,inf);
start[v]+=rem-1;len[v]-=rem-1;
link[last]=v=u;last=u;
}
if(now==1) rem--;
else now=link[now];
}
}
}T;
int dep[300010],p[300010],pos[100010];
std::vector<int> g[300010];
void dfs(int x){
if(T.start[x]+T.len[x]-1>n){dep[x]-=T.len[x],dep[x]+=T.n-T.start[x]+1,pos[p[x]=T.n-dep[x]+1]=x;return;}
p[x]=n+1;
for(int i=0;i<11;i++)if(T.ch[x][i]){
dep[T.ch[x][i]]=dep[x]+T.len[T.ch[x][i]];
dfs(T.ch[x][i]);
p[x]=std::min(p[x],p[T.ch[x][i]]);
// printf("%d %d ",x,T.ch[x][i]);
// for(int j=T.start[T.ch[x][i]];j<std::min(T.n+1,T.start[T.ch[x][i]]+T.len[T.ch[x][i]]);j++)putchar(T.s[j]+'0');
// puts("");
}
}
int qcnt;
long long ans[100010];
struct que{
int pos,qq,id;
bool operator <(que const &x)const{
return qq<x.qq;
}
}q[100010];
void insert(int x,const char *c){
if(*c==0){
q[qcnt].pos=x,q[qcnt].qq=p[x];
ans[qcnt]=p[x]-1;
return;
}
if(T.ch[x][*c-'0']){
int s=T.ch[x][*c-'0'];
int k=*c-'0';
for(int i=0;i<T.len[s];i++){
if(*c==0){
int cnt=++T.tail;
T.start[cnt]=T.start[s];
T.len[cnt]=i;
dep[cnt]=dep[x]+i;
T.start[s]+=i;
T.len[s]-=i;
T.ch[x][k]=cnt;
T.ch[cnt][T.s[T.start[s]]]=s;
p[cnt]=p[s];
q[qcnt].pos=cnt,q[qcnt].qq=p[s];
ans[qcnt]=p[s]-1;
return;
}else if(*c!=::ch[T.start[s]+i]){
int cnt=++T.tail;
T.start[cnt]=T.start[s];
T.len[cnt]=i;
dep[cnt]=dep[x]+i;
T.start[s]+=i;
T.len[s]-=i;
T.ch[x][k]=cnt;
T.ch[cnt][T.s[T.start[s]]]=s;
p[cnt]=p[s];
q[qcnt].pos=cnt,q[qcnt].qq=n;
ans[qcnt]=n;
return;
}else ++c;
}
insert(s,c);
}else{
q[qcnt].pos=x,q[qcnt].qq=n;
ans[qcnt]=n;
}
}
int sz[300010],son[300010],f[300010],k,s[300010][2];
long long val[300010];
void dfs1(int x){
sz[x]=1;
for(int i=0,to;i<11;i++)if(T.ch[x][i]){
g[x].push_back(to=T.ch[x][i]);
val[to]=dep[to]-dep[x];
dfs1(to);
sz[x]+=sz[to];
if(sz[to]>sz[son[x]])son[x]=to;
}
}
int b[300010],bs[300010];
long long ss[300010];
int cbuild(int l,int r){
int x=l,y=r,m=r;
while(x<=y){
int mid=(x+y)>>1;
if(bs[mid-1]-bs[l-1]<=bs[r]-bs[mid-1])m=mid,x=mid+1;
else y=mid-1;
}
x=b[m];
if(m!=l) f[s[x][0]=cbuild(l,m-1)]=x;
if(m!=r) f[s[x][1]=cbuild(m+1,r)]=x;
ss[x]=val[x]+ss[s[x][0]]+ss[s[x][1]];
return x;
}
int build(int x){
int y=x;
do
for(int to:g[y])
if(to!=son[y])
f[build(to)]=y;
while((y=son[y]));
do
b[++y]=x,bs[y]=bs[y-1]+sz[x]-sz[son[x]];
while((x=son[x]));
return cbuild(1,y);
}
long long lz[300010],v[300010];
inline void update(int x){
long long p=0;
int t=1;
while(x){
v[x]=v[x]+p;
if(t){
++lz[x];
if(s[x][1])--lz[s[x][1]];
p=p+val[x]+ss[s[x][0]];
v[x]=v[x]-ss[s[x][1]];
}
t=(x!=s[f[x]][0]?1:0);
if(t&&x!=s[f[x]][1])p=0;
x=f[x];
}
}
inline int query(int x){
long long p=0,ret=0;
int t=1;
while(x){
if(t){
ret=ret+v[x]-v[s[x][1]];
ret=ret-ss[s[x][1]]*lz[s[x][1]];
p=p+val[x]+ss[s[x][0]];
}
ret=ret+p*lz[x];
t=(x!=s[f[x]][0]?1:0);
if(t&&x!=s[f[x]][1])p=0;
x=f[x];
}
return ret;
}
int main(){
scanf("%d%s",&n,ch+1);
for(int i=1;i<=n;i++)T.extend(ch[i]-'0');
ch[n+1]='$';
T.extend(10);
T.ch[1][10]=0;
dfs(1);
scanf("%d",&m);
for(qcnt=1;qcnt<=m;qcnt++){
scanf("%s",t+1);
q[qcnt].id=qcnt;
insert(1,t+1);
}
dfs1(1);
build(1);
std::sort(q+1,q+m+1);
for(int i=1,j=1;i<=m;i++){
for(;j<=q[i].qq;j++)update(pos[j]);
ans[q[i].id]+=query(q[i].pos);
}
for(int i=1;i<=m;i++)printf("%lld\n",ans[i]);
return 0;
}