如何在scikit-learn中将用户定义的指标用于最近的邻居?

Shi*_*dey 3 metrics distance nearest-neighbor scikit-learn

我正在使用scikit-learn 0.18.dev0。我知道在这里之前有人问过完全相同的问题。我尝试了此处显示的答案,但出现以下错误

>>> def mydist(x, y):
...     return np.sum((x-y)**2)
...
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3,   2]])

>>> nbrs = NearestNeighbors(n_neighbors=4, algorithm='ball_tree',
...            metric='pyfunc', func=mydist)
Run Code Online (Sandbox Code Playgroud)

错误信息 _init_params() got an unexpected keyword argument 'func'

看来此选项已被删除。如何在中使用用户定义的矩阵sklearn.neighbors

eic*_*erg 5

正确的关键字是metric

import numpy as np
from sklearn.neighbors import NearestNeighbors

def mydist(x, y):
    return np.sum((x-y)**2)

nn = NearestNeighbors(n_neighbors=4, algorithm='ball_tree', metric=myfunc)

X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3,   2]])
nn.fit(X)
Run Code Online (Sandbox Code Playgroud)

开发版本的文档字符串中也提到了这一点:https : //github.com/scikit-learn/scikit-learn/blob/86b1ba72771718acbd1e07fbdc5caaf65ae65440/sklearn/neighbors/unsupervised.py#L48