Как использовать обратный вызов, чтобы остановить обучение при адекватной производительности

В этой статье я объясню, как управлять обучением нейронной сети в Tensorflow с помощью обратных вызовов. Обратный вызов — это функция, которая многократно вызывается во время процесса (например, обучения нейронной сети) и обычно служит для подтверждения или исправления определенных действий.

В машинном обучении мы можем использовать обратные вызовы, чтобы определить, что происходит до, во время или в конце эпохи обучения. Это особенно полезно для регистрации производительности или остановки обучения, если наша метрика производительности достигает определенного порога. Этот механизм называется ранняя остановка.

Например, если вы установили 1000 эпох, а желаемая точность уже достигнута к эпохе 200, обучение остановится автоматически, избавляя вас от возможной переобучения вашей модели. Давайте посмотрим, как это реализовано в Tensorflow и Питон.

Давайте заложим основу, импортировав набор данных fashion_mnist из Tensorflow. Мы будем использовать этот набор данных, чтобы объяснить, как работают обратные вызовы.

Вот как выглядит наш набор данных

Класс ранней остановки

Второй шаг – создание класса, посвященного ранней остановке. Мы создадим класс, который будет наследоваться от tf.keras.callbacks.Callback и позволит нам остановить обучение, когда будет достигнута точность 95%. Обратный вызов будет использовать функцию on_epoch_end, чтобы остановить обучение, если условие выполняется при просмотре журналов, предоставленных моделью Tensorflow.

Здесь мы обращаемся к методу on_epoch_end, унаследованному от tf.keras.callbacks.Callback, и переопределяем его поведение, кодируя условие, которое приведет к остановке обучения.

Давайте продолжим реализацию нашей модели Tensorflow.

Задача классификации с глубокой нейронной сетью

Мы будем использовать нейронную сеть с несколькими слоями для классификации одежды в наборе данных. Лучшим подходом было бы использование сверточной нейронной сети, но для этого примера отлично подойдет глубокая нейронная сеть.

Теперь мы готовы обучить модель. Чтобы использовать обратный вызов, просто поместите объект в список, который будет передан аргументу обратного вызова в методе fit() модели.

Вот как настроить обратный вызов для управления обучением нейронной сети! Надеюсь, вы сегодня узнали что-то новое 👍

Оставайтесь сильными и переживайте трудные моменты. Это будет того стоить. Ваше здоровье!

Шаблон кода