ブログ

PySparkでk近傍法

こんにちは、ソフトウェアエンジニアの冨田恭平です。
前回書いたPySparkで日本語形態素解析に引き続き、Sparkについて書こうと思います。
Sparkを使っていて、大規模なデータで最近傍探索をしたいと思ったことはないでしょうか。私はあります。
最近傍探索が活躍するシナリオの例としては、広告のターゲティングで類似ユーザーを探したり、レコメンドエンジンで類似アイテムを探したりする場合が考えられます。今回は、n個のベクトルそれぞれに対し、類似度の高いベクトルk個を計算するk近傍法 (K-Nearest Neighbor, KNN)をPySparkで実行する方法を紹介します。

ローカルでのk近傍法

ローカルでのk近傍法をPythonで行う場合は、scikit-learnのNeighborsモジュールを使うと便利です。リンク先のチュートリアルでも説明されている通り、下記のようなコードでk近傍法が実行できます。

import numpy as np
from sklearn.neighbors import NearestNeighbors
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
neighbors = NearestNeighbors(n_neighbors=2, algorithm='kd_tree').fit(X)
distances, indices = neighbors.kneighbors(X)

5行目でk近傍法を実行する際のkの値やアルゴリズムを指定しfit()でモデルを作成し、6行目で作成したモデルを元にkneighbors()でそれぞれの点に対して近傍点を求めるのがポイントです。モデルの作成や近傍点を求める際の計算量は指定するアルゴリズムによって異なります。例えば上の例のK-D Treeを使用していてベクトルの次元数dが小さい場合には、1つの点に対するk近傍点の探索はO(d \log(n))時間となります。ベクトルが疎な場合に使われる単純なbrute forceとなると、探索にO(dn)の時間がかかってしまいます。

scikit-learnを使ってPySparkでk近傍法

先ほど紹介したscikit-learnのNearestNeighborsクラスを使うと、kneighbors()で行う探索を分散処理することができます。モデルの作成はマスターノードで行い、そのモデルを全ノードにブロードキャスト、その後はブロードキャストしたモデルを使ってkneighbors()メソッドによる探索を分散して行います。コード例は下記のようになります。

spark = SparkSession.builder.appName('knn').getOrCreate()
sc = spark.sparkContext
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
rdd = sc.parallelize(X)
# マスターノードに全データを集めてfit
collected_X = rdd.collect()
neighbors = NearestNeighbors(n_neighbors=2, algorithm='kd_tree').fit(collected_X)
# fit済みのモデルを全ノードにbroadcast
broadcasted_neighbors = sc.broadcast(neighbors)
# kneighbors()を分散処理
neighbors_rdd = rdd.map(lambda x: broadcasted_neighbors.value.kneighbors([x]))

実際に時間がかかるのはモデルの作成よりも全点に対して探索を行うkneighbors()の部分なので、ある程度の規模のデータであればこの方法が有効です。
この方法の注意点は、fit()によるモデルの作成自体はマスターノードで行われる点です。マスターノードでは収まりきらない量のデータを扱う場合はこの方法は使用できません。

分散処理で近似解を求める方法

厳密なk近傍点ではなく近似解を求める代わりに、モデルの作成を分散して行うようなライブラリも公開されており、それらを使用するとマスターノードに収まらないデータに対してもk近傍法を使用できます。spark-knnやLinkedInのScANNSなどが有名なようです。
マスターノードのメモリには収まるものの、疎ベクトルを対象としているなどの理由でkneighbors()による探索が分散しても遅い場合は、Facebook ResearchのFaissPySparNN、SpotifyのAnnoyなどを使用することも考えられるかもしれません。これらは近似解を効率的に探索するためのライブラリで、分散処理で探索する部分の処理が高速化されることが期待されます。実装方法はscikit learnの場合と同様で、モデルの学習をマスターノードで行い、学習済みモデルを全ノードにブロードキャストします。弊社の一部製品では実際にPySparNNを使用し、疎ベクトルに対するk近傍法の近似解を分散処理で計算しています。

まとめ

PythonやSparkでk近傍法を実行する様々な方法を紹介しましたが、扱うデータサイズによって適切な手法を選択する必要があることがわかりました。
小規模なデータに対してはローカルでscikit-learnを使うのが簡単そうです。1台のマシンに収まるくらいのデータではあるけれど探索に時間がかかってしまう規模のデータの場合は、sckit-learnやFaiss/PySparNN/Annoyなどのライブラリを使用しモデルの作成はマスターノードで、探索はワーカーノードで分散して行うのが良いかもしれません。1台のマシンに収まらない規模のデータであれば、spark-knnやScANNSなどを使って近似解を求めるのが良いのではないでしょうか。
FLYWHEELでは、大規模データや分散処理に興味のあるエンジニアを募集しています