Codeforces 613D-Kingdom and its Cities 虚树

一道虚树板子题。

虚树的作用就是:有的时候解决一些树上问题时,整棵树大小很大,但是所需要的“关键点”大小很小。为了节省时间 对有kk个关键点的虚树进行处理,能使得整棵树的大小变成2k2k

虚树在这里:OIWIKI-虚树有很详细的介绍。

其实就是利用栈维护树上的一条链。然后考虑栈顶节点和下一个关键点的关系,如果lcalca等于栈顶(在同一条链)就直接入栈。否则分类讨论栈中top1top-1的节点xxlcalca的关系:

  1. dfs(x)>dfs(lca)dfs(x)>dfs(lca)这代表xx和栈顶节点构成的链在树的“分叉点”下面的情况(lcalca为分叉点)。所以直接加入xx和栈顶之间的边,然后弹栈顶。注意 这一步显然需要用while执行多次。
  2. dfs(x)<dfs(lca)dfs(x)<dfs(lca)这代表xx和栈顶节点构成的链(不包含两端,因为这里lcaxlca\neq x,前面又判断了lcastk[top]lca\neq stk[top])上的一个节点就是lcalca。所以lcalca没入栈,因此添加栈顶和lcalca的边,栈顶出栈,lcalca入栈即可。
  3. dfs(x)=dfs(lca)dfs(x)=dfs(lca)这代表lca=xlca=x。因此加入边(lca,x)(lca,x),然后出栈即可。

这里有个实现的小技巧:在每次入栈一个新的节点时进行临接表初始化即可。

最后不要忘了将栈中剩余的节点代表的链加进去。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<queue>
#define vi vector<int>
#define pb push_back
#define mk make_pair
#define pii pair<int,int>
#define rep(i,a,b) for(int i=(a),i##end=(b);i<=i##end;i++)
#define fi first
#define se second
typedef long long ll;
using namespace std;
const int maxn=100500;
vi side[maxn],vit[maxn],kp;
int dfn[maxn],f[maxn][23],dep[maxn],dfn_cnt,stk[maxn],top;
void dfs(int u,int fa){
    f[u][0]=fa;dep[u]=dep[fa]+1;dfn[u]=++dfn_cnt;
    rep(i,1,22)f[u][i]=f[f[u][i-1]][i-1];
    rep(i,0,(int)side[u].size()-1){
        int v=side[u][i];if(v==fa)continue;
        dfs(v,u);
    }
}
bool iskp[maxn];
bool cmp(int a,int b){return dfn[a]<dfn[b];}
void roll_back(int k){
    rep(i,0,k-1)iskp[kp[i]]=0;
}
void in(int x){
    vit[x].clear();stk[++top]=x;
}   
int k;
int c[maxn];
ll dp(int u,int fa){
    ll tot=0,ans=0;
    rep(i,0,(int)vit[u].size()-1){
        int v=vit[u][i];if(v==fa)continue;
        ans+=dp(v,u);tot+=c[v];
    }
    if(iskp[u]){
        c[u]=1;ans+=tot;        
    }
    else if(tot>1)c[u]=0,ans++;
    else c[u]=tot;
    return ans;
}
int lca(int u,int v){
    if(dep[u]<dep[v])swap(u,v);
    for(int i=22;i>=0;i--)if(dep[f[u][i]]>=dep[v])u=f[u][i];
    if(u==v)return u;
    for(int i=22;i>=0;i--)if(f[u][i]!=f[v][i])u=f[u][i],v=f[v][i];
    return f[u][0];
}
int main(){
    int n;scanf("%d",&n);
    rep(i,1,n-1){
        int u,v;scanf("%d%d",&u,&v);side[u].pb(v);side[v].pb(u);
    }
    dfs(1,0);
    int q;scanf("%d",&q);
    while(q--){
        scanf("%d",&k);kp.clear();
        rep(i,1,k){
            int u;scanf("%d",&u);kp.pb(u);iskp[u]=1;
        }
        bool flag=1;
        rep(i,0,k-1){
            if(iskp[f[kp[i]][0]]){
                printf("-1\n");roll_back(k);flag=0;break;
            }
        }
        if(!flag)continue;
        sort(kp.begin(),kp.end(),cmp);
        stk[top=1]=1;vit[1].clear();
        rep(i,0,k-1){
            int x=kp[i];
            if(x==1)continue;
            int l=lca(x,stk[top]);
            if(l==stk[top]){in(x);continue;}
            while(dfn[l]<dfn[stk[top-1]]){vit[stk[top-1]].pb(stk[top]);vit[stk[top]].pb(stk[top-1]);top--;}
            if(dfn[l]>dfn[stk[top-1]]){
                vit[l].clear();vit[l].pb(stk[top]);vit[stk[top]].pb(l);stk[top]=l;
            }
            else {vit[stk[top]].pb(l);vit[l].pb(stk[top]);top--;}
            in(x);
        }
        rep(i,1,top-1){vit[stk[i]].pb(stk[i+1]);vit[stk[i+1]].pb(stk[i]);}
        printf("%lld\n",dp(1,0));
        roll_back(k);
    }
}