JOISC2020 首都

一开始所有人都以为,答案就是统计每种颜色所有点之间的路径上有多少种不同的颜色。后来才发现这是假的。因为如果两种颜色路径上有其他颜色cc,那么不仅仅是这个点,颜色cc所有点都要连入联通块。

对于这种“选了一个颜色,就必须选其他颜色”的问题,可以想到把关系连边之后用强连通分量解决。我们把得到的关系图用tarjan缩成DAG。然后没有出度的那些点里选SCCSCC最小的就好了(因为选又出度的必定不优秀)。

那么问题来了,如何在所有颜色之间依靠关系连边呢?一种方法是建虚树,然后把虚树上每条边代表的原树上的链用倍增优化连边(或者直接树剖也行?)连一下关系即可。但是这样搞有些麻烦,我们发现这张图连一些重边也不会有啥影响,所以我们每次可以把要加入虚树的所有点和这些点的的LCA之间的链连一条边,这样就很方便了。

有点细节。

/*
                                                  
  _|_|                              _|  _|    _|  
_|    _|  _|  _|_|  _|_|_|_|        _|  _|  _|    
_|    _|  _|_|          _|          _|  _|_|      
_|    _|  _|          _|      _|    _|  _|  _|    
  _|_|    _|        _|_|_|_|    _|_|    _|    _|  
                                                                                                    
*/ 
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<queue>
#include<set>
//#define ls (rt<<1)
//#define rs (rt<<1|1)
#define vi vector<int>
#define pb push_back
#define mk make_pair
#define pii pair<int,int>
#define rep(i,a,b) for(int i=(a),i##end=(b);i<=i##end;i++)
#define fi first
#define se second
typedef long long ll;
using namespace std;
const int maxn=(2e5+10);
const int po=maxn*20;
int f[maxn][20],vpointid[maxn][20],cnt;
vi edge[maxn],side[po];
int color[maxn],dep[maxn],dfn[po],dfn_cnt;
int ans=1e9;
void dfs1(int u,int fa){
	dep[u]=dep[fa]+1;dfn[u]=++dfn_cnt;
	f[u][0]=fa;rep(i,1,18)f[u][i]=f[f[u][i-1]][i-1];
	vpointid[u][0]=++cnt;if(fa)side[cnt].pb(color[fa]);
	rep(i,1,18){
		vpointid[u][i]=++cnt;
		if(vpointid[u][i-1])side[cnt].pb(vpointid[u][i-1]);
		if(vpointid[f[u][i-1]][i-1])side[cnt].pb(vpointid[f[u][i-1]][i-1]);
	}
	rep(i,0,edge[u].size()-1){
		int v=edge[u][i];if(v==fa)continue;
		dfs1(v,u);
	}
}
int lca(int u,int v){
	if(dep[u]<dep[v])swap(u,v);
	for(int i=18;i>=0;i--)if(dep[f[u][i]]>=dep[v])u=f[u][i];
	if(u==v)return u;
	for(int i=18;i>=0;i--)if(f[u][i]!=f[v][i])u=f[u][i],v=f[v][i];
	return f[u][0];
}
int stk[po];
vi town[maxn],tmp;
bool ontree[po];
bool cmp1(int a,int b){return dfn[a]<dfn[b];}
void add_chain(int u,int fa,int c){
	for(int i=18;i>=0;i--){
		if(dep[f[u][i]]>=dep[fa]){
			side[c].pb(vpointid[u][i]);
			u=f[u][i];
		}
	}
}
int low[po],scc[po],scc_cnt,top;
int n,k;
vi buc;
void tarjan(int u){
	dfn[u]=low[u]=++dfn_cnt;stk[++top]=u;
	rep(i,0,(int)(side[u].size())-1){
		int v=side[u][i];
		if(!dfn[v])tarjan(v),low[u]=min(low[u],low[v]);
		else if(!scc[v])low[u]=min(low[u],dfn[v]);
	}
	if(dfn[u]==low[u]){
		buc.clear();
		scc_cnt++;int x=0;int sz=0;
		do{	
			x=stk[top--];
			scc[x]=scc_cnt;
			if(x<=k)sz++;
			buc.pb(x);
		}while(x!=u);bool flag=0;
		rep(j,0,buc.size()-1)if(!flag){
			int idx=buc[j];
			rep(k,0,(int)(side[idx].size())-1){
				int v=side[idx][k];
				if(scc[v]!=scc_cnt){
					flag=1;break;
				}
			}
		}
		if(!flag)ans=min(ans,sz);
	}
}
int main(){
	scanf("%d%d",&n,&k);cnt=k;
	rep(i,1,n-1){
		int u,v;scanf("%d%d",&u,&v);
		edge[u].pb(v);edge[v].pb(u);
	}
	rep(i,1,n)scanf("%d",&color[i]),town[color[i]].pb(i);
	dfs1(1,0);
	rep(i,1,k){
		tmp=town[i];int sz=tmp.size();
		int l=tmp[0];
		rep(j,1,sz-1){
			l=lca(tmp[j],l);
		}
		rep(j,0,(int)(tmp.size())-1){
			add_chain(tmp[j],l,i);
		}
	}
	memset(dfn,0,sizeof(dfn));dfn_cnt=0;top=0;
	rep(i,1,k)if(!dfn[i])tarjan(i);
	cout<<ans-1;
	return 0;
}