Tune 有六个需要理解的关键组件。
简而言之,Trainable 是一个可以传递到 Tune 运行中的对象。 Ray Tune 有两种定义可训练的方法,即函数 API 和类 API。两者都是定义可训练的有效方法,但通常建议使用 Function API,并在本指南的其余部分中使用。
假设我们想要优化一个简单的目标函数,例如 a (x ** 2) + b,其中 a 和 b 是我们想要调整以最小化目标的超参数。由于目标也有一个变量 x,我们需要测试 x 的不同值。给定 a、b 和 x 的具体选择,我们可以评估目标函数并获得最小化分数。
可训练对象:函数
from ray import train
def objective(x, a, b): # Define an objective function.
return a * (x**0.5) + b
def trainable(config): # Pass a "config" dictionary into your trainable.
for x in range(20): # "Train" for 20 iterations and compute intermediate scores.
score = objective(x, config["a"], config["b"])
train.report({
"score": score}) # Send the score to Tune.
可训练对象:类
from ray import tune
def objective(x, a, b):
return a * (x**2) + b
class Trainable(tune.Trainable):
def setup(self, config):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]
def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {
"score": score}