洛谷P2023《维护序列》

多操作线段树标记下方是有顺序的

题目描述

老师交给小可可一个维护数列的任务,现在小可可希望你来帮他完成。 有长为N的数列,不妨设为a1,a2,…,aN 。有如下三种操作形式:
(1)把数列中的一段数全部乘一个值;
(2)把数列中的一段数全部加一个值;
(3)询问数列中的一段数的和,由于答案可能很大,你只需输出这个数模P的值。

输入输出格式

输入格式:

第一行两个整数N和P(1≤P≤1000000000)。
第二行含有N个非负整数,从左到右依次为a1,a2,…,aN, (0≤ai≤1000000000,1≤i≤N)。
第三行有一个整数M,表示操作总数。
从第四行开始每行描述一个操作,输入的操作有以下三种形式:
操作1:“1 t g c”(不含双引号)。表示把所有满足t≤i≤g的ai改为ai×c(1≤t≤g≤N,0≤c≤1000000000)。
操作2:“2 t g c”(不含双引号)。表示把所有满足t≤i≤g的ai改为ai+c (1≤t≤g≤N,0≤c≤1000000000)。
操作3:“3 t g”(不含双引号)。询问所有满足t≤i≤g的ai的和模P的值 (1≤t≤g≤N)。
同一行相邻两数之间用一个空格隔开,每行开头和末尾没有多余空格。

输出格式:

对每个操作3,按照它在输入中出现的顺序,依次输出一行一个整数表示询问结果。

输入输出样例

输入样例

1
2
3
4
5
6
7
8
7 43
1 2 3 4 5 6 7
5
1 2 5 5
3 2 4
2 3 7 9
3 1 3
3 4 7

输出样例

1
2
3
2
35
8

解题思路

多操作线段树模板题,同《线段树 2》

大致方向

首先我们来康一康只有区间加的时候怎么做
维护一个标记add[i]表示节点i对应的区间[l,r]被加了多少
在下放标记时,sum[i]会被更新为sum[i] + add[i] * (r - l + 1)
我们把它看作 $x + b$ 的形式,其中sum[i]对应$x$, add[i] 对应$b$,后面的看作常数就好啦
那么区间加乘的形式就应该是 $ax+b$,也就意味着要多维护一个标记mul[i]表示节点i对应的区间[l,r]被乘了多少,sum[i]会被更新为sum[i] * mul[i] + add[i] * (r - l + 1)


区间修改

先看乘法,比如i节点对应区间[l,r]被乘了一个$k$,本质上就是$k(ax+b)$,拆出来就是$kax + kb$,也就是把mul[i]add[i]都乘上一个$k$
加法本质上就是 $ax + b + k$,整理得 $ax + (b + k)$,那么把add[i]加上$k$就行了

标记下放

同样地,把每个节点看作 $ax+b$ 的关系,在这里i节点对应的区间和[l,r]为$ax+b$,左子树lc(i)对应的区间和[l, mid]为$a’y+b’$
遵循先乘后加的原则,对左子树乘上一个$a$得

然后加上 $b$

整理得

观察下这个式子,把它写成$ax+b$的形式

发现了什么?

本质上就是,
左子树的乘法标记 乘上 当前点的乘法标记
左子树的加法标记 先乘上 当前点的乘法标记 再加上 当前点的加法标记

mul[lc(i)] *= mul[i], add[lc(i)] = add[lc(i)] * mul[i] + add[i]

对右子树进行一遍同样的操作,清空标记即可(稍有常识的人都知道mul[i]要初始化为1)

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <iostream>
#include <cstring>
#include <cstdio>

#define FILE_IN(__fname) freopen(__fname, "r", stdin)
#define FILE_OUT(__fname) freopen(__fname, "w", stdout)
#define rep(a,s,t,i) for (int a = s; a <= t; a += i)
#define repp(a,s,t,i) for (int a = s; a < t; a += i)
#define countdown(s) while (s --> 0)
#define IMPROVE_IO() std::ios::sync_with_stdio(false)

using std::cin;
using std::cout;
using std::endl;

const int MAXN = 200000 + 10;

int getint() { int x; scanf("%d", &x); return x; }
long long int getll() { long long int x; scanf("%lld", &x); return x; }

int n, CH, m;
long long int a[MAXN / 2];

struct SegmentTree {
long long int sum[MAXN << 2];
long long int mul[MAXN << 2], add[MAXN << 2];

#define lc(x) ((x << 1))
#define rc(x) ((x << 1 | 1))

SegmentTree() {
memset(sum, 0, sizeof sum);
memset(mul, 1, sizeof mul);
memset(add, 0, sizeof add);
}
void PushTag(int root, int l, int r) {
if (mul[root] == 1 && add[root] == 0) return;
// 该处标记不存在或已被下放
if (l != r) {
mul[lc(root)] = mul[lc(root)] * mul[root] % CH;
mul[rc(root)] = mul[rc(root)] * mul[root] % CH;
add[lc(root)] = (add[lc(root)] * mul[root] % CH + add[root]) % CH;
add[rc(root)] = (add[rc(root)] * mul[root] % CH + add[root]) % CH;
}
sum[root] = (sum[root] * mul[root] % CH + add[root] * (r - l + 1) % CH) % CH;
mul[root] = 1; add[root] = 0;
}
void buildTree(int root, int l, int r, long long int *seq) {
mul[root] = 1; add[root] = 0;
if (l == r) { sum[root] = seq[l]; return; }
int mid = (l + r) >> 1;
buildTree(lc(root), l, mid, seq);
buildTree(rc(root), mid + 1, r, seq);
sum[root] = (sum[lc(root)] + sum[rc(root)]) % CH;
}
long long int Query(int root, int l, int r, int ll, int rr) {
PushTag(root, l, r);
if (ll <= l && r <= rr) return sum[root];
long long int ret = 0;
int mid = (l + r) >> 1;
if (ll <= mid) ret = (ret + Query(lc(root), l, mid, ll, rr)) % CH;
if (mid + 1 <= rr) ret = (ret + Query(rc(root), mid + 1, r, ll, rr)) % CH;
return ret;
}
void Modify(int method, int root, int l, int r, int ll, int rr, long long int k) {
PushTag(root, l, r);
if (ll <= l && r <= rr) {
if (method == 1) {
mul[root] = mul[root] * k % CH;
add[root] = add[root] * k % CH;
} else {
add[root] = (add[root] + k) % CH;
}
return;
}
int mid = (l + r) >> 1;
if (ll <= mid) Modify(method, lc(root), l, mid, ll, rr, k);
if (mid + 1 <= rr) Modify(method, rc(root), mid + 1, r, ll, rr, k);
PushTag(lc(root), l, mid);
PushTag(rc(root), mid + 1, r);
sum[root] = (sum[lc(root)] + sum[rc(root)]) % CH;
}
} Tree;

int main() {
n = getint(); CH = getint();
rep (i, 1, n, 1) a[i] = getint();
Tree.buildTree(1, 1, n, a);
m = getint();
countdown (m) {
int op = getint();
int l = getint();
int r = getint();
switch (op) {
case 1: {
long long int k = getll();
Tree.Modify(1, 1, 1, n, l, r, k);
break;
}
case 2: {
long long int k = getll();
Tree.Modify(2, 1, 1, n, l, r, k);
break;
}
case 3: {
printf("%lld\n", Tree.Query(1, 1, n, l, r));
break;
}
}
}
return 0;
}