Spark中的train(),run()和fit()函数之间的区别

dim*_*ima 1 machine-learning

在Java中使用Apache Spark(版本1.5.2)进行逻辑回归有几种选择:

spark.ml:

1) LogisticRegression lr = new LogisticRegression();
a) lr.train(dataFrame);
b) lr.fit(dataFrame);
Run Code Online (Sandbox Code Playgroud)

spark.mllib:

2) LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD();
a) lr.train(rdd);
b) lr.run(rdd);

3) LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS();
a) lr.train(rdd);
b) lr.run(rdd);
Run Code Online (Sandbox Code Playgroud)

我想知道a)和b)之间有什么区别,除了run()函数的GeneralizedLinearAlgorithm输出而不是另一个的LogisticRegressionModel?我在Java或Scala文档中找不到任何提示.在此先感谢您的帮助.

Vin*_*Bdn 8

Spark确实包含两个可用于机器学习的库:ML和MLLib.你能指定一下你正在使用的Spark版本吗?

MLLib.这是Spark的第一个机器学习库.它实际上具有非常浅的结构并且用于RDD运行.这在MLLib中是一种无政府状态,所以你必须查看代码才能知道使用哪一个.我不确定你使用的是哪种语言或版本,但对于scala上的Spark 1.6.0,有一个单例:

object LogisticRegressionWithSGD {
   def train(input: RDD[LabeledPoint], ...) = new LogisticRegressionWithSGD(...).run(input,...)
}
Run Code Online (Sandbox Code Playgroud)

这意味着火车将作为对象的静态方法被调用LogisticRegressionWithSGD,但是如果你有一个LogisticRegressionWithSGD只有一个run方法的实例:

LogisticRegressionWithSGD.train(rdd, parameters) 
// OR
val lr = new LogisticRegressionWithSGD() 
lr.run(rdd)
Run Code Online (Sandbox Code Playgroud)

无论如何,如果你有另一个版本,你绝对可以使用超级版本,即run.

ML.它是最新的库,它基于使用DataFrame,它基本上是一个RDD[Row](Row只是一个无类型对象的序列)和一个模式(即一个包含有关列名称,类型,元数据......的信息的对象).我绝对建议你使用它,因为它可以实现优化!在这种情况下,您应该使用fit所有估算器需要实现的方法.

说明: ML库使用的概念Pipeline(与sci-kit学习中的一样).管道实例基本上是一个阶段(类型PipelineStage)的数组,每个阶段都是一个Estimator或一个Transformer(有一些其他类型,例如Evaluator但我不会在这里进入它们,因为它们很少见).A Transformer只是一种转换数据的算法,所以它的主要方法是transform(DataFrame)输出另一个DataFrame.An Estimator是一种产生Model(子类型Transformer)的算法.它基本上是任何需要适应数据的块,所以它有一个fit(DataFrame)输出a 的函数Transformer.例如,如果你想将所有数据乘以$ 2 $,你只需要一个变换器来实现一个转换方法,它接受你的输入并乘以$ 2 $.如果您需要计算平均值并将其减去,则需要一个适合数据的估算器来计算平均值,并输出一个变换器,即减去所学习的平均值.因此,只要您使用ML,请使用fittransform方法.它允许您执行以下操作:

val trainingSet = // training DataFrame
val testSet = // test DataFrame
val lr = new LogisticRegession().setInputCol(...).setOutputCol(...) // + setParams()
val stage = // another stage, i.e. something that implements PipelineStage
val stages = Array(lr, stage)
val pipeline: Pipeline = new Pipeline().setStages(stages)
val model: PipelineModel = pipeline.fit(trainingSet)
val result: DataFrame = model.transform(testSet)
Run Code Online (Sandbox Code Playgroud)

现在,如果你真的想知道为什么train存在,它就是一个继承Predictor自己的功能Estimator.确实有一些可能的音调Estimators- 你可以计算平均值,IDF,......当你实现一个预测器,如逻辑回归,你有一个Predictor扩展的抽象类,Estimator并允许你一些快捷方式(例如它有一个标签列,功能列和预测列).特别是这段代码已经覆盖fit了相应的标签/功能/预测更改数据框的架构,你只需要实现自己的列车:

override def fit(dataset: DataFrame): M = {
   // This handles a few items such as schema validation.
   // Developers only need to implement train().
   transformSchema(dataset.schema, logging = true)
   copyValues(train(dataset).setParent(this))
}
protected def train(dataset: DataFrame): M
Run Code Online (Sandbox Code Playgroud)

如您所见,该train方法应该受到保护/私有,因此外部用户不会使用.

  • 不完全正确,调用 fit(dataframe) 将负责预测器常见的所有处理,例如检查标签和特征列是否存在并在模式中创建预测列。train() 完成核心工作,但是私有/受保护的,不应在外部使用(因为您不知道您是否使用它,您的标签列是否存在,例如)。此外,ML 的主要用途是将算法序列包装到 Pipelines 中,在本例中是调用 fit() 函数。 (2认同)