DEV Community

Sam Oz
Sam Oz

Posted on

How to make your model training stops when reached a threshold?

Sometimes you need your model training phase stops if you reach a certain threshold. Maybe this way you can avoid overfitting or you have a pipeline and you want it deal with these kind of things automatically. Whatever your reasons are, there is a solutions: Callbacks. According to wikipedia, callback is:

"In computer programming, a callback, also known as a "call-after" function, is any executable code that is passed as an argument to other code; that other code is expected to call back (execute) the argument at a given time. This execution may be immediate as in a synchronous callback, or it might happen at a later time as in an asynchronous callback. Programming languages support callbacks in different ways, often implementing them with subroutines, lambda expressions, blocks, or function pointers."

Here is a small chunk of code that you can use in Tensorflow-Keras.

def train_mnist():
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('acc')>0.99):
print("\nReached 99% accuracy so cancelling training!")
self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
callbacks = myCallback()
model = tf.keras.models.Sequential([

  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)

])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# model fitting
history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])
# model fitting
return history.epoch, history.history['acc'][-1]
Enter fullscreen mode Exit fullscreen mode

References:
https://en.wikipedia.org/wiki/Callback_(computer_programming)
https://www.coursera.org/learn/introduction-tensorflow/

Top comments (0)