ガウス単純ベイズの説明

ガウス単純ベイズの説明

ソースノード: 2021431

ガウス単純ベイズの説明
ガウス単純ベイズ分類器の決定領域。 著者による画像。

 

これは、データ サイエンスの各キャリアの始まりにおける典型的な例だと思います。 単純ベイズ分類器. というか、むしろ 家族 単純ベイズ分類器にはさまざまな種類があります。 たとえば、多項単純ベイズ、ベルヌーイ単純ベイズ、およびガウス単純ベイズ分類器があり、それぞれが XNUMX つの小さな詳細でのみ異なることがわかります。 ナイーブ ベイズ アルゴリズムの設計は非常に単純ですが、多くの複雑な実世界の状況で役立つことが証明されています。

この記事では、あなたが学ぶことができます

  • 単純ベイズ分類器の仕組み
  • それらをあるがままに定義することが理にかなっている理由
  • NumPy を使用して Python でそれらを実装する方法。

あなたはでコードを見つけることができます 私のGithub.

ベイジアン統計に関する私の入門書をチェックすると少し役立つかもしれません ベイジアン推論の穏やかな紹介 ベイズの公式に慣れる。 分類子を scikit 学習準拠の方法で実装するので、私の記事をチェックすることも価値があります。 独自のカスタム scikit-learn 回帰を構築する. ただし、scikit-learn のオーバーヘッドは非常に小さいため、とにかく従うことができるはずです。

単純ベイズ分類の驚くほど単純な理論の探索を開始し、実装に進みます。

分類するとき、私たちは何に本当に関心がありますか? 私たちは実際に何をしているのか、入力と出力は何ですか? 答えは簡単です。

データ点 x が与えられたとき、x があるクラス c に属する確率は?

私たちが答えたいのはそれだけです どれか 分類。 このステートメントを条件付き確率として直接モデル化できます。 p(c|x).

たとえば、

  • 3クラス c₁c₂c₃,
  • 2つの機能で構成されています ×₁×₂,

分類器の結果は次のようになります p(c₁|×₁×₂)= 0.3、 p(c₂|×₁×₂)=0.5 および p(c₃|×₁×₂)=0.2。 出力として単一のラベルを気にする場合、確率が最も高いラベルを選択します。つまり、 c₂ ここでは50%の確率で。

単純ベイズ分類器は、これらの確率を直接計算しようとします。

ナイーブベイズ

さて、与えられたデータポイント x、計算したい p(c|x) すべてのクラス 次に、 c 最も高い確率で。 数式では、これはよく次のように表示されます。

 

ガウス単純ベイズの説明
著者による画像。

 

注: マックス p(c|x) argmax の間、最大確率を返します。 p(c|x)を返します c この最高確率で。

しかし、最適化する前に p(c|x)、それを計算できなければなりません。 これには、 ベイズの定理:

 

ガウス単純ベイズの説明
ベイズの定理。 著者による画像。

 

これは単純ベイズのベイズ部分です。 しかし今、次の問題があります。 p(x|c)と p(c)?

これが、単純ベイズ分類器のトレーニングのすべてです。

トレーニング

すべてを説明するために、おもちゃのデータセットを使用してみましょう XNUMXつの本当の特徴 ×₁×₂XNUMXつのクラス c₁c₂c₃ 以下では。

 

ガウス単純ベイズの説明
視覚化されたデータ。 著者による画像。

 

この正確なデータセットを次の方法で作成できます

from sklearn.datasets import make_blobs X, y = make_blobs(n_samples=20, centers=[(0,0), (5,5), (-5, 5)], random_state=0)

から始めましょう クラス確率 p(c)、あるクラスが c ラベル付きデータセットで観察されます。 これを推定する最も簡単な方法は、クラスの相対頻度を計算し、それらを確率として使用することです。 データセットを使用して、これが正確に何を意味するかを確認できます。

クラスとラベル付けされた 7 点のうち 20 点があります。 c₁ (青) データセットに含まれているため、次のように言います p(c₁)=7/20. クラスのポイントは7つです c₂ (赤)も同様なので、設定します p(c₂)=7/20. 最後のクラス c₃ (黄色) は 6 ポイントしかないため、 p(c₃)=6/20。

クラス確率のこの単純な計算は、最尤アプローチに似ています。 ただし、別のものを使用することもできます 事前の よろしかったら配布。 たとえば、このデータセットが真の人口を代表していないことがわかっている場合、 c₃ ケースの 50% に表示される必要がある場合は、設定します p(c₁)= 0.25、 p(c₂)=0.25 および p(c₃)=0.5。 テスト セットのパフォーマンスを向上させるのに役立つものは何でも。

今度は 尤度 p(x|c)=p(×₁×₂|c)。 この可能性を計算する XNUMX つの方法は、ラベル付きのサンプルのデータセットをフィルター処理することです。 次に、特徴を捉える分布 (2 次元ガウス分布など) を見つけようとします。 ×₁×₂.

残念ながら、通常、クラスごとに十分なサンプルがなく、可能性を適切に推定できません。

より堅牢なモデルを構築できるようにするために、 素朴な仮定 その特徴 ×₁×₂   確率的に独立、与えられた c. これは、数学をより簡単にするための単なる派手な方法です。

 

ガウス単純ベイズの説明
著者による画像

 

クラスごとに c。 これは、 素朴な 単純ベイズの一部は、この式が一般に成り立たないためです。 それでも、単純なベイズは、実際には優れた、時には優れた結果をもたらします。 特に、bag-of-words 機能を使用する NLP 問題では、多項式単純ベイズが役立ちます。

上記の引数は、見つけることができる単純ベイズ分類器と同じです。 あとは、モデル化の方法に依存します p(x₁|c₁)、p(x₂|c₁)、p(x₁|c₂)、p(x₂|c₂)、p(x₁|c₃) および p(x₂|c₃).

機能が 0 と 1 のみの場合は、 ベルヌーイ分布. それらが整数の場合、 多項分布. ただし、実際の機能値があり、 ガウシアン 分布であるため、ガウス ナイーブ ベイズという名前が付けられています。 以下の形を想定しています

 

ガウス単純ベイズの説明
著者による画像。

 

コラボレー μᵢ、ⱼ 平均であり、 σᵢ、ⱼ は、データから推定する必要がある標準偏差です。 これは、特徴ごとに XNUMX つの平均を取得することを意味します。 i クラスと結合 cⱼ, この場合、2*3=6 が意味します。 標準偏差も同様です。 これには例が必要です。

試算してみましょう μ₂、₁ および σ₂、₁。 なぜなら j=1、クラスのみに関心があります c₁、このラベルが付いたサンプルのみを保持しましょう。 次のサンプルが残ります。

# samples with label = c_1 array([[ 0.14404357, 1.45427351], [ 0.97873798, 2.2408932 ], [ 1.86755799, -0.97727788], [ 1.76405235, 0.40015721], [ 0.76103773, 0.12167502], [-0.10321885, 0.4105985 ], [ 0.95008842, -0.15135721]])

今、そのせいで i=2 XNUMX 番目の列だけを考慮する必要があります。 μ₂,₁ は平均であり、 σ₂,₁ この列の標準偏差、つまり μ₂,₁ = 0.49985176 および σ₂、₁ = 0.9789976。

これらの数値は、散布図を上からもう一度見ると意味があります。 特徴 ×₂ クラスからのサンプルの c₁ 写真からわかるように、約0.5です。

残りの XNUMX つの組み合わせについてこれを計算したら、完了です。

Python では、次のように実行できます。

from sklearn.datasets import make_blobs
import numpy as np # Create the data. The classes are c_1=0, c_2=1 and c_3=2.
X, y = make_blobs( n_samples=20, centers=[(0, 0), (5, 5), (-5, 5)], random_state=0
) # The class probabilities.
# np.bincounts counts the occurence of each label.
prior = np.bincount(y) / len(y) # np.where(y==i) returns all indices where the y==i.
# This is the filtering step.
means = np.array([X[np.where(y == i)].mean(axis=0) for i in range(3)])
stds = np.array([X[np.where(y == i)].std(axis=0) for i in range(3)])

受け取る

# priors
array([0.35, 0.35, 0.3 ])
# means array([[ 0.90889988, 0.49985176], [ 5.4111385 , 4.6491892 ], [-4.7841679 , 5.15385848]])
# stds
array([[0.6853714 , 0.9789976 ], [1.40218915, 0.67078568], [0.88192625, 1.12879666]])

これは、ガウス単純ベイズ分類器のトレーニングの結果です。

予測する

完全な予測式は

 

ガウス単純ベイズの説明
著者による画像。

 

新しいデータポイントを仮定しましょう x*=(-2, 5) が入ります。

 

ガウス単純ベイズの説明
著者による画像。

 

それがどのクラスに属しているかを確認するために、計算してみましょう p(c|x*) すべてのクラス。 写真から、それはクラスに属している必要があります c₃ = 2 ですが、見てみましょう。 分母は無視しよう p(x) ちょっと。 次のループを使用して、 j = 1、2、3。

x_new = np.array([-2, 5]) for j in range(3): print( f"Probability for class {j}: {(1/np.sqrt(2*np.pi*stds[j]**2)*np.exp(-0.5*((x_new-means[j])/stds[j])**2)).prod()*p[j]:.12f}" )

受け取る

Probability for class 0: 0.000000000263
Probability for class 1: 0.000000044359
Probability for class 2: 0.000325643718

もちろん、これら 確率 (そのように呼ぶべきではありません) 分母を無視したので、足し合わせて 0.00032569 にしないでください。 ただし、これらの正規化されていない確率を取り、それらの合計で割ると、合計が XNUMX になるため、これは問題ありません。 したがって、これら XNUMX つの値を合計約 XNUMX で割ると、次のようになります。

 

ガウス単純ベイズの説明
著者による画像。

 

予想通り、明らかな勝者です。 では、実装していきましょう!

この実装は明らかに効率的ではなく、数値的にも安定しておらず、教育目的のみに役立ちます。 ほとんどのことについて話し合ったので、今は簡単に理解できるはずです。 すべて無視できます check 関数、または私の記事を読む 独自のカスタム scikit-learn を構築する 彼らが正確に何をするかに興味があるなら。

私が実装したことに注意してください predict_proba 最初にメソッドを使用して確率を計算します。 方法 predict このメソッドを呼び出すだけで、argmax 関数を使用して最も高い確率でインデックス (= クラス) を返します (これもまたあります!)。 クラスは 0 から XNUMX までのクラスを待機します k-1、ここで k クラス数です。

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted class GaussianNaiveBayesClassifier(BaseEstimator, ClassifierMixin): def fit(self, X, y): X, y = check_X_y(X, y) self.priors_ = np.bincount(y) / len(y) self.n_classes_ = np.max(y) + 1 self.means_ = np.array( [X[np.where(y == i)].mean(axis=0) for i in range(self.n_classes_)] ) self.stds_ = np.array( [X[np.where(y == i)].std(axis=0) for i in range(self.n_classes_)] ) return self def predict_proba(self, X): check_is_fitted(self) X = check_array(X) res = [] for i in range(len(X)): probas = [] for j in range(self.n_classes_): probas.append( ( 1 / np.sqrt(2 * np.pi * self.stds_[j] ** 2) * np.exp(-0.5 * ((X[i] - self.means_[j]) / self.stds_[j]) ** 2) ).prod() * self.priors_[j] ) probas = np.array(probas) res.append(probas / probas.sum()) return np.array(res) def predict(self, X): check_is_fitted(self) X = check_array(X) res = self.predict_proba(X) return res.argmax(axis=1)

実装のテスト

コードは非常に短いですが、間違いがなかったことを完全に確認するには長すぎます。 それで、それがどのようにうまくいくかを確認しましょう scikit-learn GaussianNB 分類器.

my_gauss = GaussianNaiveBayesClassifier()
my_gauss.fit(X, y)
my_gauss.predict_proba([[-2, 5], [0,0], [6, -0.3]])

outputs

array([[8.06313823e-07, 1.36201957e-04, 9.99862992e-01], [1.00000000e+00, 4.23258691e-14, 1.92051255e-11], [4.30879705e-01, 5.69120295e-01, 9.66618838e-27]])

を使用した予測 predict 方法は

# my_gauss.predict([[-2, 5], [0,0], [6, -0.3]])
array([2, 0, 1])

それでは、scikit-learn を使ってみましょう。 いくつかのコードを投げる

from sklearn.naive_bayes import GaussianNB gnb = GaussianNB()
gnb.fit(X, y)
gnb.predict_proba([[-2, 5], [0,0], [6, -0.3]])

収量

array([[8.06314158e-07, 1.36201959e-04, 9.99862992e-01], [1.00000000e+00, 4.23259111e-14, 1.92051343e-11], [4.30879698e-01, 5.69120302e-01, 9.66619630e-27]])

数値は分類子の数値と似ていますが、表示されている最後の数桁が少しずれています。 私たちは何か悪いことをしましたか? いいえ。 scikit-learn バージョンは単に別のハイパーパラメータを使用するだけです var_smoothing=1e-09 . これを ゼロ、正確に数値を取得します。 完全!

分類器の決定領域を見てください。 また、テストに使用した 56.9 つのポイントに印を付けました。 境界に近いその XNUMX 点が赤いクラスに属する可能性は XNUMX% しかありません。 predict_proba 出力します。 他の XNUMX つのポイントは、はるかに高い信頼度で分類されます。

 

ガウス単純ベイズの説明

3 つの新しい点がある決定領域。 著者による画像。

 

この記事では、ガウス単純ベイズ分類器がどのように機能するかを学び、そのように設計された理由について直感を示しました。これは、関心のある確率をモデル化するための直接的なアプローチです。 これをロジスティック回帰と比較してください。そこでは、シグモイド関数が適用された線形関数を使用して確率がモデル化されます。 これはまだ簡単なモデルですが、単純なベイズ分類器ほど自然には感じられません。

引き続き、いくつかの例を計算し、途中でいくつかの有用なコードを収集しました。 最後に、scikit-learn とうまく連携する方法で完全なガウス単純ベイズ分類器を実装しました。 つまり、パイプラインやグリッド検索などで使用できます。

最後に、scikit-learns 独自の Gaussian 単純ベイズ分類器をインポートし、私たちと scikit-learn の分類器の両方が同じ結果をもたらすかどうかをテストすることで、小さな健全性チェックを行いました。 このテストは成功しました。

 
 
ロバート・キューブラー博士 Publicis Mediaのデータサイエンティストであり、Towards DataScienceの著者です。

 
元の。 許可を得て転載。
 

タイムスタンプ:

より多くの KDナゲット