CF1034C Region Separation

CF2700还是做不来啊嘤嘤嘤,这思路太仙了。

我们先考虑,如果我们把它分成了kk块然后终止操作,设S=aiS=\sum a_i,则每块权值和均为Sk\frac {S}{k}。而且我们发现,对于一个点,若其子树权值和为Sk\frac {S}{k},唯一的操作是把他这棵子树砍下来,不断操作至不能操作。发现其结果等价于将所有si0(modSk)s_i \equiv 0 (\bmod \frac {S}{k})的点至其父亲的边砍断的结果(其中sis_i为子树权值和),而且显然这是一个唯一解。

我们令f(k)f(k)为一个值为true/false\text{true} / \text{false}的函数,代表把全树分成kk块是否合法。f(k)=1f(k)=1的条件显然为有kk个满足si0(modSk)s_i \equiv 0 (\bmod \frac {S}{k})的点(每个点都是一个切完后连通块的根)。考虑如何求f(k)f(k),直接求每次复杂度都是O(n)O(n)的。思考下什么时候一个点ii会对f(k)f(k)的结果产生贡献,当且仅当sis_iSk\frac {S}{k}的倍数,即si=Sk×ts_i=\frac {S}{k}\times t,变形得k=Ssi×tk=\frac {S}{s_i}\times t。尝试消掉分母sis_i,得到k=Sgcd(S,si)×tk=\frac{S}{gcd(S,s_i)}\times t。像这样的有一堆数,每个数都对其所有倍数有1的贡献的问题,可以划归为toti=jicntjtot_i=\sum _{j\mid i}cnt_j。自然数约数和为nlnnn\ln n。这样我们就能求出所有的f(k)f(k)了。

然后只需要进行DP,因为联通块时刻保持和相等,所以每个连通块每轮进行切割的次数要和其他连通块相等,而且把任意一个连通块分为pp份的方案是唯一的。所以显然有dpk=f(k)×jkdpjdp_k=f(k)\times \sum _{j\mid k}dp_j

/*
                                                  
  _|_|                              _|  _|    _|  
_|    _|  _|  _|_|  _|_|_|_|        _|  _|  _|    
_|    _|  _|_|          _|          _|  _|_|      
_|    _|  _|          _|      _|    _|  _|  _|    
  _|_|    _|        _|_|_|_|    _|_|    _|    _|  
                                                                                                    
*/ 
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<queue>
#include<set>
//#define ls (rt<<1)
//#define rs (rt<<1|1)
#define vi vector<int>
#define pb push_back
#define mk make_pair
#define pii pair<int,int>
#define rep(i,a,b) for(int i=(a),i##end=(b);i<=i##end;i++)
#define fi first
#define se second
typedef long long ll;
using namespace std;
const int maxn=1e6+100;
const int mod=1e9+7;
vi side[maxn];
int a[maxn];
ll s[maxn];
int cnt[maxn],tot[maxn];
bool f[maxn];
void dfs(int u,int fa){
	s[u]=a[u];
	rep(i,0,side[u].size()-1){
		int v=side[u][i];if(v==fa)continue;
		dfs(v,u);s[u]+=s[v];
	}
}
ll gcd(ll a,ll b){return (!b)?a:gcd(b,a%b);}
ll dp[maxn];
int main(){
	ios::sync_with_stdio(0);
	int n;cin>>n;
	rep(i,1,n)cin>>a[i];
	rep(i,2,n){
		int p;cin>>p;
		side[i].pb(p);side[p].pb(i);
	}
	dfs(1,0);
	rep(i,1,n)if(s[1]/gcd(s[1],s[i])<=n)cnt[s[1]/gcd(s[1],s[i])]++;
	rep(i,1,n)for(int j=i;j<=n;j+=i)tot[j]+=cnt[i];
	rep(i,1,n)if(tot[i]==i)f[i]=1;
	dp[1]=1;
	rep(i,1,n){
		dp[i]*=f[i];
		for(int j=2*i;j<=n;j+=i)dp[j]=(dp[j]+dp[i])%mod;
	}
	ll ans=0;
	rep(i,1,n)ans=(ans+dp[i])%mod;
	cout<<ans;
	return 0;
}