树状数组学习笔记

树状数组学习笔记

树状数组

本文参考视频:https://www.bilibili.com/video/av69667943?from=search&seid=204489840113652018c++

 

lowbit 操做

int lowbit(int x) {
    return x & (-x);
}

\(lowbit\) 操做是为了求出一个数字 \(x\) 在二进制形态下,最低位的 \(1\) 的大小。算法

例如 \((110100)_2\) 中最低位 \(1\) 的大小是 \((100)_2\)数组

\(lowbit\) 求解的方法是,先将 \(x\) 的二进制按位取反,而后 \(+1\) ,再按位与原数字。数据结构

例如:\((110100)_2\)ide

  1. 按位取反 \((001011)_2\)
  2. \(+1\) \((001100)_2\)
  3. 按位与原数 \((000100)_2\)

因为计算机中负数采用补码存储,因而第1、二步的操做能够简化为 \(\times (-1)\)函数

那么,\(lowbit\) 在树状数组中的做用究竟是什么?实际上, \(lowbit(x)\) 表明树状数组中第 \(x\) 位元素覆盖的区间长度,(能够参考顶部图片)即 \(t[x] = \sum_{i=x-lowbit(x)+1}^xa[i]\)。(\(t[]\) 表明树状数组,\(a[]\) 表明原数组)学习

也就是说,树状数组中第 \(x\) 位元素的值表明当前位置到前 \(lowbit(x)\) 位置的全部原数组元素之和。优化

 

单点修改和区间查询

//如下代码,默认原数组为 a[],树状数组为 t[]
void add(int pos, int x) { //pos位置加上x
    for (; pos <= n; pos += lowbit(pos)) { //n为数组大小
        t[pos] += x;
    }
}
int query_presum(int pos) { //查询pos位置的前缀和,即a1 + a2 + ... + apos
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

int query_sum(int l, int r) { //[l, r]区间查询
    return query_presum(r) - query_presum(l - 1);
}

单点修改

树状数组中,每一个节点 \(x\) 的父节点均可以表示为 \(x + lowbit(x)\) 。利用这个性质,咱们就能够作到 \(O(logn)\) 单点修改。例如,咱们想要给 \(a[3]+1\) ,那么咱们须要对 \(t[3],t[4],t[8]\) \(+1\)spa

区间查询

区间查询咱们须要利用前缀和,例如求 \([l, r]\) 的区间和,咱们只需求 \(\sum_{i=1}^ra[i] - \sum_{i=1}^{l-1}a[i]\) 。利用 \(lowbit\) 的性质,咱们知道 \(x\) 位置的元素覆盖的长度为 \(lowbit(x)\) 。因而咱们只需每次将下标减去 \(lowbit(x)\) ,将当前位置的数值加上便可。例如 \(presum(7) = t[7] + t[6] + t[4]\)

例题1 P3374 【模板】树状数组 1

连接:https://www.luogu.com.cn/problem/P3374

#define _CRT_SECURE_NO_WARNINGS
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma comment(linker, "/stack:200000000")
#include <bits/stdc++.h>
#define SIZE 500010
#define rep(i, a, b) for (long long i = a; i <= b; ++i)
#define ll long long
using namespace std;
void io() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); }
int n, m;
int t[SIZE];
int lowbit(int x) { return x & (-x); }

void add(int pos, int x) { //pos位置加上x
    for (; pos <= n; pos += lowbit(pos)) { //n为数组大小
        t[pos] += x;
    }
}

int query_presum(int pos) { //查询pos位置的前缀和,即a1 + a2 + ... + apos
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

int query_sum(int l, int r) { //[l, r]区间查询
    return query_presum(r) - query_presum(l - 1);
}

int main() {
    io(); cin >> n >> m;
    rep(i, 1, n) {
        int x; cin >> x;
        add(i, x);
    }
    rep(i, 1, m) {
        int op; cin >> op;
        if (op == 1) {
            int pos, x; cin >> pos >> x;
            add(pos, x);
        }
        else {
            int l, r; cin >> l >> r;
            ll ans = query_sum(l, r);
            cout << ans << '\n';
        }
    }
}

 

区间修改和单点查询

这一部分的建树与以前不一样,先前所述的单点修改和区间查询,咱们只须要对于 \(a[i]\) 创建树状数组;可是如今咱们须要对 \(a[i]\) 的差分数组 \(p[i]\) 建树。

void add(int l, int r, int x) { //[l, r] 区间+x
    add(l, x);
    add(r + 1, -x);
}
int query_presum(int pos) { //单点查询,即对差分数组求前缀和
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

差分思想

为了快速实现区间加和单点查询操做,咱们须要维护一个差分数组 \(p[i] = a[i] - a[i-1]\) ,而后对 \(p[i]\) 建树;咱们容易发现,对于差分数组求前缀和,即为单点查询:

\(\sum_{i-1}^xp[i]=(a[x]-a[x-1]) + (a[x-1]-a[x-2]) + ... + (a[2]-a[1]) + a[1] = a[x]\)

因而,对于一个差分数组,咱们能够利用树状数组 \(O(logn)\) 求前缀和的性质,实现更快的单点查询。那么,如何实现区间修改操做?

咱们不难发现, \(a\) 数组的区间 \([l, r]\) 同时加上一个数值 \(x\) 时,它的差分数组只有首尾两项的值会发生变化,由于差分数组维护的是相邻数字的差值,因此一个区间同时加上一个数字时,这个区间中的相邻数字的差值其实不会改变。因而,咱们只须要对 \(p[l]+x,p[r + 1-x]\) 便可,即进行两次单点修改。

例题2 P3368 【模板】树状数组 2

连接:https://www.luogu.com.cn/problem/P3368

#define _CRT_SECURE_NO_WARNINGS
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma comment(linker, "/stack:200000000")
#include <bits/stdc++.h>
#define SIZE 500010
#define rep(i, a, b) for (long long i = a; i <= b; ++i)
#define ll long long
using namespace std;
void io() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); }
int n, m;
int t[SIZE], a[SIZE];
int lowbit(int x) { return x & (-x); }

void add(int pos, int x) { //pos位置加上x
    for (; pos <= n; pos += lowbit(pos)) { //n为数组大小
        t[pos] += x;
    }
}

void add(int l, int r, int x) { //[l, r] 区间+x
    add(l, x);
    add(r + 1, -x);
}

int query_presum(int pos) { //单点查询,即对差分数组求前缀和
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}
int main() {
    io(); cin >> n >> m;
    rep(i, 1, n) cin >> a[i];
    rep(i, 1, n) {
        int x = a[i] - a[i - 1];
        add(i, x);
    }
    rep(i, 1, m) {
        int op; cin >> op;
        if (op == 1) {
            int l, r, x; cin >> l >> r >> x;
            add(l, r, x);
        }
        else {
            int pos; cin >> pos;
            ll ans = query_presum(pos);
            cout << ans << '\n';
        }
    }
}

 

区间修改和区间查询

对于单点修改,咱们能够作到区间查询;那么,对于区间修改咱们是否只能作到单点查询?答案是否认的,咱们仍然能够经过维护差分数组的方法作到区间查询。

void add(int pos, int x, int t[]) { //由于要维护两个数组,加一个参数
    for (; pos <= n; pos += lowbit(pos)) {
        t[pos] += x;
    }
}

void add(int l, int r, int x) { //[l, r] 区间+x
    add(l, x, t1);
    add(r + 1, -x, t1);
    add(l, l * x, t2);
    add(r + 1, -x * (r + 1), t2);
}

int query_presum(int pos, int t[]) { //单点查询,即对差分数组求前缀和
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

int query_sum2(int l, int r) { //区间修改下的区间查询
    int p1 = l * query_presum(l - 1, t1) - query_presum(l - 1, t2);
    int p2 = (r + 1) * query_presum(r, t1) - query_presum(r, t2);
    return p2 - p1;
}

区间查询

咱们仍然是从前缀和的角度出发,对于一个区间查询操做,咱们看做两次前缀和查询。

所以咱们考虑求 \(presum(x)=\sum^{x}_{i=1}a[i]=\sum^{x}_{i=1}\sum^{i}_{j=1}p[j]\) 。显然,这个式子难以计算,咱们须要对它变形:

\(\sum^{x}_{i=1}\sum^{i}_{j=1}p[j]=(x+1)\sum_{i=1}^{x}p[i]-\sum_{i=1}^{x}i\times p[i]\) (这步变换经过几何意义更容易理解,可参考上文提到的视频)

对于上述变形,咱们可使用另外一个树状数组维护 \(i\times p[i]\) 的前缀和来快速计算这个式子(想想为何?由于 \(p[i]\) 是一个差分数组,区间修改只会改变两项数值,所以 \(\times i\) 后,仍然只有首尾两项变化)。即:

//区间 [l, r] + x 操做时,还须要维护新的差分数组 i * p[i]
add1(l, x);
add1(r + 1, -x);
add2(l, l * x);
add2(r + 1, -x * (r + 1))
//add1操做维护 p[i],add2操做维护 i * p[i]

例题3 P3372 【模板】线段树 1

连接:https://www.luogu.com.cn/problem/P3372

#define _CRT_SECURE_NO_WARNINGS
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma comment(linker, "/stack:200000000")
#include <bits/stdc++.h>
#define SIZE 500010
#define rep(i, a, b) for (long long i = a; i <= b; ++i)
#define int long long
using namespace std;
void io() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); }
int n, m;
int t1[SIZE], t2[SIZE], a[SIZE];
int lowbit(int x) { return x & (-x); }
void add(int pos, int x, int t[]) { //由于要维护两个数组,加一个参数
    for (; pos <= n; pos += lowbit(pos)) {
        t[pos] += x;
    }
}

void add(int l, int r, int x) { //[l, r] 区间+x
    add(l, x, t1);
    add(r + 1, -x, t1);
    add(l, l * x, t2);
    add(r + 1, -x * (r + 1), t2);
}

int query_presum(int pos, int t[]) { //单点查询,即对差分数组求前缀和
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

int query_sum2(int l, int r) { //区间修改下的区间查询
    int p1 = l * query_presum(l - 1, t1) - query_presum(l - 1, t2);
    int p2 = (r + 1) * query_presum(r, t1) - query_presum(r, t2);
    return p2 - p1;
}

signed main() {
    io(); cin >> n >> m;
    rep(i, 1, n) cin >> a[i];
    rep(i, 1, n) {
        int x = a[i] - a[i - 1];
        add(i, x, t1);
        add(i, x * i, t2);
    }
    rep(i, 1, m) {
        int op; cin >> op;
        if (op == 1) {
            int l, r, x; cin >> l >> r >> x;
            add(l, r, x);
        }
        else {
            int l, r; cin >> l >> r;
            cout << query_sum2(l, r) << '\n';
        }
    }
}

 

简单应用

P1908 逆序对

连接:https://www.luogu.com.cn/problem/P1908

题解:求逆序对不只能够归并排序,还能用树状数组解决。因为数据可能很大,因此咱们须要先对数据离散化。离散化实际上就是创建原数组到一个 \(1,2,3, ..., n\) 的数组的映射关系;例如 24 33 1 99 25 等价于 2 4 1 5 3

须要注意的是,原数组中若是有相等的元素,离散化后他们的相对位置不能变化

完成离散化后,咱们考虑如何对离散化数组求逆序对:对于任意一个位置 \(pos\) 的元素而言,咱们须要求的实际上就是在 \(a_{pos}\) 以前而且大于它的元素,联系到树状数组可以快速维护前缀和的性质,咱们不难发现咱们只须要把某一位置以前全部小于它的元素置为 \(1\) ,小于它的元素置为 \(0\) ,就能用前缀和快速计算贡献。

设离散化后的数组为 \(p[]\) ,对于这个数组咱们从 \(1\)\(n\) 遍历,在任意一个位置 \(j\) 作以下操做:

for (int j = 1; j <= n; ++j) {
    add(j, 1); //单点修改,将 p[j] 置为 1
    ans += j - presum(p[j]); //计算贡献
}

显然,对于 \(p_j\) 而言,要统计他的贡献不须要考虑 \(j\) 位置以后的元素,上方所述的操做就能够将全部逆序对找到。

#define _CRT_SECURE_NO_WARNINGS
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma comment(linker, "/stack:200000000")
#include <bits/stdc++.h>
#define SIZE 500010
#define rep(i, a, b) for (long long i = a; i <= b; ++i)
#define int long long
using namespace std;
void io() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); }
int n, m, ans;
int t[SIZE], a[SIZE];
int pos[SIZE]; //离散化
struct Node {
    int val;
    int id;
    bool operator< (const Node& b) {
        return (val < b.val) || (val == b.val && id < b.id);
    }
}p[SIZE];
int lowbit(int x) { return x & (-x); }
void add(int pos, int x) {
    for (; pos <= n; pos += lowbit(pos)) {
        t[pos] += x;
    }
}

int query_presum(int pos) {
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

signed main() {
    io(); cin >> n;
    rep(i, 1, n) cin >> p[i].val, p[i].id = i;
    sort(p + 1, p + 1 + n);
    rep(i, 1, n) pos[p[i].id] = i;
    rep(i, 1, n) {
        add(pos[i], 1);
        ans += i - query_presum(pos[i]);
    }
    cout << ans;
}

P1972 [SDOI2009]HH的项链

连接:https://www.luogu.com.cn/problem/P1972

题解:刚开始想这道题的时候可能会认为须要一些可持久化的数据结构维护,事实上咱们只须要经过树状数组维护便可。

首先,咱们先要想到能够经过离线操做使得无序给出的查询区间有序,使得咱们能够避免重复更新区间。先将全部区间读入,而后以区间右端点为关键字排序。而后咱们维护一个树状数组,来记录贝壳的种类数量;可是某个贝壳重复出现怎么办?事实上对于重复出现的贝壳,咱们只须要考虑最右边的贝壳:例如 1 2 3 1 2 ,其实是 0 0 1 1 1 。更新过程以下:(自上而下对应 \(5\) 次更新)

1 0 0 0 0
1 1 0 0 0
1 1 1 0 0
0 1 1 1 0
0 0 1 1 1

为了实现这个过程,咱们还须要一个数组来记录某种贝壳是否先前出现过,以及它出现的位置。因而,咱们就能够对于每一个询问区间更新到它的右端点,而且只保留最后出现的贝壳。这样,对于每次询问的区间 \([l,r]\) ,咱们只须要记录 \(query\)_\(sum(l,r)\) 便可

#define _CRT_SECURE_NO_WARNINGS
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma comment(linker, "/stack:200000000")
#include <bits/stdc++.h>
#define SIZE 500010
#define rep(i, a, b) for (long long i = a; i <= b; ++i)
#define int long long
using namespace std;
void io() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); }
int n, m, k, ans, nxt;
int t[SIZE], a[SIZE];
int Map[SIZE];
int lowbit(int x) { return x & (-x); }
struct Node {
    int l, r;
    int id;
    bool operator< (const Node& b) {
        return (r < b.r) || (r == b.r && l < b.l);
    }
}p[SIZE];
void add(int pos, int x) {
    for (; pos <= n; pos += lowbit(pos)) {
        t[pos] += x;
    }
}

int query_presum(int pos) {
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

int query_sum(int l, int r) {
    return query_presum(r) - query_presum(l - 1);
}

signed main() {
    io(); cin >> n;
    rep(i, 1, n) cin >> a[i];
    cin >> m;
    rep(i, 1, m) cin >> p[i].l >> p[i].r, p[i].id = i;
    sort(p + 1, p + 1 + m);
    nxt = 1;
    vector<int> vec(m + 1);
    rep(i, 1, m) {
        rep(j, nxt, p[i].r) {
            if (Map[a[j]]) add(Map[a[j]], -1);
            add(j, 1);
            Map[a[j]] = j;
        }
        nxt = p[i].r + 1;
        vec[p[i].id] = query_sum(p[i].l, p[i].r);
    }
    rep(i, 1, m) cout << vec[i] << '\n';
}

P5673 【SWTR-02】Picking Gifts

连接:https://www.luogu.com.cn/problem/P5673

题解:显然,本题能够看做是前一题的升级版。不一样点在于,上一题一个区间内不能存在相同元素;而这一题能够从右往左存在 \(k\) 个相同元素,处理方法和前一题相似。

#define _CRT_SECURE_NO_WARNINGS
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma comment(linker, "/stack:200000000")
#include <bits/stdc++.h>
#define SIZE 500010
#define rep(i, a, b) for (long long i = a; i <= b; ++i)
#define int long long
using namespace std;
void io() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); }
int n, m, k, ans, nxt;
int t[SIZE], a[SIZE], v[SIZE];
vector<int> q[SIZE];
vector<int> vec(SIZE >> 1);
int lowbit(int x) { return x & (-x); }
struct Node {
    int l, r;
    int id;
    bool operator< (const Node& b) {
        return (r < b.r) || (r == b.r && l < b.l);
    }
}p[SIZE];
void add(int pos, int x) {
    for (; pos <= n; pos += lowbit(pos)) {
        t[pos] += x;
    }
}

int query_presum(int pos) {
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

int query_sum(int l, int r) {
    return query_presum(r) - query_presum(l - 1);
}

int main() {
    io(); cin >> n >> m >> k; --k;
    rep(i, 1, n) cin >> a[i];
    rep(i, 1, n) cin >> v[i], add(i, v[i]);
    rep(i, 1, m) {
        cin >> p[i].l >> p[i].r;
        p[i].id = i;
    }
    sort(p + 1, p + 1 + m);
    nxt = 1;
    rep(i, 1, m) {
        rep(j, nxt, p[i].r) {
            if (q[a[j]].size() >= k) {
                add(q[a[j]][0], -v[q[a[j]][0]]);
                q[a[j]].erase(q[a[j]].begin());
            }
            q[a[j]].emplace_back(j);
        }
        nxt = p[i].r + 1;
        vec[p[i].id] = query_sum(p[i].l, p[i].r);
    }
    rep(i, 1, m) cout << vec[i] << '\n';
}

P3369 【模板】普通平衡树

连接:https://www.luogu.com.cn/problem/P3369

题解:首先确定要离线操做,而后离散化,注意操做 \(4\) 不须要离散化。

单点加减操做咱们已经很熟悉了,只须要分别对于元素所在位置 \(+1\)\(-1\) 便可。

接着就是本题的核心操做求元素排名,第 \(k\) 大元素和前驱后继。为了方便表述,咱们将离散化后的数组表示为 \(a[]\)

那么求元素 \(a[pos]\) 的排名就变得至关简单了,注意到增删元素只是 \(±1\) ,所以 \(a[pos]\) 的排名即为 \(query\)_\(presum(a[pos] - 1) + 1\) ,即求出全部比它小的元素数量而后 \(+1\) 。能够注意的一点是,求逆序对的操做就是求排名。

对于第 \(k\) 大元素,注意到树状数组的二进制特征,咱们可使用倍增快速找到其位置(不熟悉倍增思想能够回想一下快速幂的实现。因为树状数组的 \(lowbit\) 构成特征,咱们能够经过倍增优化算法而不是二分查找)。具体实现能够参考代码中的 \(kth()\) 函数,而且结合树状数组的构成图理解。

有了上面的两种思想,咱们不难发现求前驱和后继就是上述操做的综合,先求出元素 \(a[pos]\) 的排名 \(rank_{a[pos]}\) ,前驱和后继就能分别表示为第 \(rank_{a[pos]}-1\)\(rank_{a[pos]} + 1\) 大元素。

#define _CRT_SECURE_NO_WARNINGS
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma comment(linker, "/stack:200000000")
#include <bits/stdc++.h>
#define SIZE 500010
#define rep(i, a, b) for (long long i = a; i <= b; ++i)
#define int long long
using namespace std;
void io() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); }
int n, m, k, ans, cnt;
int t[SIZE], op[SIZE], a[SIZE], p[SIZE];
int lowbit(int x) { return x & (-x); }
void add(int pos, int x) {
    for (; pos <= n; pos += lowbit(pos)) {
        t[pos] += x;
    }
}

int query_presum(int pos) {
    int ans = 0;
    for (; pos > 0; pos -= lowbit(pos)) {
        ans += t[pos];
    }
    return ans;
}

int query_sum(int l, int r) {
    return query_presum(r) - query_presum(l - 1);
}

int kth(int k) {
    int ans = 0, cnt = 0;
    for (int i = 20; i >= 0; i--) {
        ans += (1 << i);
        if (ans > n || cnt + t[ans] >= k) ans -= (1 << i);
        else cnt += t[ans];
    }
    return ++ans;
}

int main() {
    io(); cin >> n;
    rep(i, 1, n) {
        cin >> op[i] >> a[i];
        if (op[i] != 4) p[++cnt] = a[i];
    }
    sort(p + 1, p + 1 + cnt);
    rep(i, 1, n) { //离散化
        if (op[i] != 4) {
            a[i] = lower_bound(p + 1, p + 1 + cnt, a[i]) - p;
        }
    }
    rep(i, 1, n) {
        if (op[i] == 1) add(a[i], 1);
        else if (op[i] == 2) add(a[i], -1);
        else if (op[i] == 3) cout << query_presum(a[i] - 1) + 1 << '\n';
        else if (op[i] == 4) cout << p[kth(a[i])] << '\n';
        else if (op[i] == 5) cout << p[kth(query_presum(a[i] - 1))] << '\n';
        else cout << p[kth(query_presum(a[i]) + 1)] << '\n';
    }
}

  因为硬盘损坏,许多数据丢失,保留的笔记先挂到博客上。

相关文章
相关标签/搜索