BOJ 15977, KOI 2018 고등부 3번 조화로운 행렬
풀이
29점 풀이
서브태스크 1과 3만 먼저 해결하자. 즉, $n = 2$ 인 경우. 문제를 잘 해석해 보면, 다음 풀이에 도달하는 것은 어렵지 않다.
- 1행의 원소를 기준으로 전체 배열을 정렬한 다음
- 2행에서 Longest Increasing Subseqeunce를 찾으면 정답이 된다.
이 풀이는 DP, Set 또는 다른 자료구조를 이용해 $O(n \log n)$ 시간에 LIS를 구하는 매우 잘 알려진 문제이고, 다이아 4인 이 문제에 도전하는 정도의 PS 실력이 있다면 아마 풀이를 찾는 것도, 구현도 매우 쉬울 것이므로 설명은 생략.
100점 풀이
서브태스크 3의 아이디어를 확장하자. 위 알고리즘을 그대로 돌리면, 다음과 같은 아이디어를 떠올릴 수 있다.
- 1행의 원소를 기준으로 전체 배열을 정렬한 다음
- (2, 3)행의 원소를 Pair로 잡고 Pair에게 $(x, y) < (a, b) \iff x < y , a < b$ 라는 partial order를 부여하자. 이 기준으로 LIS를 찾으면 정답이 된다.
결국 문제는 Pair-LIS를 어떻게 찾을 것인가에 대한 고민이다.
Longest Increasing Seqeunce를 Segment tree를 이용하여 구하는 방법이 있다. 요지는 일단 다이나믹 프로그래밍인데, 다이나믹 프로그래밍을 대충 기술해 보면 \(D_i = \max_{j > i, a_j > a_i} D_j\) 위 DP식으로 생각할 수 있다. 이를 쉽게 해결하기 위해, $i$가 큰 것부터, 즉 수열을 거꾸로 보면서 DP를 갱신하면, 다음 쿼리를 빠르게 처리하면 된다. \(D_i = \max_{a_j > a_i} D_j\) 이 문제는 Range-maximum segment tree에 $[a_i + 1, M]$ 범위 쿼리를 날려서 쉽게 해결할 수 있다.
Pair 문제도 별로 다르지 않다. 다음과 같이 DP식을 쓰자. \(D_i = \max_{a_j > a_i, b_j > b_i} D_j\) 똑같이 해결하되, $[a_i + 1, M]$ 범위 쿼리 대신 $[a_i + 1, M] \times [b_i + 1, M]$ 직사각형 쿼리를 날려서 진행하면 된다. 일반적으로 이러한 $M \times M$ 세그먼트 트리 구축에 $O(M^2)$ 메모리를 필요로 하고, $O(\log^2 M)$ 시간에 한 쿼리를 처리할 수 있음이 잘 알려져 있다.
Dynamic Segment Tree
이 문제에서 $M$은 수열의 최대 범위이므로 10억이다. 먼저, 10억 칸의 메모리도 시간도 없으나, 각 원소가 다름이 주어지고, 우리에게 중요한 것은 실제값이 아닌 대소관계이므로 전체를 좌표압축하자. 이때 1행은 어차피 LIS에서는 쓰지 않을 것이므로 2, 3행만 압축해서 40만으로 줄일 수 있다. 이제 $M$ 이 40만인데, 여전히 $M^2$ 메모리는 없다. 이제 $M \in O(n)$ 이므로 대충 $n$으로 퉁쳐가며 쓰기로 하자.
잘 관찰해 보면, $n^2$ 칸의 거대한 직사각형 중 우리가 실제로 값을 업데이트할 칸은 $O(n)$ 칸에 불과하다. 즉, 대부분의 노드가 사실 빈 칸을 저장하고 있다는 것이다. 이러한 낭비를 해결하기 위한 좋은 방법으로 Dynamic Segment Tree가 있다.
Dynamic Segment Tree에 대해서는 JusticeHui님의 블로그 글 이 가장 이해하기 좋은 자료 중 하나인 것 같으니 이를 참고하자. 요점은 노드들을 On-demand 방식으로 만들면서 진행하고, 이를 포인터로 잘 관리하면서 움직이는 것이다.
다만 이 문제에 경우 Dynamic 2D Segment Tree를 구현해야 하므로, 위 글을 보고 나면 secmem에 올라온 blisstoner님의 글 도 참고하자. 이 글은 2D segment tree를 소개하고 있지만, 메모리를 아끼는 방법도 소개하고 있다. 메모리가 상당히 타이트하므로 열심히 구현해야 하며, 나는 어떤 다른 코포 유저가 IOI games를 설명하기 위해 구현한 코드를 대부분 베껴넣고 내 평소 코딩 스타일에 맞게 조금 고쳐서 작성했다.
슬프게도 2D Dynamic segment tree는 상수가 정말 크기 때문에, 5초의 제한 시간에 맞추기가 쉽지 않다. oj.uz에서의 제한 시간은 4초던데 나는 아직 이 코드를 (FASTIO를 적용했음에도 불구하고) 4초 안에 구겨넣지 못했다. 언젠가 다시 해볼 생각.
코드
#include <bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#define ll long long
#define eps 1e-7
#define all(x) ((x).begin()),((x).end())
#define usecppio ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
using namespace std;
using pii = pair<int, int>;
void solve_2();
int m, n;
int arr[3][202020];
set <int> s;
const int MAXR = 404040;
const int MAXC = 404040;
char buf[10000011];
int idx, widx;
inline void readint(int & r) {
r = 0;
char c = buf[idx++];
while (c < 33) c = buf[idx++];
while (c >= 48 && c <= 57) {
r = r * 10 + (c & 15);
c = buf[idx++];
}
}
// max-dynamic-2seg
// msg555 impl, some modifications
struct Dynamic2DSeg
{
struct Node
{
Node (int l, int r, int v = 0) : l(l), r(r), v(v), lc(NULL), rc(NULL){}
int l, r, v;
Node *lc, *rc;
};
struct RowNode
{
RowNode () : lc(NULL), rc(NULL), rowRoot(0, MAXC, 0){}
Node rowRoot;
RowNode *lc, *rc;
};
RowNode root;
static void update2(Node *node, int idx, int dt)
{
int lo = node -> l;
int hi = node -> r;
int mid = (lo + hi) / 2;
if (lo + 1 == hi)
{
node -> v = dt;
return;
}
Node *& next = idx < mid ? node -> lc : node -> rc;
if (next == NULL)
next = new Node(idx, idx + 1, dt);
else if (next -> l <= idx && idx < next -> r)
update2(next, idx, dt);
else
{
do {
(idx < mid ? hi : lo) = mid;
mid = (lo + hi) / 2;
} while ((idx < mid) == (next->l < mid));
Node * nnode = new Node(lo, hi);
(next->l < mid ? nnode->lc : nnode->rc) = next;
next = nnode;
update2(nnode, idx, dt);
}
node -> v = max(node->lc?node->lc->v:0, node->rc?node->rc->v:0);
}
static void update1(RowNode * node, int lo, int hi, int ridx, int idx, int dt)
{
int mid = (lo + hi) / 2;
if (lo + 1 == hi)
update2(&node -> rowRoot, idx, dt);
else
{
RowNode *& nnode = ridx < mid ? node -> lc : node -> rc;
(ridx < mid ? hi : lo) = mid;
if (nnode == NULL)
nnode = new RowNode();
update1(nnode, lo, hi, ridx, idx, dt);
dt = max(node->lc ? query2(&node->lc->rowRoot, idx, idx + 1) : 0,
node->rc ? query2(&node->rc->rowRoot, idx, idx + 1) : 0);
update2(&node->rowRoot, idx, dt);
}
}
static int query2(Node *node, int a, int b)
{
if (node == NULL || b <= node -> l || node -> r <= a)
return 0;
else if (a <= node -> l && node -> r <= b)
return node -> v;
return max(query2(node->lc, a, b), query2(node->rc, a, b));
}
static int query1(RowNode *node, int lo, int hi, int a1, int b1, int a2, int b2)
{
if (node == NULL || b1 <= lo || hi <= a1)
return 0;
else if (a1 <= lo && hi <= b1)
return query2(&node->rowRoot, a2, b2);
int mid = (lo + hi) / 2;
return max(query1(node->lc, lo, mid, a1, b1, a2, b2),
query1(node->rc, mid, hi, a1, b1, a2, b2));
}
void update(int x, int y, int dt)
{
update1(&root, 0, MAXR, x, y, dt);
}
int query(int x1, int y1, int x2, int y2)
{
return query1(&root, 0, MAXR, x1, x2 + 1, y1, y2 + 1);
}
} DSEG;
void solve_3()
{
vector <pair<int, pii>> v;
for (int i = 0; i < n; i++)
v.push_back({arr[0][i], {arr[1][i], arr[2][i]}});
sort(all(v));
vector <int> ans(n, 0);
for (int i = n-1; i >= 0; i--)
{
int cx = v[i].second.first;
int cy = v[i].second.second;
ans[i] = DSEG.query(cx+1, cy+1, 1e9+1, 1e9+1) + 1;
DSEG.update(cx, cy, ans[i]);
}
int r = 0;
for (auto it:ans)
r = max(it, r);
cout << r << '\n';
}
vector <int> compress;
int32_t main()
{
fread(buf, 1, 6262626, stdin);
readint(m); readint(n);
for (int i = 0; i < m; i++)
{
for (int j = 0; j < n; j++)
{
readint(arr[i][j]);
if (i)
compress.push_back(arr[i][j]);
}
}
sort(all(compress));
compress.erase(unique(all(compress)), compress.end());
for (int i = 1; i < m; i++)
for (int j = 0; j < n; j++)
arr[i][j] = lower_bound(all(compress), arr[i][j]) - compress.begin();
if (m == 2)
solve_2();
else solve_3();
}
void solve_2()
{
vector <pii> v;
for (int i = 0; i < n; i++)
v.push_back({arr[0][i], arr[1][i]});
sort(all(v));
vector <int> vv;
for (int i = 0; i < n; i++)
vv.push_back(v[i].second);
set <int> :: iterator it;
for (int i = 0; i < n; i++)
{
s.insert(vv[i]);
it = s.upper_bound(vv[i]);
if (it != s.end())
s.erase(it);
}
cout << s.size() << '\n';
}