在 scikit-learn 中为 k-nn 使用用户定义的距离度量

Cue*_*nta 2 python scikit-learn

我有这个代码:

import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt 
import sklearn.neighbors as ng 

def mydist(x, y):
    return np.sum((x-y)**2)

if __name__ == '__main__':
    nn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',metric='mydist')
Run Code Online (Sandbox Code Playgroud)

我正在使用 sci-kit learn 0.18.1 并且出现此错误

ValueError: Metric 'mydist' not valid for algorithm 'ball_tree'
Run Code Online (Sandbox Code Playgroud)

我也尝试使用 algorithm = 'brute' 但错误仍然存​​在。

这是什么原因造成的?如何正确使用用户定义的距离度量?

Max*_*axU 5

以下是ball_tree算法的有效指标列表- 在scikit-learn内部检查指定的指标是否在其中:

In [114]: from sklearn.neighbors import BallTree

In [115]: BallTree.valid_metrics
Out[115]:
['euclidean',
 'l2',
 'minkowski',
 'p',
 'manhattan',
 'cityblock',
 'l1',
 'chebyshev',
 'infinity',
 'seuclidean',
 'mahalanobis',
 'wminkowski',
 'hamming',
 'canberra',
 'braycurtis',
 'matching',
 'jaccard',
 'dice',
 'kulsinski',
 'rogerstanimoto',
 'russellrao',
 'sokalmichener',
 'sokalsneath',
 'haversine',
 'pyfunc']       # <--- NOTE
Run Code Online (Sandbox Code Playgroud)

所以尝试指定metric='pyfunc'metric_params={"func":mydist}

knn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',
                              metric='pyfunc', metric_params={"func":mydist})
Run Code Online (Sandbox Code Playgroud)