BOJ 18194, 경기과학고 2019 Bad Hair Day와 기댓값
- 난이도 : Platinum 1
- 출처 : 경기과학고 프로그래밍 대회 2019
풀이
각 소 $i$ 에 대해, $n!$ 개의 가능한 배치 중 $i$ 가 $j$ 를 볼 수 있는 경우의 수를 센다고 생각하자. 이때 이 값을 $f(i, j)$ 라고 하면, 우리가 구하는 값은 (기댓값의 선형성에 의해) \(\frac{1}{n!}\sum_{i, j} f(i, j)\) 가 될 것이다. 그러나, 구해야 하는 $f(i, j)$ 의 개수가 $O(n^2)$ 개이기 때문에, 이 방법으로는 어려워 보인다.
기준이 되는 한 마리 $i$ 번을 잡자. 이때, $i$ 번이 볼 수 있는 소의 마릿수의 기댓값을 빠르게 구할 수 있다면 전체 문제를 풀 수 있을 것이다. 이를 위해, 다음과 같이 생각하자.
- 어차피, $i$ 입장에서 자기보다 큰 소는 절대 볼 수 없다. 즉, 볼 수 있는 수는 최대 자기보다 작은 소 $k$ 마리이다.
- 자기보다 크거나 같은 (본인을 포함) 소들 $n-k$ 마리를 먼저 아무렇게나 배치했다고 치자. 이제, 이들 사이사이에 작은 소들을 집어넣는다고 생각할 것이다.
그림으로 생각해 보면, 자기 바로 오른쪽 자리 (본인보다 크거나 같은 소들만을 기준으로) 에 자기보다 작은 소들이 배치된다면 이 소를 볼 수 있고, 그렇지 않다면 볼 수 없다. 즉, $a_1 \dots a_{n-k}$ 까지의 사이사이에 $b_0 \dots b_{n-k}$ 를 배치한다고 할 때, 다음 조건을 만족하며 (소 $k$ 마리를 잘 배치하는 것이므로) \(\sum_{t = 0}^{n-k}b_t = k\) 이때, 우리가 기준으로 하는 $i$ 번 소의 바로 오른쪽인 $b_{i+1}$ 마리의 머리를 볼 수 있으므로, 우리가 알고자 하는 값은 $\mathbb{E}(b_{i+1})$ 이다. 자명하게, 각 $b_t$ 자리에는 전혀 차이가 없으므로, \(\mathbb{E}(b_{i+1}) = \frac{k}{n - k + 1}\) 따라서, 우리가 구하고자 하는 값은, $i$번 소보다 작은 소의 수를 $k_i$ 라고 할 떄, \(\sum_{i = 1}^{n}\frac{k_i}{n - k_i + 1}\) 이다. 여기서 $k_i$ 는 binary search를 이용해서 $O(\log n)$ 시간에 구할 수 있으므로, 전체 풀이는 $O(n \log n)$이고, 모듈러 연산을 해야 하므로 실제로는 나눗셈 한번에 $O(\log p)$ 시간이 걸려 $O(n \log n + n \log p)$ 시간이 된다.
Code
#include <bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#define ll long long
#define int ll
#define eps 1e-7
#define all(x) ((x).begin()),((x).end())
#define usecppio ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
using namespace std;
using pii = pair<int, int>;
const int mod = 1000000007;
int n, ans;
vector <int> height;
int modpow(int b, int e)
{
int r = 1;
for (; e; b = b * b % mod, e /= 2)
if (e & 1) r = r * b % mod;
return r;
}
int modmul(int a, int b)
{
return (a * b) % mod;
}
int moddiv(int a, int b)
{
int ib = modpow(b, mod-2);
return modmul(a, ib);
}
int32_t main()
{
usecppio
cin >> n;
for (int i = 0; i < n; i++)
{
int h; cin >> h;
height.push_back(h);
}
sort(all(height));
for (auto h : height)
{
int tmp = 1;
int le = lower_bound(all(height), h) - height.begin();
tmp = modmul(tmp, le);
tmp = moddiv(tmp, n - le + 1);
ans += tmp;
ans %= mod;
}
cout << ans << '\n';
}