Scikit-Learn接口包装器

我们可以通过包装器将Sequential模型(仅有一个输入)作为Scikit-Learn工作流的一部分,相关的包装器定义在keras.wrappers.scikit_learn.py

目前,有两个包装器可用:

keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)实现了sklearn的分类器接口

keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params)实现了sklearn的回归器接口

参数

  • build_fn:可调用的函数或类对象

  • sk_params:模型参数和训练参数

build_fn应构造、编译并返回一个Keras模型,该模型将稍后用于训练/测试。build_fn的值可能为下列三种之一:

  1. 一个函数

  2. 一个具有call方法的类对象

  3. None,代表你的类继承自KerasClassifierKerasRegressor,其call方法为其父类的call方法

sk_params以模型参数和训练(超)参数作为参数。合法的模型参数为build_fn的参数。注意,‘build_fn’应提供其参数的默认值。所以我们不传递任何值给sk_params也可以创建一个分类器/回归器

sk_params还接受用于调用fitpredictpredict_probascore方法的参数,如nb_epochbatch_size等。这些用于训练或预测的参数按如下顺序选择:

  1. 传递给fitpredictpredict_probascore的字典参数

  2. 传递个sk_params的参数

  3. keras.models.Sequentialfitpredictpredict_probascore的默认值

当使用scikit-learn的grid_search接口时,合法的可转换参数是你可以传递给sk_params的参数,包括训练参数。即,你可以使用grid_search来搜索最佳的batch_sizenb_epoch以及其他模型参数