题目描述
You are given a tree with N vertices and N−1 edges. The vertices are numbered 1 to N, and the
i-th edge connects Vertex ai and bi.
You have coloring materials of K colors. For each vertex in the tree, you will choose one of the
K colors to paint it, so that the following condition is satisfied:
If the distance between two different vertices x and y is less than or equal to two, x and y have different colors.
How many ways are there to paint the tree? Find the count modulo 1 000 000 007.
What is tree?
A tree is a kind of graph. For detail, please see: Wikipedia “Tree (graph theory)”
What is distance?
The distance between two vertices x and y is the minimum number of edges one has to traverse to get from x to y.
Constraints
·1≤N,K≤105
·1≤ai,bi≤N
·The given graph is a tree.c++
输入
Input is given from Standard Input in the following format:web
N K
a1 b1
a2 b2
.
.
.
aN−1 bN−1less
输出
Print the number of ways to paint the tree, modulo 1 000 000 007.svg
样例输入
【样例1】
4 3
1 2
2 3
3 4
【样例2】
5 4
1 2
1 3
1 4
4 5
【样例3】
16 22
12 1
3 1
4 16
7 12
6 2
2 15
5 16
14 16
10 11
3 10
3 13
8 6
16 8
9 12
4 3spa
样例输出
【样例1】
6
【样例2】
48
【样例3】
271414432code
提示
There are six ways to paint the tree.orm
思路
对每个节点,若该节点为初始节点,则有k种选择方式,若该节点为初始节点的下一个节点,则有C(k-1,deg[i])种选择方式,其余节点则有C(k-2,deg[i])种选择方式,计算总选择方式便可xml
代码实现blog
#pragma GCC optimize(3,"Ofast","inline") #include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; const int N=1e5+5; const int M=30; const int INF=0x3f3f3f; const ull sed=31; const ll mod=1e9+7; const double eps=1e-8; const double PI=acos(-1.0); typedef pair<int,int>P; vector<int>E[N]; ll dep[N],deg[N],fac[N],inv[N]; int n,m; void dfs(int x,int fa) { dep[x]=dep[fa]+1; for(int i=0;i<E[x].size();i++) { int v=E[x][i]; if(v==fa) continue; dfs(v,x); deg[x]++; } } ll qpow(ll a,ll b) { ll ret=1; while(b) { if(b&1) ret=ret*a%mod; a=a*a%mod; b>>=1; } return ret; } void init() { fac[0]=1; for(int i=1;i<N;i++) fac[i]=fac[i-1]*i%mod; inv[N-1]=qpow(fac[N-1],mod-2); for(int i=N-2;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod; } ll C(ll a,ll b) { if(b>a) return 0; return fac[a]*inv[a-b]%mod; } int main() { init(); scanf("%d%d",&n,&m); for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); E[u].push_back(v); E[v].push_back(u); } dfs(1,-1); ll ans=m; for(int i=1;i<=n;i++) { if(dep[i]==1) ans=ans*C(m-1,deg[i])%mod; else ans=ans*C(m-2,deg[i])%mod; } printf("%lld\n",ans); return 0; }