Введение
В последние годы машинное обучение становится все более популярным, поскольку оно обеспечивает решение различных реальных проблем. Java, универсальный язык программирования, также можно использовать для разработки приложений машинного обучения с использованием библиотеки Deeplearning4j (DL4J). В этом сообщении блога мы рассмотрим основы машинного обучения в Java с использованием DL4J и рассмотрим простой пример, чтобы продемонстрировать его возможности.
Настройка Deeplearning4j
Чтобы начать работу с DL4J, вам необходимо настроить проект Java. Мы рекомендуем использовать Maven или Gradle в качестве инструмента сборки. В этом уроке мы будем использовать Maven.
Сначала создайте новый проект Maven и добавьте следующие зависимости в файл pom.xml
:
<dependencies> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-beta7</version> </dependency> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> <version>1.0.0-beta7</version> </dependency> </dependencies>
Подготовка набора данных
В этом примере мы будем использовать популярный набор данных Iris, который содержит 150 образцов цветков ириса с четырьмя характеристиками: длина чашелистика, ширина чашелистика, длина лепестка и ширина лепестка. Набор данных состоит из трех классов, каждый из которых представляет тип цветка ириса: Iris Setosa, Iris Versicolor и Iris Virginica.
Мы будем использовать встроенный итератор набора данных DL4J для загрузки набора данных Iris. Добавьте следующий код в метод main
:
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; public static void main(String[] args) { int batchSize = 10; int numClasses = 3; DataSetIterator iterator = new IrisDataSetIterator(batchSize, numClasses); }
Создание модели нейронной сети
Теперь, когда у нас есть набор данных, давайте создадим простую модель нейронной сети. Мы будем использовать многослойный персептрон с прямой связью (MLP) с двумя скрытыми слоями. Добавьте следующий код в свой метод main
:
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Nesterovs(0.1, 0.9)) .list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(10).activation(Activation.RELU).build()) .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.RELU).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX) .nIn(10) .nOut(numClasses) .build()) .build(); MultiLayerNetwork model = new MultiLayerNetwork(configuration); model.init();
Обучение модели
Теперь, когда у нас настроена модель нейронной сети, мы можем обучить ее, используя наш набор данных Iris. Мы будем обучать модель на 100 эпох. Добавьте следующий код в свой метод main
:
int numEpochs = 100; for (int i = 0; i < numEpochs; i++) { iterator.reset(); model.fit(iterator); }
Оценка модели
После обучения модели нам необходимо оценить ее производительность на тестовых данных. Для этого мы будем использовать встроенный в DL4J класс Evaluation
. Добавьте следующий код в свой метод main
:
import org.deeplearning4j.eval.Evaluation; import org.nd4j.linalg.dataset.DataSet; iterator.reset(); Evaluation evaluation = new Evaluation(numClasses); while (iterator.hasNext()) { DataSet batch = iterator.next(); model.rnnClearPreviousState(); evaluation.eval(batch.getLabels(), model.output(batch.getFeatures())); } System.out.println(evaluation.stats());
Запуск приложения
Теперь вы можете запустить свое Java-приложение. Выходные данные должны отображать оценочную статистику для обученной модели, включая точность, достоверность, полноту и оценку F1.
Заключение
В этом сообщении блога мы представили машинное обучение на Java с использованием библиотеки Deeplearning4j. Мы продемонстрировали, как настроить проект, загрузить набор данных, создать простую модель нейронной сети, обучить модель и оценить ее производительность. С помощью этой основы вы сможете дополнительно изучить возможности DL4J и применить его к более сложным задачам машинного обучения.