nncf.common.accuracy_aware_training.training_loop#

Implementations of training loops to be used for accuracy aware training.

Classes#

TrainingLoop

The training loop object that launches the training process via the run method.

BaseEarlyExitCompressionTrainingLoop

Base class to generalize functionality of derived training loop classes.

EarlyExitCompressionTrainingLoop

Training loop that does not modify compression parameters and exits as soon as (and if) the accuracy drop criterion

AdaptiveCompressionTrainingLoop

A training loop that automatically adjusts compression rate to reach maximum compression within accuracy budget.

class nncf.common.accuracy_aware_training.training_loop.TrainingLoop[source]#

Bases: abc.ABC

The training loop object that launches the training process via the run method.

abstract property statistics: nncf.common.accuracy_aware_training.statistics.TrainingLoopStatistics[source]#

Returns statistics of the compressed model.

Return type:

nncf.common.accuracy_aware_training.statistics.TrainingLoopStatistics

abstract run(model, train_epoch_fn, validate_fn, configure_optimizers_fn=None, dump_checkpoint_fn=None, load_checkpoint_fn=None, early_stopping_fn=None, tensorboard_writer=None, log_dir=None, update_learning_rate_fn=None)[source]#

Implements the custom logic to run a training loop for model fine-tuning by using the provided train_epoch_fn, validate_fn and configure_optimizers_fn methods.

Parameters:
  • model (TModel) – The model instance before fine-tuning

  • train_epoch_fn (Callable) – a callback to fine-tune the model for a single epoch

  • validate_fn (Callable) – a callback to evaluate the model on the validation dataset

  • configure_optimizers_fn (Callable) – a callback to instantiate an optimizer and a learning rate scheduler

  • dump_checkpoint_fn (Callable) – a callback to dump a checkpoint

  • load_checkpoint_fn (Callable) – a callback to load a checkpoint

  • early_stopping_fn (Callable) – a callback to check for an early stopping condition

  • tensorboard_writer (Optional[TensorboardWriterType]) – The tensorboard object to be used for logging.

  • log_dir (Union[pathlib.Path, str]) – The path to be used for logging and checkpoint saving.

  • update_learning_rate_fn (Callable) – The callback to update the learning rate after each epoch of the training loop.

Returns:

The fine-tuned model.

class nncf.common.accuracy_aware_training.training_loop.BaseEarlyExitCompressionTrainingLoop(compression_controller)[source]#

Bases: TrainingLoop, abc.ABC

Base class to generalize functionality of derived training loop classes.

Parameters:

compression_controller (nncf.api.compression.CompressionAlgorithmController) –

property statistics: nncf.common.accuracy_aware_training.statistics.TrainingLoopStatistics[source]#

Returns statistics of the compressed model.

Return type:

nncf.common.accuracy_aware_training.statistics.TrainingLoopStatistics

run(model, train_epoch_fn, validate_fn, configure_optimizers_fn=None, dump_checkpoint_fn=None, load_checkpoint_fn=None, early_stopping_fn=None, tensorboard_writer=None, log_dir=None, update_learning_rate_fn=None)[source]#

Implements the custom logic to run a training loop for model fine-tuning by using the provided train_epoch_fn, validate_fn and configure_optimizers_fn methods.

Parameters:
  • model (TModel) – The model instance before fine-tuning

  • train_epoch_fn (Callable) – a callback to fine-tune the model for a single epoch

  • validate_fn (Callable) – a callback to evaluate the model on the validation dataset

  • configure_optimizers_fn (Callable) – a callback to instantiate an optimizer and a learning rate scheduler

  • dump_checkpoint_fn (Callable) – a callback to dump a checkpoint

  • load_checkpoint_fn (Callable) – a callback to load a checkpoint

  • early_stopping_fn (Callable) – a callback to check for an early stopping condition

  • tensorboard_writer (Optional[TensorboardWriterType]) – The tensorboard object to be used for logging.

  • log_dir (Union[pathlib.Path, str]) – The path to be used for logging and checkpoint saving.

  • update_learning_rate_fn (Callable) – The callback to update the learning rate after each epoch of the training loop.

Returns:

The fine-tuned model.

class nncf.common.accuracy_aware_training.training_loop.EarlyExitCompressionTrainingLoop(nncf_config, compression_controller, uncompressed_model_accuracy, lr_updates_needed=True, verbose=True, dump_checkpoints=True)[source]#

Bases: BaseEarlyExitCompressionTrainingLoop

Training loop that does not modify compression parameters and exits as soon as (and if) the accuracy drop criterion is reached.

Parameters:
  • nncf_config (nncf.NNCFConfig) – The configuration object.

  • compression_controller (nncf.api.compression.CompressionAlgorithmController) – The controller for the compression algorithm that is currently applied to the model to be trained.

  • uncompressed_model_accuracy (float) – The uncompressed model accuracy, measured outside of this training loop to serve as the point of reference for fine-tuning the compressed model.

  • lr_updates_needed (bool) –

  • verbose (bool) – Whether to post additional data to TensorBoard.

  • dump_checkpoints (bool) – If true, will dump all checkpoints obtained during the training process, otherwise will only keep the best checkpoint (accuracy-wise).

class nncf.common.accuracy_aware_training.training_loop.AdaptiveCompressionTrainingLoop(nncf_config, compression_controller, uncompressed_model_accuracy, lr_updates_needed=True, verbose=True, minimal_compression_rate=0.0, maximal_compression_rate=0.95, dump_checkpoints=True)[source]#

Bases: BaseEarlyExitCompressionTrainingLoop

A training loop that automatically adjusts compression rate to reach maximum compression within accuracy budget.

Parameters:
  • nncf_config (nncf.NNCFConfig) – The configuration object.

  • compression_controller (nncf.api.compression.CompressionAlgorithmController) – The controller for the compression algorithm that is currently applied to the model to be trained.

  • uncompressed_model_accuracy (float) – The uncompressed model accuracy, measured outside of this training loop to serve as the point of reference for fine-tuning the compressed model.

  • lr_updates_needed (bool) –

  • verbose (bool) – Whether to post additional data to TensorBoard.

  • minimal_compression_rate (float) – Sets the minimal compression rate to be considered during the training loop.

  • maximal_compression_rate (float) – Sets the maximal compression rate to be considered during the training loop.

  • dump_checkpoints (bool) – If true, will dump all checkpoints obtained during the training process, otherwise will only keep the best checkpoint (accuracy-wise).

run(model, train_epoch_fn, validate_fn, configure_optimizers_fn=None, dump_checkpoint_fn=None, load_checkpoint_fn=None, early_stopping_fn=None, tensorboard_writer=None, log_dir=None, update_learning_rate_fn=None)[source]#

Implements the custom logic to run a training loop for model fine-tuning by using the provided train_epoch_fn, validate_fn and configure_optimizers_fn methods.

Parameters:
  • model (TModel) – The model instance before fine-tuning

  • train_epoch_fn (Callable) – a callback to fine-tune the model for a single epoch

  • validate_fn (Callable) – a callback to evaluate the model on the validation dataset

  • configure_optimizers_fn (Callable) – a callback to instantiate an optimizer and a learning rate scheduler

  • dump_checkpoint_fn (Callable) – a callback to dump a checkpoint

  • load_checkpoint_fn (Callable) – a callback to load a checkpoint

  • early_stopping_fn (Callable) – a callback to check for an early stopping condition

  • tensorboard_writer (Optional[TensorboardWriterType]) – The tensorboard object to be used for logging.

  • log_dir (Union[pathlib.Path, str]) – The path to be used for logging and checkpoint saving.

  • update_learning_rate_fn (Callable) – The callback to update the learning rate after each epoch of the training loop.

Returns:

The fine-tuned model.