Back to : algorithms
Contents

풀이

κ²°κ΅­ 문제의 μš”μ μ€, μœ„ / 쀑간/ μ•„λž˜λ₯Ό $u_i, m_i, l_i$ 라고 ν•  λ•Œ, $u_i + l_j = 2m_k$ 인 μˆœμ„œμŒ $(i, j, k)$의 개수λ₯Ό μ„ΈλŠ” λ¬Έμ œμ΄λ‹€.

이 λ¬Έμ œλŠ” 맀우 잘 μ•Œλ €μ§„ Convolution 문제둜, μˆ˜μ—΄μ„ μΉ΄μš΄νŒ… λ‹€ν•­μ‹μœΌλ‘œ μΈμ½”λ”©ν•˜λ©΄ μ‰½κ²Œ ν’€ 수 μžˆλ‹€. μˆ˜μ—΄μ΄ 1, 3, 4, 5, 5λ©΄ $x + x^3 + x^4 + 2x^5$ 인 μ‹μœΌλ‘œ. μ΄λ ‡κ²Œ μΈμ½”λ”©ν•΄μ„œ $u(x), m(x), l(x)$ λ₯Ό λ§Œλ“  λ‹€μŒ,

$u(x)l(x)$을 κ³„μ‚°ν•˜λ©΄ κ·Έ $n$μ°¨ν•­μ˜ κ³„μˆ˜κ°€ $u_i + l_j = n$ 인 μˆœμ„œμŒ $i, j$ 의 κ°œμˆ˜κ°€ λœλ‹€. λ”°λΌμ„œ, 각 $n$에 λŒ€ν•΄, $m_k = n$인 $k$의 κ°œμˆ˜μ™€ $u(x)l(x)$의 $2n$μ°¨ ν•­μ˜ κ³„μˆ˜λ₯Ό κ³±ν•΄μ„œ λ”ν•˜λ©΄ κ·ΈλŒ€λ‘œ κ΅¬ν•˜λŠ” 닡이 되고, 이λ₯Ό FFTλ₯Ό μ΄μš©ν•˜μ—¬ $O(n \log n)$에 ꡬ할 수 있음이 잘 μ•Œλ €μ Έ μžˆλ‹€.

μ—¬λ‹΄μœΌλ‘œβ€¦ ICPC μ„œμšΈ 리저널은 이런 Do you know *** λ¬Έμ œκ°€ κ½€ 많이 λ‚˜μ˜€λŠ” 편인 것 κ°™λ‹€. 2017 FFT, 2017 인예 FFT, 2019 인예 LiChaoTree / LR Flow …

μ½”λ“œ

#include <bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#define ll long long
#define int ll
#define eps 1e-7
#define all(x) ((x).begin()),((x).end())
#define usecppio ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
using namespace std;
using pii = pair<int, int>;

#define sz(v) ((int)(v).size())
typedef vector<int> vi;
typedef complex<double> base;

void fft(vector <base> &a, bool invert)
{
    int n = sz(a);
    for (int i=1,j=0;i<n;i++){
        int bit = n >> 1;
        for (;j>=bit;bit>>=1) j -= bit;
        j += bit;
        if (i < j) swap(a[i],a[j]);
    }
    for (int len=2;len<=n;len<<=1){
        double ang = 2*M_PI/len*(invert?-1:1);
        base wlen(cos(ang),sin(ang));
        for (int i=0;i<n;i+=len){
            base w(1);
            for (int j=0;j<len/2;j++){
                base u = a[i+j], v = a[i+j+len/2]*w;
                a[i+j] = u+v;
                a[i+j+len/2] = u-v;
                w *= wlen;
            }
        }
    }
    if (invert){
        for (int i=0;i<n;i++) a[i] /= n;
    }
}

void multiply(const vi &a,const vi &b,vi &res)
{
    vector <base> fa(all(a)), fb(all(b));
    int n = 1;
    while (n < max(sz(a),sz(b))) n <<= 1;
    fa.resize(n); fb.resize(n);
    fft(fa,false); fft(fb,false);
    for (int i=0;i<n;i++) fa[i] *= fb[i];
    fft(fa,true);
    res.resize(n);
    for (int i=0;i<n;i++)
        res[i] = (int)(fa[i].real()+(fa[i].real()>0?0.5:-0.5));
}

vi u(121212, 0), m(121212, 0), l(121212, 0), c;

void input(vi &v)
{
    int n; cin >> n;
    for (int i = 0; i < n; i++)
    {
        int x; cin >> x;
        v[x+30000]++;
    }
}

int32_t main()
{
    usecppio
    input(u); input(m); input(l);
    multiply(u, l, c);
    int ans = 0;
    for (int i = 0; i <= 60000; i++)
        ans += m[i] * c[2*i];
    cout << ans << '\n';
}