KD Tree & Ball Tree - KNN classifier 속도 향상

  • 발표자: 익명의 인턴

1. 소개

이 글에서는 다차원 데이터를 빠르게 삽입/삭제/탐색할 수 있는 자료구조인 KD-Tree와 Ball-Tree를 소개하고자 한다. 1차원에서 가장 가까운 점(NN)을 찾는 것은 Binary Search로 로그 시간복잡도에 구할 수 있다. 그러나 Nearest Neighbor 등의 Image Classification에서는 다차원의 데이터를 다루어야 하기에, K차원의 점을 효과적으로 저장하는 자료구조가 필요하다.

두 알고리즘 모두 구현은 다소 복잡할 수도 있으나, 다행히도 Sklearn에 구현되어 있다.

2. KD-Tree

2.1. 동작

K차원의 점을 저장하는 자료구조이다. 트리의 각 높이는 ( 높이 mod K ) 축을 담당하며, 각 서브트리는 중앙값을 기준으로 나뉘어 이진트리를 이룬다.

이해하기 좋은 그림자료가 아래의 링크에 존재한다.

http://www.secmem.org/blog/2019/05/09/트리의-종류와-이해/

2.2. 시간복잡도

처음 초기화할 때 O(N), 기타 연산(삽입, 삭제, 탐색)은 대략 O(log N)에 이루어진다.

2.3. 구현

Sklearn에 구현되어 있다. 로컬에서 37분 걸린다.

import numpy as np
from sklearn.neighbors import KDTree

class NearestNeighbor(object):
    def __init__(self):
        pass

    def train(self, X, y):
        self.kdt = KDTree(X)
        self.ytr = y
    def predict(self, X, k):
        num_test = X.shape[0]
        if num_test<k:
            k=num_test
        Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
        for i in tqdm_notebook(range(num_test)):
            dist, idx = self.kdt.query([X[i]],k)
            Ypred[i]=self.ytr[idx]
        return Ypred

nn=NearestNeighbor()
nn.train(Xtr_rows,Ytr)

predict = nn.predict(Xte_rows)
print('accuracy :%f' % (np.mean(predict == Yte)))

2.4. 단점

K를 주기로 담당하는 구간이 반으로 줄어들기 때문에, K가 지나치게 클 경우 불균형해진다. 일반적으로 $N>(2^k)$ 인 경우에 KD-Tree가 효과적으로 작동한다.

3. Ball-Tree

3.1. 동작

Ball-Tree는 범위를 기준으로 차원을 내림차순 정렬한 후, 서브트리에 속한 범위의 반을 기준으로 KD-Tree를 적용한 것이다. 사실 KD-Tree와 거의 같은 개념이다. KD-Tree를 약간 최적화한 것이 Ball-Tree이다.

3.2. 시간복잡도

처음 초기화할 때 $O(NlgN)$, 기타 연산(삽입, 삭제, 탐색)은 대략 $O(log N)$에 이루어진다.

동일해보이지만, 다차원 랜덤데이터에 대해서 랜덤함을 보정하는 효과가 있기 때문에 균형 잡힌 트리가 만들어질 가능성이 보다 높다. 초기화 과정에서 $logN$이 붙으나, Query 연산의 기댓값이 빨라지는 Trade-Off 관계이다.

Ball Tree 역시 삽입 및 삭제 Query가 많아지면 불균형해질 가능성이 높아진다.

3.3. 구현

Sklearn에 구현되어 있다. 로컬에서 29분 걸린다.

import numpy as np
from sklearn.neighbors import BallTree

class NearestNeighbor(object):
    def __init__(self):
        pass

    def train(self, X, y):
        self.bt = BallTree(X)
        self.ytr = y
    def predict(self, X, k):
        num_test = X.shape[0]
        if num_test<k:
            k=num_test
        Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
        for i in tqdm_notebook(range(num_test)):
            dist, idx = self.bt.query([X[i]],k)
            Ypred[i]=self.ytr[idx]
        return Ypred

nn=NearestNeighbor()
nn.train(Xtr_rows,Ytr)

predict = nn.predict(Xte_rows)
print('accuracy :%f' % (np.mean(predict == Yte)))

4. 번외

4.1. 시간복잡도

지금까지 탐색 연산의 시간복잡도를 $log N$이라 적어두었는데 이는 위키피셜이고, 조사한 증명에 따르면 $log^{d-1} N$이다. 물론 다른 더 좋은 증명이 있을 수도 있다.

4.2. Balanced KD-Tree

삽입 및 삭제가 일어나지 않는 정적인 KD-Tree는 전처리를 통해 부모노드가 관리하는 서브 트리에 속한 정점의 인덱스의 중앙값을 기준으로 이진트리를 구성하여 성능을 높일 수 있다. 이 경우 처음 초기화할 때 $O(Nlog^2N)$, 기타 연산(삽입, 삭제, 탐색)은 $O(log N)$에 이루어진다.

이러한 방식은 차원의 범위의 중간값을 기준으로 하는 Ball-Tree와 유사하나, Update Query가 없는 경우 언제나 균형잡힌 이진트리를 만들기 때문에 성능 차이를 보인다.

5. Reference

KD-Tree

KD-Tree Wiki

Ball-Tree

Ball-Tree Wiki

시간복잡도