實質上是兩個多項式相乘注意去重
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <complex>
using namespace std;
typedef long long LL;
const double pi = acos(-1.0);
const int maxn = 271111;
const double eps = 1e-6;
struct Complex {
double a, b;
Complex() {
}
Complex(double a, double b) :
a(a), b(b) {
}
Complex operator +(const Complex& t) const {
return Complex(a + t.a, b + t.b);
}
Complex operator -(const Complex& t) const {
return Complex(a - t.a, b - t.b);
}
Complex operator *(const Complex& t) const {
return Complex(a * t.a - b * t.b, a * t.b + b * t.a);
}
Complex operator =(const double &n) const {
return Complex(1.0, 0);
}
} a[maxn], b[maxn];
char s1[maxn], s2[maxn];
void brc(Complex *y, int l) {
int i, j, k;
for (i = 1, j = l >> 1; i < l - 1; i++) {
if (i < j)
swap(y[i], y[j]);
k = l >> 1;
while (j >= k) {
j -= k;
k >>= 1;
}
if (j < k)
j += k;
}
}
int num[maxn];
int cnt[maxn];
LL res[maxn];
LL sum[maxn];
void FFT(Complex *y, int l, int on)
{
int h, i, j, k, p;
double r;
Complex u, t;
brc(y, l);
for (h = 2; h <= l; h <<= 1) {
r = on * 2.0 * pi / h;
Complex wn(cos(r), sin(r));
p = h >> 1;
for (j = 0; j < l; j += h) {
Complex w(1, 0);
for (k = j; k < j + p; k++) {
u = y[k];
t = w * y[k + p];
y[k] = u + t;
y[k + p] = u - t;
w = w * wn;
}
}
}
if (on == -1)
for (i = 0; i < l; i++)
y[i].a = y[i].a / l;
}
int main() {
ios::sync_with_stdio(false);
int n, i, L = 1, ma, T;
scanf("%d", &T);
while (T--) {
ma = 0;
memset(cnt, 0, sizeof(cnt));
scanf("%d", &n);
for (i = 0; i < n; ++i) {
scanf("%d", num + i);
cnt[num[i]]++;
ma = max(ma, num[i]);
}
if (n < 3) {
printf("%.7f\n", 0.0);
continue;
}
++ma;
L = 1;
while (L < ma * 2)
L <<= 1;
for (i = 0; i < ma; ++i)
a[i] = Complex(cnt[i], 0);
for (i = ma; i < L; ++i)
a[i] = Complex(0, 0);
FFT(a, L, 1);
for (i = 0; i < L; ++i)
a[i] = a[i] * a[i];
FFT(a, L, -1);
for (i = 0; i < L; ++i)
res[i] = (a[i].a + 0.5);
for (i = 0; i <= ma; ++i)
res[i << 1] -= cnt[i];
for (i = 0; i < L; ++i)
res[i] >>= 1;
for (i = 1; i < L; ++i)
sum[i] = sum[i - 1] + res[i];
double tot = 0, den = 1.0 * n * (n - 1) * (n - 2);
for (i = 0; i < n; ++i)
tot += sum[num[i]] / den;
double ans = 1 - tot * 6.0;
printf("%.7f\n", ans);
}
return 0;
}