[NOIP2018 提升组] 保卫王国

[NOIP2018 提升组] 保卫王国


深受启发的题解ios

DP的集大成者。c++


倍增作法:\(O((n+m)\ log\ n)\)


首先,这道题最原始的问题就是经典的最大独立集问题。没有上司的舞会数组

咱们考虑DP:设\(dp(u,0)\)表明节点\(u\)没被选上,\(dp(u,1)\)表明结点\(u\)被选上,而后转移。spa

时间复杂度为\(O(n)\)code

那么对于本题的\(44pts\),咱们只须要每次暴力修改点值,而后跑一遍最大独立集的模板,便可以轻松得到了。blog


接下来,咱们先将问题简单化。图片

咱们能够先考虑每次修改一个点的时候该怎么作;ip

容易看到,对于以该点(不妨设为\(u\))为根的子树,给定状态(指令:选仍是不选,不妨设为\(state\))对应的dp值就是这棵子树对答案的影响。换言之,对于整棵子树,咱们不须要重复计算它们的dp值。咱们只须要将\(dp(u,state)\)做为答案。ci

那么,对于点\(u\)的祖先节点如何统计它们对答案的贡献啊?get

换根

也就是说,咱们直接将点\(u\)做为整棵树新的根结点,再按照咱们刚刚讨论过的“经典最大独立集”的作法进行求解。

这个方法针对于单次修改彻底能接受,但是\(m\)次修改呢?

咱们考虑如下方法:

  • 定义两个数组\(dp1(u,0/1)\)\(dp2(u,0/1)\)分别表明以\(u\)为根的子树最小值和整棵树中扔掉以\(u\)为根的子树的最小值(即不考虑以\(u\)根的这棵子树计算得来的答案,注意\(0\ or\ 1\)指的是\(u\)选不选,但不计算\(u\)的答案);
  • 转移:这里只举一例:\(dp2(u,0)=dp2(fa_u,1)+dp1(fa_u,1)-min(dp1(u,0),dp1(u,1)\)
  • 更新的时候,因为咱们单次修改对后面没有影响,咱们只须要将操做状态对应的dp值输出便可。

能够看到,这个作法其实就相似于换根dp的思想(或就是((()。


如今,咱们拓展——每次修改两个点的状态,该怎么办。

考虑到刚刚咱们是怎么作的。能够发现,对于被修改的两个点,不妨设为\(u\)\(v\)\(u\)\(v\)\(path(u,v)\)\(u\)\(v\)的简单路径所构成的点集)能够当作一个“广义节点”。

类比于刚刚的作法,咱们能够留意:

将这些点聚合成一个“广义节点”便可以按照上述作法解决。

既然如此,那么“广义节点”怎么求?

咱们是单纯把\(path(u,v)\)当成“广义节点”,仍是将路径上牵出来的子树囊括其中?

选择后者。

这是由于前者求解这些子树时依赖于路径上的每个点的选取状态,而想要求解最小值只能枚举


问题1

咱们先考虑——树退化为一条链的状况:\(u\)\(v\)是链上的两个端点。按照咱们以前的作法,直接将全部\(u\)\(v\)路径上的点所有缩掉。值得欢庆的是,这条路径上没有子树。

咱们对于“广义节点”外的结点直接跑最大独立集,对内再来一个独立集。

详细地讲,咱们在求解“广义节点”的内部独立集问题时,只不过和外面“绝缘”。换言之,咱们求解内部独立集的时候,把它当作一个独立的链来统计。

不过这样的效率是极低的,咱们不妨用倍增的思想处理 内部矛盾。

定义\(dp(u,i,0/1,0/1)\)表明从\(u\)\(2^i\)级祖先的最小值。

按照正常的倍增作法去作。而后对于一条链而言,咱们按二进制把它拆开来,一段一段看成求独立集时一个个结点维护便可。(不清楚的看代码)


问题2

咱们再考虑,当\(u\)\(v\)的祖先时的状况。


能够看到,在这个问题里面,与上一个子问题惟一有差别的地儿是:“广义节点”路径上是有子树的。由此,在初始化时,咱们还要算上子树的权值便可。(具体实现细节,能够认真思考)。


最后的问题

最后,咱们将直接面对最通常的状况,即当不存在一个节点是另外一个节点的祖先时,咱们该如何处理。

先抓住两点的最近公共祖先(记做\(lca\))。咱们不难观察到,假若\(lca\)选取状态肯定了,以下面的图片所示,那么这就至关于跑了两遍的“问题\(2\)”(\(u\)\(lca\)\(v\)\(lca\))再加上\(lca\)上面的一大团东西。

但是问题是\(lca\)选取状态不肯定。不要紧,咱们就枚举它的“选取状态”,分别对于每种选取状态进行求解更新。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#include<cmath>
#include<map>
#define PII pair <int, int>
#define MP make_pair
#define RE register
#define CLR(x, y) memset(x,y,sizeof x)
#define FOR(i, x, y) for(RE int i=x;i<=y;++i)
#define ROF(i, x, y) for(RE int i=x;i>=y;--i)
using namespace std;

typedef long long LL;

const int MAXN = 1e5 + 5;
const LL INF = 1e12;

template <class T> void read(T &x)
{
	bool mark = false;
	char ch = getchar();
	for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') mark = true;
	for(x = 0; ch >= '0' && ch <= '9'; ch = getchar()) x = (x << 3) + (x << 1) + ch - '0';
	if(mark) x = -x;
}

map <PII, bool> table;

vector <int> G[MAXN];
int n, m, t;
int F[MAXN][30] = {}, dep[MAXN] = {};
// F[u,i] -> u的2^i级祖先 dep[u] -> u到根节点的深度 T[u] -> u到根节点的距离的log值 
LL dp1[MAXN][2] = {}, dp2[MAXN][2] = {}, dp[MAXN][30][2][2] = {}, p[MAXN] = {};
// dp1[u][0/1] -> 在u子树中, 不选/选 u的最小值  dp2[u][0/1] -> 在整棵树除了u子树中, 不选/选 u的最小值  

// 倍增预处理 
void BFS()
{ 	
	int hh = 0, tt = 0, q[MAXN] = {};
	int u, v;
	dep[1] = 1, q[tt ++] = 1;
	while(hh < tt)
	{
		u = q[hh ++];
		for(RE int i = 0; i < G[u].size(); ++ i)
		{
			v = G[u][i];
			if(dep[v]) continue;
			dep[v] = dep[u] + 1, F[v][0] = u;
			FOR(i, 1, t) F[v][i] = F[F[v][i - 1]][i - 1];
			q[tt ++] = v;
		}
	}
	return;
}
// the prework of dp1[u, 0/1]
void dfs_dp1(int u)
{
	dp1[u][0] = 0, dp1[u][1] = p[u];
	for(RE int i = 0; i < G[u].size(); ++ i)
	{
		int v = G[u][i];
		if(v == F[u][0]) continue;
		dfs_dp1(v);
		dp1[u][0] += dp1[v][1], dp1[u][1] += min(dp1[v][0], dp1[v][1]);
	}
	return;
}
// the prework of dp2[u, 0/1]
void dfs_dp2(int u)// @
{
	int v;
	for(RE int i = 0; i < G[u].size(); ++ i)
	{
		v = G[u][i];
		if(v == F[u][0]) continue;
		dp2[v][0] = dp2[u][1] + dp1[u][1] - min(dp1[v][0], dp1[v][1]);
		dp2[v][1] = min(dp2[v][0], dp2[u][0] + dp1[u][0] - dp1[v][1]);
		dfs_dp2(v);
	}
	return;
}
// the prework of all of those dp arrays
void dp_prework()
{
	dfs_dp1(1), dfs_dp2(1);
	FOR(i, 1, n)
		FOR(j, 0, t) 
			FOR(x, 0, 1)
				FOR(y, 0, 1) dp[i][j][x][y] = INF;
	FOR(i, 2, n)
	{
		int fa = F[i][0];
		dp[i][0][0][1] = dp1[fa][1] - min(dp1[i][0], dp1[i][1]);
		dp[i][0][1][0] = dp1[fa][0] - dp1[i][1];
		dp[i][0][1][1] = dp1[fa][1] - min(dp1[i][0], dp1[i][1]);
	}
	FOR(j, 1, t)
	{
		FOR(i, 1, n)
		{
			int anc = F[i][j - 1];// anc -> ancestor
			dp[i][j][0][0] = min(dp[i][j - 1][0][0] + dp[anc][j - 1][0][0], dp[i][j - 1][0][1] + dp[anc][j - 1][1][0]);
			dp[i][j][0][1] = min(dp[i][j - 1][0][0] + dp[anc][j - 1][0][1], dp[i][j - 1][0][1] + dp[anc][j - 1][1][1]);
			dp[i][j][1][0] = min(dp[i][j - 1][1][0] + dp[anc][j - 1][0][0], dp[i][j - 1][1][1] + dp[anc][j - 1][1][0]);
			dp[i][j][1][1] = min(dp[i][j - 1][1][0] + dp[anc][j - 1][0][1], dp[i][j - 1][1][1] + dp[anc][j - 1][1][1]);
		}
	}
	return;
}
LL solve(int u, bool opt1, int v, bool opt2)
{
	if(dep[u] > dep[v]) swap(u, v), swap(opt1, opt2);
	LL flca[2], fu[2] = {INF, INF}, fv[2] = {INF, INF}, new_fu[2] = {INF, INF}, new_fv[2] = {INF, INF};
	fu[opt1] = dp1[u][opt1], fv[opt2] = dp1[v][opt2];
	ROF(i, t, 0)
	{
		if(dep[F[v][i]] >= dep[u])
		{
			new_fv[0] = new_fv[1] = INF;
			FOR(x, 0, 1)
				FOR(y, 0, 1)
					new_fv[x] = min(new_fv[x], fv[y] + dp[v][i][y][x]);

			FOR(x, 0, 1) fv[x] = new_fv[x];
			v = F[v][i];
		}
	}
	if(u == v) return fv[opt1] + dp2[u][opt1];
	ROF(i, t, 0)
	{
		if(F[u][i] != F[v][i])
		{
			new_fu[0] = new_fu[1] = new_fv[0] = new_fv[1] = INF;
			FOR(x, 0, 1)
				FOR(y, 0, 1)
					new_fv[x] = min(new_fv[x], fv[y] + dp[v][i][y][x]), new_fu[x] = min(new_fu[x], fu[y] + dp[u][i][y][x]);
			
			FOR(x, 0, 1) fu[x] = new_fu[x], fv[x] = new_fv[x];
			u = F[u][i], v = F[v][i];
		}
	}
	int lca = F[u][0];
	flca[0] = dp2[lca][0] + dp1[lca][0] - dp1[u][1] - dp1[v][1] + fu[1] + fv[1];
	flca[1] = dp2[lca][1] + dp1[lca][1] - min(dp1[u][0], dp1[u][1]) - min(dp1[v][0], dp1[v][1]) + min(fu[0], fu[1]) + min(fv[0], fv[1]);
	return min(flca[0], flca[1]);		
}
signed main()
{
	read(n), read(m);
	char type[10];
	cin >> type;
	t = log(n) / log(2) + 1;
	FOR(i, 1, n) read(p[i]), G[i].clear();
	int a, x, b, y;
	FOR(i, 2, n)
	{
		read(x), read(y);
		G[x].push_back(y), G[y].push_back(x);
		table[MP(x, y)] = table[MP(y, x)] = true;
	}
	BFS(), dp_prework();
	FOR(i, 1, m)
	{
		read(a), read(x), read(b), read(y);
		if(!x && !y && table.find(MP(a, b)) != table.end()) puts("-1");
		else printf("%lld\n", solve(a, x, b, y));
	}
	return 0;
}
相关文章
相关标签/搜索