【CF266E】More Queries to Array... - 線段樹

題目描述

You've got an array, consisting of nn integers: \(a_{1},a_{2},...,a_{n}\). Your task is to quickly run the queries of two types:

  1. Assign value \(x\) to all elements from \(l\) to \(r\) inclusive. After such query the values of the elements of array \(a_{l},a_{l+1},...,a_{r}\) become equal to \(x\).
  2. Calculate and print sum , where \(k\) doesn't exceed \(5\) . As the value of the sum can be rather large, you should print it modulo \(1000000007 (10^{9}+7)\)

題目大意

一段序列 \(a_1,a_2......a_n\)
維護兩種操作:
\(=\ l\ r\ x\) 表示將區間 \([l,r]\) 的值賦爲 \(x\)
\(?\ l\ r\ k\) 表示輸出 \(\Sigma_{i=l}^ra_i(i-l+1)^k\ mod\ 1e9+7\)

思路

用二項式定理展開一下

\[\begin{align*} &\Sigma_{i=l}^ra_i[i+(1-l)]^k\\ =&\Sigma_{i=l}^ra_i\Sigma_{j=0}^ki^j(1-l)^{k-j}C_k^j\\ =&\Sigma_{j=0}^k(1-l)^{k-j}C_k^j\Sigma_{i=l}^ra_ii^j\\ \end{align*} \]

所以維護 \(a_ii^k,k\in[0,5]\) 就好了

#include <cstdio>
const int c[6][6] = { { 1,0,0,0,0,0 },{ 1,1,0,0,0,0 },{ 1,2,1,0,0,0 },{ 1,3,3,1,0,0 },{ 1,4,6,4,1,0 },{ 1,5,10,10,5,1 } };
const int maxn = 1e5 + 10;
const int mod = 1e9 + 7;
typedef long long ll;
int n,m,laz[maxn<<3];
ll sum[maxn<<3][6];
inline ll powerkth(ll n,int k) {
	if (k == 1) return n*(n+1)/2%mod;
	if (k == 2) return n*(n+1)*(2*n+1)/6%mod;
	if (k == 3) return n*n%mod*(n+1)%mod*(n+1)%mod*250000002ll%mod;
	if (k == 4) return n*(n+1)%mod*(2*n+1)%mod*(3*n*n%mod+3*n%mod-1)%mod*233333335ll%mod;
	if (k == 5) return n*n%mod*(n+1)%mod*(n+1)%mod*(2*n*n%mod+2*n%mod-1)%mod*83333334ll%mod;
	return n;
}
inline void pushup(int root) { for (int i = 0;i <= 5;i++) sum[root][i] = (sum[root<<1][i]+sum[root<<1|1][i])%mod; }
inline void pushdown(int root,int l,int r) {
	int mid = l+r>>1;
	if (laz[root] ^ mod) {
		laz[root<<1] = laz[root];
		laz[root<<1|1] = laz[root];
		for (int i = 0;i <= 5;i++) {
			sum[root<<1][i] = laz[root]*((powerkth(mid,i)-powerkth(l-1,i)+mod)%mod)%mod;
			sum[root<<1|1][i] = laz[root]*((powerkth(r,i)-powerkth(mid,i)+mod)%mod)%mod;
		}
		laz[root] = mod;
	}
}
inline void build(int l,int r,int root) {
	laz[root] = mod;
	if (l == r) {
		scanf("%lld",&sum[root][0]);
		for (int i = 1;i <= 5;i++) sum[root][i] = sum[root][i-1]*l%mod;
		return;
	}
	int mid = l+r>>1;
	build(l,mid,root<<1);
	build(mid+1,r,root<<1|1);
	pushup(root);
}
inline void update(int l,int r,int ul,int ur,int root,ll x) {
	if (l > ur || r < ul) return;
	if (ul <= l && r <= ur) {
		laz[root] = x;
		for (int i = 0;i <= 5;i++) sum[root][i] = x*((powerkth(r,i)-powerkth(l-1,i)+mod)%mod)%mod;
		return;
	}
	pushdown(root,l,r);
	int mid = l+r>>1;
	update(l,mid,ul,ur,root<<1,x);
	update(mid+1,r,ul,ur,root<<1|1,x);
	pushup(root);
}
inline ll query(int l,int r,int ql,int qr,int root,int k) {
	if (l > qr || r < ql) return 0;
	if (ql <= l && r <= qr) return sum[root][k];
	pushdown(root,l,r);
	int mid = l+r>>1;
	return (query(l,mid,ql,qr,root<<1,k)+query(mid+1,r,ql,qr,root<<1|1,k))%mod;
}
int main() {
	for (scanf("%d%d",&n,&m),build(1,n,1);m--;) {
		char ch; int l,r,k;
		scanf("%s%d%d%d",&ch,&l,&r,&k);
		if (ch == '=') update(1,n,l,r,1,k);
		else {
			if (l == 1) { printf("%lld\n",query(1,n,l,r,1,k)); continue; }
			ll ans = 0;
			for (int i = 0;i <= k;i++) {
				ll tmp = 1;
				for (int j = 1;j <= k-i;j++) tmp = (tmp*(1-l)%mod+mod)%mod;
				(ans += tmp*c[k][i]%mod*query(1,n,l,r,1,i)%mod) %= mod;
			}
			printf("%lld\n",ans);
		}
	}
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章