Virus Tree 2(树)

题目描述
You are given a tree with N vertices and N−1 edges. The vertices are numbered 1 to N, and the
i-th edge connects Vertex ai and bi.
You have coloring materials of K colors. For each vertex in the tree, you will choose one of the
K colors to paint it, so that the following condition is satisfied:
If the distance between two different vertices x and y is less than or equal to two, x and y have different colors.
How many ways are there to paint the tree? Find the count modulo 1 000 000 007.
What is tree?
A tree is a kind of graph. For detail, please see: Wikipedia “Tree (graph theory)”
What is distance?
The distance between two vertices x and y is the minimum number of edges one has to traverse to get from x to y.
Constraints
·1≤N,K≤105
·1≤ai,bi≤N
·The given graph is a tree.c++

输入
Input is given from Standard Input in the following format:web

N K
a1 b1
a2 b2
.
.
.
aN−1 bN−1less

输出
Print the number of ways to paint the tree, modulo 1 000 000 007.svg

样例输入
【样例1】
4 3
1 2
2 3
3 4
【样例2】
5 4
1 2
1 3
1 4
4 5
【样例3】
16 22
12 1
3 1
4 16
7 12
6 2
2 15
5 16
14 16
10 11
3 10
3 13
8 6
16 8
9 12
4 3spa

样例输出
【样例1】
6
【样例2】
48
【样例3】
271414432code

提示
在这里插入图片描述
There are six ways to paint the tree.orm

思路
对每个节点,若该节点为初始节点,则有k种选择方式,若该节点为初始节点的下一个节点,则有C(k-1,deg[i])种选择方式,其余节点则有C(k-2,deg[i])种选择方式,计算总选择方式便可xml

代码实现blog

#pragma GCC optimize(3,"Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N=1e5+5;
const int M=30;
const int INF=0x3f3f3f;
const ull sed=31;
const ll mod=1e9+7;
const double eps=1e-8;
const double PI=acos(-1.0);
typedef pair<int,int>P;
 
vector<int>E[N];
ll dep[N],deg[N],fac[N],inv[N];
int n,m;
 
void dfs(int x,int fa)
{
    dep[x]=dep[fa]+1;
    for(int i=0;i<E[x].size();i++)
    {
        int v=E[x][i];
        if(v==fa) continue;
        dfs(v,x);
        deg[x]++;
    }
}
 
ll qpow(ll a,ll b)
{
    ll ret=1;
    while(b)
    {
        if(b&1) ret=ret*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ret;
}
void init()
{
    fac[0]=1;
    for(int i=1;i<N;i++) fac[i]=fac[i-1]*i%mod;
    inv[N-1]=qpow(fac[N-1],mod-2);
    for(int i=N-2;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
}
 
ll C(ll a,ll b)
{
    if(b>a) return 0;
    return fac[a]*inv[a-b]%mod;
}
int main()
{
    init();
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        E[u].push_back(v);
        E[v].push_back(u);
    }
    dfs(1,-1);
    ll ans=m;
    for(int i=1;i<=n;i++)
    {
        if(dep[i]==1) ans=ans*C(m-1,deg[i])%mod;
        else ans=ans*C(m-2,deg[i])%mod;
    }
    printf("%lld\n",ans);
    return 0;
}