基于 brain.js 的 Objective-C 神经网络库,适用于 iOS 和 Mac OS X。
此示例使用神经网络近似 XOR 函数
#import "SNNeuralNet.h"
SNTrainingRecord records[] = {
{SNInput(0,0), SNOutput(0)},
{SNInput(0,1), SNOutput(1)},
{SNInput(1,0), SNOutput(1)},
{SNInput(1,1), SNOutput(0)}
};
SNNeuralNet *net = [[SNNeuralNet alloc] initWithTrainingData:records
numRecords:4
numInputs:2
numOutputs:1];
double *output = [net runInput:SNInput(1, 0)];
printf("%f\n", output[0]); // 0.987
有几种方式可以创建一个 SNNeuralNet
实例。基类初始化方法如下
SNNeuralNet *net = [[SNNeuralNet alloc] initWithInputs:2 outputs:1];
这将创建一个默认的神经网络,包含一个隐藏层,这对于许多用途来说是足够的。然而,你也可以创建一个包含更多隐藏层并且大小可以定制的神经网络。
SNNeuralNet *net = [[SNNeuralNet alloc] initWithInputs:2 hiddenLayers:@[@3, @4] outputs:1];
此外,你可以一步创建和训练网络,尽管这会使得配置网络和创建自定义隐藏层变得不可能。
// see the section on training below
SNTrainingRecord records[] = {
{SNInput(0,0), SNOutput(0)},
{SNInput(0,1), SNOutput(1)},
{SNInput(1,0), SNOutput(1)},
{SNInput(1,1), SNOutput(0)}
};
SNNeuralNet *net = [[SNNeuralNet alloc] initWithTrainingData:records
numRecords:4
numInputs:2
numOutputs:1];
SNNeuralNet 有几个可配置的属性。一旦你拥有一个实例,你可以设置这些属性。理论上,这些属性在训练后是没有效果的,所以请确信使用不进行训练的构造函数。它们的默认值如下所示。
net.maxIterations = 20000; // maximum training iterations
net.minError = 0.005; // error threshold to reach
net.learningRate = 0.3; // influences how quickly the network trains
net.momentum = 0.1; // influences learning rate
在创建和配置你的网络之后,你应该用一些已知的数据对其进行训练。网络只能训练一次,因此请一次包含所有训练数据。如果你尝试多次调用 train
方法,将返回 -1
。你可以通过检查 net.isTrained
来查看一个网络是否已经训练过。
在训练之前,你需要制作一个 SNTrainingRecord
数组。这些是 C 结构,包含输入和输出双精度浮点数数组。你可以有尽可能多的输入或输出,但所有数据必须具有相同的输入和输出数量。存在名为 SNInput
和 SNOutput
的便利宏,可以用来包裹数组中的数据。
当你准备好训练网络时,使用 train
方法,并传递记录数组和记录数量。
SNTrainingRecord records[] = {
{SNInput(0,0), SNOutput(0)},
{SNInput(0,1), SNOutput(1)},
{SNInput(1,0), SNOutput(1)},
{SNInput(1,1), SNOutput(0)}
};
double error = [net train:records numRecords:4];
train
方法返回训练中发生的错误量,通常应低于 net.minError
,除非达到了 net.maxIterations
。
神经网络的真实用途在于其预测未知输入的能力。一旦您的SNNeuralNet
经过训练,您可以使用runInput
方法来获取预测输出。
double *output = [net runInput:SNInput(1, 0.4, 0)];
runInput
方法返回一个包含net.numOutputs
个条目的double数组。
MIT