【学习笔记】线段树浅谈
线段树,顾名思义,就是一棵树,上面的每个节点由一个“线段”构成,然后组成了一棵除去了最后一层外,是完全二叉树的树,又叫区间树,作用主要是单点修改,区间修改和区间查询,这里拿线段树的区间和来举例子
线段树的构建
线段树是一棵完全二叉树,我们可以可以按照与二叉堆相似的储存节点的方式,根节点为k,左子树为k<<1,右子树则为k<<1|1,因为树是递归定义的,所以我们开始递归建树。
1.建立结构体,用于存树
因为完全二叉树对应一个长度为a的区间,它需要的节点个数为4a,因此,我们要把树的大小开到区间长度的四倍
1 2 3 4 5
| const int maxn=1000010; struct tree{ int l,r; long long dat,add,mul; }t[maxn*4+5];
|
2.建立上传
因为每次建一棵树,父节点都要有相对应的数据更新,因此我们需要pushup来上传子节点的数据
1 2 3 4
| void pushup(int k) { t[k].dat=(t[k<<1].dat+t[k<<1|1].dat)%p; }
|
3.递归建树
然后用递归建树
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| void build(int k,int l,int r) { t[k].l=l,t[k].r=r; t[k].add=0,t[k].mul=1; if(l==r) { t[k].dat=a[l]; return ; } int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); pushup(k); t[k].dat%=p; return ; }
|
这样,一棵线段树就建好了
线段树的操作
对于线段树,我们肯定不仅仅满足于建好一棵线段树,最主要的是,我们要用它做一些什么事情,来帮助我们更好地解决我们需要解决的问题。
前面说过,线段树的操作主要是点的修改和区间的数据查询,我们先说一说点的修改
1.点的修改
因为线段树是用递归定义的,因此,我们点的修改也可以用递归来进行修改这里不过多赘述,因为重要的在后面
2.区间修改
区间修改有两种修改方式,第一种是对应的某一个区间里面的数,每一个加上一个数,另一种是对应区间里面的数,每一个乘上一个数字。
对于加上一个数的方式,我们直接类似于点的修改,但是我们会发现,有一些树是并没有用到的,也就是说,我们每次在更新的时候去更新下一个节点,然后数据继续往下更新,我们就会造成很大的空间与时间的冗余,这些空间和时间都很大,我们都用得到!而且可以节省很大一部分空间!
因此,我们在这里就要使用一种特殊的操作,延迟标记,然后我们就可以节省大量的空间,以及时间,从而加快线段树的速度
大家可以看到,在之前的结构体定义中有add和mul,它们就分别是线段树加法和乘法的延迟标记
然后延迟标记在我们使用某棵树的时候需要传递下去,因此,我们与pushup相类似,我们要使用一个pushdown
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| void pushdown(int k) { t[k<<1].dat=(t[k<<1].dat*t[k].mul+t[k].add*(t[k<<1].r-t[k<<1].l+1))%p; t[k<<1|1].dat=(t[k<<1|1].dat*t[k].mul+t[k].add*(t[k<<1|1].r-t[k<<1|1].l+1))%p; t[k<<1].mul=(t[k].mul*t[k<<1].mul)%p; t[k<<1|1].mul=(t[k].mul*t[k<<1|1].mul)%p; t[k<<1].add=(t[k<<1].add*t[k].mul+t[k].add)%p; t[k<<1|1].add=(t[k<<1|1].add*t[k].mul+t[k].add)%p; t[k].add=0; t[k].mul=1; }
|
然后我们就可以快乐的进行区间修改了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| void updata1(int k,int l,int r,int v) { if(l<=t[k].l&&r>=t[k].r) { t[k].dat=(t[k].dat*v)%p; t[k].mul=(t[k].mul*v)%p; return ; } pushdown(k); int mid=(l+r)>>1; if(l<=mid) updata1(k<<1,l,r,v); if(r>mid) updata1(k<<1|1,l,r,v); pushup(k); } void updata2(int k,int l,int r,int v) { if(l<=t[k].l&&r>=t[k].r) { t[k].dat=(t[k].dat+v*(t[k].r-t[k].l+1)%p; t[k].mul=(t[k].add+v)%p; return ; } pushdown(k); int mid=(l+r)>>1; if(l<=mid) updata2(k<<1,l,r,v); if(r>mid) updata2(k<<1|1,l,r,v); pushup(k); }
|
3.区间查询
区间查询和区间修改实际上差不多,但是要注意返回值的时候要
模,不然会出问题
1 2 3 4 5 6 7 8 9 10 11 12 13
| long long query(int k,int l,int r) { if(l<=t[k].r&&r>=t[k].r) return t[k].dat; pushdown(k); long long ans=0; int mid=(l+r)>>1; if(l<=mid) ans=(ans+query(k<<1,l,r))%p; if(r>mid) ans=(ans+query(k<<1|1,l,r))%p; return ans%p; }
|
这就是线段树的模板了
最终代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
| #include <iostream> #include <cstdio> #include <algorithm> #include <cstring> #include <string> using namespace std; const int maxn=100010; int n,m,p; int a[maxn]; struct tree{ int l,r; long long dat,add,mul; }t[maxn*4+5]; void pushup(int k) { t[k].dat=(t[k<<1].dat+t[k<<1|1].dat)%p; } void pushdown(int k) { t[k<<1].dat=(t[k<<1].dat*t[k].mul+t[k].add*(t[k<<1].r-t[k<<1].l+1))%p; t[k<<1|1].dat=(t[k<<1|1].dat*t[k].mul+t[k].add*(t[k<<1|1].r-t[k<<1|1].l+1))%p; t[k<<1].mul=(t[k].mul*t[k<<1].mul)%p; t[k<<1|1].mul=(t[k].mul*t[k<<1|1].mul)%p; t[k<<1].add=(t[k<<1].add*t[k].mul+t[k].add)%p; t[k<<1|1].add=(t[k<<1|1].add*t[k].mul+t[k].add)%p; t[k].add=0; t[k].mul=1; }
void build(int k,int l,int r) { t[k].l=l,t[k].r=r; t[k].add=0,t[k].mul=1; if(l==r) { t[k].dat=a[l]; return ; } int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); pushup(k); t[k].dat%=p; return ; } void updata1(int k,int l,int r,int v) { if(l<=t[k].l&&r>=t[k].r) { t[k].dat=(t[k].dat*v)%p; t[k].mul=(t[k].mul*v)%p; return ; } pushdown(k); int mid=(l+r)>>1; if(l<=mid) updata1(k<<1,l,r,v); if(r>mid) updata1(k<<1|1,l,r,v); pushup(k); } void updata2(int k,int l,int r,int v) { if(l<=t[k].l&&r>=t[k].r) { t[k].dat=(t[k].dat+v*(t[k].r-t[k].l+1))%p; t[k].mul=(t[k].add+v)%p; return ; } pushdown(k); int mid=(l+r)>>1; if(l<=mid) updata2(k<<1,l,r,v); if(r>mid) updata2(k<<1|1,l,r,v); pushup(k); } long long query(int k,int l,int r) { if(l<=t[k].r&&r>=t[k].r) return t[k].dat; pushdown(k); long long ans=0; int mid=(l+r)>>1; if(l<=mid) ans=(ans+query(k<<1,l,r))%p; if(r>mid) ans=(ans+query(k<<1|1,l,r))%p; return ans%p; } int main() { }
|