[AH2017/HNOI2017]禮物 (FFT)

題面

LOJ傳送門

題解

直接把貢獻寫出來看看。
i=1n(aibi+x)2=i=1n(ai2+bi2)+nx2+2xi=1n(aibi)2i=1naibi\sum_{i=1}^n(a_i-b_i+x)^2=\sum_{i=1}^n(a_i^2+b_i^2)+nx^2+2x\sum_{i=1}^n(a_i-b_i)-2\sum_{i=1}^na_ib_i
發現當xx已知時,唯一需要考慮的就是最後一項,因爲可以循環位移數組。然後由於xx值域[m,m][-m,m]可枚舉,那麼問題就在於怎麼求maxi=1naibi\max\sum_{i=1}^na_ib_i

我們把aa數組倍長,就是求maxi=1nj=1nai+j1bj\max_{i=1}^{n}\sum_{j=1}^na_{i+j-1}b_j

bb數組倒序,就是求maxi=1nj=1nai+j1bnj+1\max_{i=1}^{n}\sum_{j=1}^na_{i+j-1}b_{n-j+1}

那麼這就是一個卷積的形式,直接FFTFFT乘起來求第n+1n+12n2*n項的最大值就行了。最後再枚舉xx求答案。實際上不用枚舉,直接二次函數對稱軸,不過mm只有100100,隨便做了。

CODE

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAXN = (1<<18) + 5;
const double Pi = acos(-1.0);
int n, m, a[MAXN], b[MAXN];
LL sa, sb, s2;
struct cp {
	double x, y;
	cp(){}
	cp(double xx, double yy):x(xx), y(yy){}
	inline cp operator+(const cp &o)const { return cp(x+o.x, y+o.y); }
	inline cp operator-(const cp &o)const { return cp(x-o.x, y-o.y); }
	inline cp operator*(const cp &o)const { return cp(x*o.x-y*o.y, x*o.y+y*o.x); }
}f[MAXN], g[MAXN];
int rev[MAXN];
void FFT(cp *arr, int len, int flg) {
	for(int i = 0; i < len; ++i)if(rev[i]<i)swap(arr[i], arr[rev[i]]);
	cp u, v, wn, w;
	for(int i = 2; i <= len; i<<=1) {
		wn = cp(cos(2*Pi/i*flg), sin(2*Pi/i*flg));
		for(int j = 0; j < len; j += i) {
			w = cp(1, 0);
			for(int k = j; k < j + i/2; ++k) {
				u = arr[k];
				v = arr[k+i/2]*w;
				arr[k] = u + v;
				arr[k+i/2] = u - v;
				w = w * wn;
			}
		}
	}
	if(flg == -1) for(int i = 0; i < len; ++i) arr[i].x/=len;
}
int main () {
	scanf("%d%d", &n, &m);
	for(int i = 1; i <= n; ++i) scanf("%d", &a[i]), sa += a[i], s2 += a[i]*a[i], a[n+i] = a[i];
	for(int i = n; i >= 1; --i) scanf("%d", &b[i]), sb += b[i], s2 += b[i]*b[i];
	int len = 1; while(len <= (3*n)) len<<=1;
	for(int i = 1; i < len; ++i) rev[i] = (rev[i>>1]>>1)|((i&1)*(len>>1));
	for(int i = 0; i < len; ++i) f[i] = cp(a[i], 0), g[i] = cp(b[i], 0);
	FFT(f, len, 1), FFT(g, len, 1);
	for(int i = 0; i < len; ++i) f[i] = f[i] * g[i];
	FFT(f, len, -1);
	LL mx = 0, ans = 1ll<<50;
	for(int i = n+1; i <= 2*n; ++i)
		mx = max(mx, (LL)round(f[i].x));
	for(LL x = -m; x <= m; ++x)
		ans = min(ans, s2 + n*x*x + 2*x*(sa-sb) - 2*mx);
	printf("%lld\n", ans);
}
發佈了373 篇原創文章 · 獲贊 241 · 訪問量 5萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章