nncf.common.accuracy_aware_training.training_loop
#
Implementations of training loops to be used for accuracy aware training.
Classes#
The training loop object that launches the training process via the run method. |
|
Base class to generalize functionality of derived training loop classes. |
|
Training loop that does not modify compression parameters and exits as soon as (and if) the accuracy drop criterion |
|
A training loop that automatically adjusts compression rate to reach maximum compression within accuracy budget. |
- class nncf.common.accuracy_aware_training.training_loop.TrainingLoop[source]#
Bases:
abc.ABC
The training loop object that launches the training process via the run method.
- abstract property statistics: nncf.common.accuracy_aware_training.statistics.TrainingLoopStatistics[source]#
Returns statistics of the compressed model.
- abstract run(model, train_epoch_fn, validate_fn, configure_optimizers_fn=None, dump_checkpoint_fn=None, load_checkpoint_fn=None, early_stopping_fn=None, tensorboard_writer=None, log_dir=None, update_learning_rate_fn=None)[source]#
Implements the custom logic to run a training loop for model fine-tuning by using the provided train_epoch_fn, validate_fn and configure_optimizers_fn methods.
- Parameters:
model (TModel) – The model instance before fine-tuning
train_epoch_fn (Callable) – a callback to fine-tune the model for a single epoch
validate_fn (Callable) – a callback to evaluate the model on the validation dataset
configure_optimizers_fn (Callable) – a callback to instantiate an optimizer and a learning rate scheduler
dump_checkpoint_fn (Callable) – a callback to dump a checkpoint
load_checkpoint_fn (Callable) – a callback to load a checkpoint
early_stopping_fn (Callable) – a callback to check for an early stopping condition
tensorboard_writer (Optional[TensorboardWriterType]) – The tensorboard object to be used for logging.
log_dir (Union[pathlib.Path, str]) – The path to be used for logging and checkpoint saving.
update_learning_rate_fn (Callable) – The callback to update the learning rate after each epoch of the training loop.
- Returns:
The fine-tuned model.
- class nncf.common.accuracy_aware_training.training_loop.BaseEarlyExitCompressionTrainingLoop(compression_controller)[source]#
Bases:
TrainingLoop
,abc.ABC
Base class to generalize functionality of derived training loop classes.
- Parameters:
compression_controller (nncf.api.compression.CompressionAlgorithmController) –
- property statistics: nncf.common.accuracy_aware_training.statistics.TrainingLoopStatistics[source]#
Returns statistics of the compressed model.
- run(model, train_epoch_fn, validate_fn, configure_optimizers_fn=None, dump_checkpoint_fn=None, load_checkpoint_fn=None, early_stopping_fn=None, tensorboard_writer=None, log_dir=None, update_learning_rate_fn=None)[source]#
Implements the custom logic to run a training loop for model fine-tuning by using the provided train_epoch_fn, validate_fn and configure_optimizers_fn methods.
- Parameters:
model (TModel) – The model instance before fine-tuning
train_epoch_fn (Callable) – a callback to fine-tune the model for a single epoch
validate_fn (Callable) – a callback to evaluate the model on the validation dataset
configure_optimizers_fn (Callable) – a callback to instantiate an optimizer and a learning rate scheduler
dump_checkpoint_fn (Callable) – a callback to dump a checkpoint
load_checkpoint_fn (Callable) – a callback to load a checkpoint
early_stopping_fn (Callable) – a callback to check for an early stopping condition
tensorboard_writer (Optional[TensorboardWriterType]) – The tensorboard object to be used for logging.
log_dir (Union[pathlib.Path, str]) – The path to be used for logging and checkpoint saving.
update_learning_rate_fn (Callable) – The callback to update the learning rate after each epoch of the training loop.
- Returns:
The fine-tuned model.
- class nncf.common.accuracy_aware_training.training_loop.EarlyExitCompressionTrainingLoop(nncf_config, compression_controller, uncompressed_model_accuracy, lr_updates_needed=True, verbose=True, dump_checkpoints=True)[source]#
Bases:
BaseEarlyExitCompressionTrainingLoop
Training loop that does not modify compression parameters and exits as soon as (and if) the accuracy drop criterion is reached.
- Parameters:
nncf_config (nncf.NNCFConfig) – The configuration object.
compression_controller (nncf.api.compression.CompressionAlgorithmController) – The controller for the compression algorithm that is currently applied to the model to be trained.
uncompressed_model_accuracy (float) – The uncompressed model accuracy, measured outside of this training loop to serve as the point of reference for fine-tuning the compressed model.
lr_updates_needed (bool) –
verbose (bool) – Whether to post additional data to TensorBoard.
dump_checkpoints (bool) – If true, will dump all checkpoints obtained during the training process, otherwise will only keep the best checkpoint (accuracy-wise).
- class nncf.common.accuracy_aware_training.training_loop.AdaptiveCompressionTrainingLoop(nncf_config, compression_controller, uncompressed_model_accuracy, lr_updates_needed=True, verbose=True, minimal_compression_rate=0.0, maximal_compression_rate=0.95, dump_checkpoints=True)[source]#
Bases:
BaseEarlyExitCompressionTrainingLoop
A training loop that automatically adjusts compression rate to reach maximum compression within accuracy budget.
- Parameters:
nncf_config (nncf.NNCFConfig) – The configuration object.
compression_controller (nncf.api.compression.CompressionAlgorithmController) – The controller for the compression algorithm that is currently applied to the model to be trained.
uncompressed_model_accuracy (float) – The uncompressed model accuracy, measured outside of this training loop to serve as the point of reference for fine-tuning the compressed model.
lr_updates_needed (bool) –
verbose (bool) – Whether to post additional data to TensorBoard.
minimal_compression_rate (float) – Sets the minimal compression rate to be considered during the training loop.
maximal_compression_rate (float) – Sets the maximal compression rate to be considered during the training loop.
dump_checkpoints (bool) – If true, will dump all checkpoints obtained during the training process, otherwise will only keep the best checkpoint (accuracy-wise).
- run(model, train_epoch_fn, validate_fn, configure_optimizers_fn=None, dump_checkpoint_fn=None, load_checkpoint_fn=None, early_stopping_fn=None, tensorboard_writer=None, log_dir=None, update_learning_rate_fn=None)[source]#
Implements the custom logic to run a training loop for model fine-tuning by using the provided train_epoch_fn, validate_fn and configure_optimizers_fn methods.
- Parameters:
model (TModel) – The model instance before fine-tuning
train_epoch_fn (Callable) – a callback to fine-tune the model for a single epoch
validate_fn (Callable) – a callback to evaluate the model on the validation dataset
configure_optimizers_fn (Callable) – a callback to instantiate an optimizer and a learning rate scheduler
dump_checkpoint_fn (Callable) – a callback to dump a checkpoint
load_checkpoint_fn (Callable) – a callback to load a checkpoint
early_stopping_fn (Callable) – a callback to check for an early stopping condition
tensorboard_writer (Optional[TensorboardWriterType]) – The tensorboard object to be used for logging.
log_dir (Union[pathlib.Path, str]) – The path to be used for logging and checkpoint saving.
update_learning_rate_fn (Callable) – The callback to update the learning rate after each epoch of the training loop.
- Returns:
The fine-tuned model.