Введение

В последние годы машинное обучение становится все более популярным, поскольку оно обеспечивает решение различных реальных проблем. Java, универсальный язык программирования, также можно использовать для разработки приложений машинного обучения с использованием библиотеки Deeplearning4j (DL4J). В этом сообщении блога мы рассмотрим основы машинного обучения в Java с использованием DL4J и рассмотрим простой пример, чтобы продемонстрировать его возможности.

Настройка Deeplearning4j

Чтобы начать работу с DL4J, вам необходимо настроить проект Java. Мы рекомендуем использовать Maven или Gradle в качестве инструмента сборки. В этом уроке мы будем использовать Maven.

Сначала создайте новый проект Maven и добавьте следующие зависимости в файл pom.xml:

<dependencies>
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-core</artifactId>
        <version>1.0.0-beta7</version>
    </dependency>
    <dependency>
        <groupId>org.nd4j</groupId>
        <artifactId>nd4j-native-platform</artifactId>
        <version>1.0.0-beta7</version>
    </dependency>
</dependencies>

Подготовка набора данных

В этом примере мы будем использовать популярный набор данных Iris, который содержит 150 образцов цветков ириса с четырьмя характеристиками: длина чашелистика, ширина чашелистика, длина лепестка и ширина лепестка. Набор данных состоит из трех классов, каждый из которых представляет тип цветка ириса: Iris Setosa, Iris Versicolor и Iris Virginica.

Мы будем использовать встроенный итератор набора данных DL4J для загрузки набора данных Iris. Добавьте следующий код в метод main:

import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public static void main(String[] args) {
    int batchSize = 10;
    int numClasses = 3;
    DataSetIterator iterator = new IrisDataSetIterator(batchSize, numClasses);
}

Создание модели нейронной сети

Теперь, когда у нас есть набор данных, давайте создадим простую модель нейронной сети. Мы будем использовать многослойный персептрон с прямой связью (MLP) с двумя скрытыми слоями. Добавьте следующий код в свой метод main:

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;

MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .updater(new Nesterovs(0.1, 0.9))
        .list()
        .layer(0, new DenseLayer.Builder().nIn(4).nOut(10).activation(Activation.RELU).build())
        .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.RELU).build())
        .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
        .activation(Activation.SOFTMAX)
        .nIn(10)
        .nOut(numClasses)
        .build())
    .build();

MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();

Обучение модели

Теперь, когда у нас настроена модель нейронной сети, мы можем обучить ее, используя наш набор данных Iris. Мы будем обучать модель на 100 эпох. Добавьте следующий код в свой метод main:

int numEpochs = 100;
for (int i = 0; i < numEpochs; i++) {
    iterator.reset();
    model.fit(iterator);
}

Оценка модели

После обучения модели нам необходимо оценить ее производительность на тестовых данных. Для этого мы будем использовать встроенный в DL4J класс Evaluation. Добавьте следующий код в свой метод main:

import org.deeplearning4j.eval.Evaluation;
import org.nd4j.linalg.dataset.DataSet;

iterator.reset();
Evaluation evaluation = new Evaluation(numClasses);
while (iterator.hasNext()) {
    DataSet batch = iterator.next();
    model.rnnClearPreviousState();
    evaluation.eval(batch.getLabels(), model.output(batch.getFeatures()));
}

System.out.println(evaluation.stats());

Запуск приложения

Теперь вы можете запустить свое Java-приложение. Выходные данные должны отображать оценочную статистику для обученной модели, включая точность, достоверность, полноту и оценку F1.

Заключение

В этом сообщении блога мы представили машинное обучение на Java с использованием библиотеки Deeplearning4j. Мы продемонстрировали, как настроить проект, загрузить набор данных, создать простую модель нейронной сети, обучить модель и оценить ее производительность. С помощью этой основы вы сможете дополнительно изучить возможности DL4J и применить его к более сложным задачам машинного обучения.

  1. Веб-сайт Deeplearning4j
  2. Репозиторий Deeplearning4j на GitHub

Понравилось читать? Еще не являетесь участником Medium? Вы можете поддержать мою работу напрямую, зарегистрировавшись по моей реферальной ссылке здесь. Это быстро, просто и не требует дополнительных затрат. Спасибо за вашу поддержку!