keisukeのブログ

***乱雑です!自分用のメモです!*** 統計や機械学習の勉強と、読み物を書く練習と、備忘録用のブログ

【Python】ふたつの配列からすべての組み合わせを評価

引数をふたつとる関数f([x,y])が存在するとします(この場合は引数2つではなく、要素数2の配列をひとつ引数にとるのほうが正しいですが・・・)。
たとえば、f([x,y]) = x+y などです。
このfはうまく実装されているので、f([[1,3], [2,4]])とすると(つまり[1+3, 2+4])、
[f([1,3]), f([2,4])]としたのと同じ効果(つまり[1+3, 2+4])を持ちます(これはnumpyやmatlabならお馴染みの設計だと思います)。

では、この関数fの引数x,yについて、候補が複数あったらどうなるでしょう?
そしてその候補を重複なく選んですべて評価したいときはどうすればよいでしょう?
つまり、x=[1,2,3]、y=[4,5]について、f([1,4]), f([1,5]), f([2,4]), f([2,5]), f([3,4]), f([3,5]) のように評価したいわけです。

ナイーブに考えると、こうなります:

>>> import numpy as np
>>> x = np.array([1,2,3])
>>> y = np.array([4,5])
>>> ret = [f([xi,yi]) for xi in x for yi in y]

しかしこれではせっかくfが引数を行列として処理できるように実装されているのに、それを活かせません。処理速度も遅くなる場合が大半だと思います。
このような場合、numpyのmeshgridが使えます:

>>> xx, yy = np.meshgrid(x, y)
>>> ret = f(np.c_[xx.ravel(), yy.ravel()])

パッと見だと何やってるかよくわかりづらいと思うので解説します。

まず、xx, yy = np.meshgrid(x,y)は、xx = np.tile(x, (len(y),1); yy = np.tile(y, (len(x),1).T; と同じです。サイズがlen(y) X len(x)の行列をふたつ作ります。
先ほどの例で言うと、xx=[[1,2,3],[1,2,3]], yy=[[4,4,4],[5,5,5]]となります。
この時点で、ふたつの行列xx,yyのインデックスi,jを順番に見ていくと、(i,j)=(0,0)のとき[1,4]、(i,j)=(0,1)のとき[2,4]、などとなり、すべての組み合わせを重複なく選べそうになっています。
続いてのxx.ravel(), yy.ravel()ですが、これは単純にxxとyyをバラバラにするだけです。xx.ravelをすると、[1,2,3, 1,2,3], yy.ravel()をすると[4,4,4, 5,5,5]となります。
次にnp.c_[]ですが、これは与えられたオブジェクトの第二軸(second axis)をもとに合体させます。この場合、np.c_[ [1,2,3, 1,2,3], [4,4,4, 5,5,5] ] は、[[1,4],[2,4],[3,4], [1,5],[2,5],[3,5]]となります。
これをfに与えれば、やりたかった重複なくすべての組み合わせの評価ができます。

実際こんなの何に使うのかというと、格子点すべてで関数を評価するときに便利です。
\begin{align*}y>-x^2\end{align*}を視覚化したいとしましょう。これは\begin{align*}y+x^2>0\end{align*}ですから、視覚化するときは、あらゆる(x,y)の組に対して\begin{align*}y+x^2>0\end{align*}が成立すれば黄、しなければ青でグラフを塗りつぶせば良いです。このあらゆる(x,y)の組というのが、np.meshgrid()で作れます。

>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> x = np.arange(-2,2, 0.01)
>>> x2 = x**2
>>> y = np.arange(-5,0, 0.01)
>>> xx, yy = np.meshgrid(x2,y)
>>> arg = np.c_[xx.ravel(), yy.ravel()]
>>> ret = np.sum(arg, axis=1)  # calc x^2+y
>>> plt.contourf(x, y, ret.reshape(xx.shape)>0, cmap=plt.cm.Paired, alpha=0.8)
>>> plt.show()

f:id:kaisk:20141105040848p:plain