Sometimes you need your model training phase stops if you reach a certain threshold. Maybe this way you can avoid overfitting or you have a pipeline and you want it deal with these kind of things automatically. Whatever your reasons are, there is a solutions: Callbacks. According to wikipedia, callback is:
"In computer programming, a callback, also known as a "call-after" function, is any executable code that is passed as an argument to other code; that other code is expected to call back (execute) the argument at a given time. This execution may be immediate as in a synchronous callback, or it might happen at a later time as in an asynchronous callback. Programming languages support callbacks in different ways, often implementing them with subroutines, lambda expressions, blocks, or function pointers."
Here is a small chunk of code that you can use in Tensorflow-Keras.
def train_mnist():
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('acc')>0.99):
print("\nReached 99% accuracy so cancelling training!")
self.model.stop_training = True
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
callbacks = myCallback()
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# model fitting
history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])
# model fitting
return history.epoch, history.history['acc'][-1]
References:
https://en.wikipedia.org/wiki/Callback_(computer_programming)
https://www.coursera.org/learn/introduction-tensorflow/
Top comments (0)