39 from hydra.core.hydra_config
import HydraConfig
42 warnings.filterwarnings(
"ignore")
43 os.environ[
'TF_CPP_MIN_LOG_LEVEL'] =
'3'
45 import tensorflow
as tf
46 from omegaconf
import DictConfig
50 from clearml
import Task
51 from clearml.backend_config.defs
import get_active_config_file
54 SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
55 sys.path.append(os.path.dirname(SCRIPT_DIR))
57 from common.utils
import mlflow_ini, set_gpu_memory_limit, get_random_seed, display_figures, log_to_file
58 from common.benchmarking
import benchmark, cloud_connect
59 from common.evaluation
import gen_load_val
60 from common.prediction
import gen_load_val_predict
61 from src.preprocessing
import preprocess
62 from src.utils
import get_config
63 from src.training
import train
64 from src.evaluation
import evaluate
65 from src.quantization
import quantize
66 from src.prediction
import predict
67 from deployment
import deploy, deploy_mpu
69 from typing
import Optional
72 def chain_qd(cfg: DictConfig =
None, float_model_path: str =
None,
73 train_ds: tf.data.Dataset =
None,
74 quantization_ds: tf.data.Dataset =
None) ->
None:
76 @brief Executes the Quantization → Deployment pipeline (chain_qd).
79 This chain is used when a float model is already trained and only needs to be
80 quantized and then deployed onto the STM32N6 board.
82 Quantization strategy (in order of priority):
83 1. Use the dedicated quantization dataset if provided.
84 2. Fall back to the training dataset if no quantization dataset is available.
85 3. Use fake (synthetic) data if neither dataset is provided — accuracy will be degraded.
87 After quantization, the model is deployed:
88 - On MPU targets via deploy_mpu().
89 - On MCU targets (e.g., STM32N6570-DK) via deploy().
91 @param cfg Hydra configuration dictionary loaded from user_config.yaml.
92 @param float_model_path Path to the float32 model file (.tflite, .h5, or .onnx).
93 @param train_ds TensorFlow dataset used as fallback for quantization calibration.
94 @param quantization_ds Dedicated TensorFlow dataset for INT8 quantization calibration.
100 if cfg.tools.stm32ai.on_cloud:
101 _, _, credentials = cloud_connect(stm32ai_version=cfg.tools.stm32ai.version)
105 print(
'[INFO] : Using the quantization dataset to quantize the model.')
106 quantized_model_path =
quantize(cfg=cfg, quantization_ds=quantization_ds,
107 float_model_path=float_model_path)
109 print(
'[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
110 quantized_model_path =
quantize(cfg=cfg, quantization_ds=train_ds,
111 float_model_path=float_model_path)
113 print(
'[INFO] : Neither quantization dataset nor training set are provided! '
114 'Using fake data to quantize the model. The model performance will not be accurate.')
115 quantized_model_path =
quantize(cfg=cfg, fake=
True)
116 print(
'[INFO] : Quantization complete.')
119 if cfg.hardware_type ==
"MPU":
120 deploy_mpu(cfg=cfg, model_path_to_deploy=quantized_model_path, credentials=credentials)
122 deploy(cfg=cfg, model_path_to_deploy=quantized_model_path, credentials=credentials)
124 print(
'[INFO] : Deployment complete.')
126 if cfg.deployment.hardware_setup.board ==
"STM32N6570-DK":
127 print(
'[INFO] : On STM32N6570-DK, please toggle the boot switches to the left and power cycle the board.')
130 def chain_eqeb(cfg: DictConfig =
None, float_model_path: str =
None,
131 train_ds: tf.data.Dataset =
None,
132 valid_ds: tf.data.Dataset =
None,
133 quantization_ds: tf.data.Dataset =
None,
134 test_ds: tf.data.Dataset =
None) ->
None:
136 @brief Executes the Evaluation → Quantization → Evaluation → Benchmarking pipeline (chain_eqeb).
139 This chain is used to fully characterize both the float and quantized versions of a model:
140 1. Evaluate the float model to establish a baseline accuracy.
141 2. Quantize to INT8 using the provided calibration dataset.
142 3. Evaluate the quantized model to measure accuracy degradation.
143 4. Benchmark the quantized model on the target STM32 board to measure real-world latency.
145 @param cfg Hydra configuration dictionary.
146 @param float_model_path Path to the float32 model.
147 @param train_ds Training dataset (used as fallback for quantization calibration).
148 @param valid_ds Validation dataset for evaluation.
149 @param quantization_ds Dedicated calibration dataset for INT8 quantization.
150 @param test_ds Test dataset (takes priority over valid_ds for evaluation).
155 if cfg.tools.stm32ai.on_cloud:
156 _, _, credentials = cloud_connect(stm32ai_version=cfg.tools.stm32ai.version)
160 evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=float_model_path, name_ds=
"test_set")
162 evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=float_model_path, name_ds=
"validation_set")
163 print(
'[INFO] : Evaluation complete.')
168 print(
'[INFO] : Using the quantization dataset to quantize the model.')
169 quantized_model_path =
quantize(cfg=cfg, quantization_ds=quantization_ds,
170 float_model_path=float_model_path)
172 print(
'[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
173 quantized_model_path =
quantize(cfg=cfg, quantization_ds=train_ds,
174 float_model_path=float_model_path)
175 print(
'[INFO] : Quantization complete.')
179 evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=quantized_model_path, name_ds=
"test_set")
181 evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=quantized_model_path, name_ds=
"validation_set")
182 print(
'[INFO] : Evaluation complete.')
186 benchmark(cfg=cfg, model_path_to_benchmark=quantized_model_path, credentials=credentials)
187 print(
'[INFO] : Benchmarking complete.')
190 def chain_qb(cfg: DictConfig =
None, float_model_path: str =
None,
191 train_ds: tf.data.Dataset =
None,
192 quantization_ds: tf.data.Dataset =
None) ->
None:
194 @brief Executes the Quantization → Benchmarking pipeline (chain_qb).
197 Useful when accuracy evaluation is not needed and the goal is to quickly measure
198 the on-device performance of a quantized model (latency, memory usage).
200 @param cfg Hydra configuration dictionary.
201 @param float_model_path Path to the float32 model to quantize.
202 @param train_ds Training dataset (fallback for quantization calibration).
203 @param quantization_ds Dedicated calibration dataset for INT8 quantization.
208 if cfg.tools.stm32ai.on_cloud:
209 _, _, credentials = cloud_connect(stm32ai_version=cfg.tools.stm32ai.version)
212 print(
'[INFO] : Using the quantization dataset to quantize the model.')
213 quantized_model_path =
quantize(cfg=cfg, quantization_ds=quantization_ds,
214 float_model_path=float_model_path)
216 print(
'[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
217 quantized_model_path =
quantize(cfg=cfg, quantization_ds=train_ds,
218 float_model_path=float_model_path)
220 print(
'[INFO] : Neither quantization dataset nor training set are provided! '
221 'Using fake data to quantize the model. The model performance will not be accurate.')
222 quantized_model_path =
quantize(cfg=cfg, fake=
True)
223 print(
'[INFO] : Quantization complete.')
225 benchmark(cfg=cfg, model_path_to_benchmark=quantized_model_path, credentials=credentials)
226 print(
'[INFO] : Benchmarking complete.')
229 def chain_eqe(cfg: DictConfig =
None, float_model_path: str =
None,
230 train_ds: tf.data.Dataset =
None,
231 valid_ds: tf.data.Dataset =
None,
232 quantization_ds: tf.data.Dataset =
None,
233 test_ds: tf.data.Dataset =
None) ->
None:
235 @brief Executes the Evaluation → Quantization → Evaluation pipeline (chain_eqe).
238 Evaluates accuracy before and after INT8 quantization to measure the accuracy
239 degradation introduced by the quantization process. No on-device benchmarking.
241 @param cfg Hydra configuration dictionary.
242 @param float_model_path Path to the float32 model.
243 @param train_ds Training dataset (fallback for quantization calibration).
244 @param valid_ds Validation dataset for evaluation.
245 @param quantization_ds Dedicated calibration dataset for INT8 quantization.
246 @param test_ds Test dataset (takes priority over valid_ds).
252 evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=float_model_path, name_ds=
"test_set")
254 evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=float_model_path, name_ds=
"validation_set")
255 print(
'[INFO] : Evaluation complete.')
260 print(
'[INFO] : Using the quantization dataset to quantize the model.')
261 quantized_model_path =
quantize(cfg=cfg, quantization_ds=quantization_ds,
262 float_model_path=float_model_path)
264 print(
'[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
265 quantized_model_path =
quantize(cfg=cfg, quantization_ds=train_ds,
266 float_model_path=float_model_path)
267 print(
'[INFO] : Quantization complete.')
271 evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=quantized_model_path, name_ds=
"test_set")
273 evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=quantized_model_path, name_ds=
"validation_set")
274 print(
'[INFO] : Evaluation complete.')
278 def chain_tqeb(cfg: DictConfig =
None, train_ds: tf.data.Dataset =
None,
279 valid_ds: tf.data.Dataset =
None,
280 quantization_ds: tf.data.Dataset =
None,
281 test_ds: tf.data.Dataset =
None) ->
None:
283 @brief Executes the full Training → Quantization → Evaluation → Benchmarking pipeline (chain_tqeb).
286 This is the most complete pipeline, covering the entire model lifecycle from training
287 to on-device performance measurement. It is particularly useful when starting from scratch
288 or when fine-tuning a model for a new dataset.
291 1. Train the model on the provided training dataset.
292 2. Quantize the trained model to INT8.
293 3. Evaluate the quantized model for accuracy.
294 4. Benchmark on the target STM32 board.
296 @param cfg Hydra configuration dictionary.
297 @param train_ds Training dataset.
298 @param valid_ds Validation dataset.
299 @param quantization_ds Dedicated calibration dataset (falls back to train_ds if not provided).
300 @param test_ds Test dataset (takes priority over valid_ds for evaluation).
305 if cfg.tools.stm32ai.on_cloud:
306 _, _, credentials = cloud_connect(stm32ai_version=cfg.tools.stm32ai.version)
310 trained_model_path = train(cfg=cfg, train_ds=train_ds, valid_ds=valid_ds, test_ds=test_ds)
312 trained_model_path = train(cfg=cfg, train_ds=train_ds, valid_ds=valid_ds)
313 print(
'[INFO] : Training complete.')
317 print(
'[INFO] : Using the quantization dataset to quantize the model.')
318 quantized_model_path =
quantize(cfg=cfg, quantization_ds=quantization_ds,
319 float_model_path=trained_model_path)
321 print(
'[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
322 quantized_model_path =
quantize(cfg=cfg, quantization_ds=train_ds,
323 float_model_path=trained_model_path)
324 print(
'[INFO] : Quantization complete.')
328 evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=quantized_model_path, name_ds=
"test_set")
330 evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=quantized_model_path, name_ds=
"validation_set")
331 print(
'[INFO] : Evaluation complete.')
335 benchmark(cfg=cfg, model_path_to_benchmark=quantized_model_path, credentials=credentials)
336 print(
'[INFO] : Benchmarking complete.')
339 def chain_tqe(cfg: DictConfig =
None, train_ds: tf.data.Dataset =
None,
340 valid_ds: tf.data.Dataset =
None,
341 quantization_ds: tf.data.Dataset =
None,
342 test_ds: tf.data.Dataset =
None) ->
None:
344 @brief Executes the Training → Quantization → Evaluation pipeline (chain_tqe).
347 Similar to chain_tqeb but without the final on-device benchmarking step.
348 Useful when the goal is to verify accuracy after quantization without needing
349 to connect a physical STM32 board.
351 @param cfg Hydra configuration dictionary.
352 @param train_ds Training dataset.
353 @param valid_ds Validation dataset.
354 @param quantization_ds Dedicated calibration dataset (falls back to train_ds if not provided).
355 @param test_ds Test dataset (takes priority over valid_ds for evaluation).
360 trained_model_path = train(cfg=cfg, train_ds=train_ds, valid_ds=valid_ds, test_ds=test_ds)
362 trained_model_path = train(cfg=cfg, train_ds=train_ds, valid_ds=valid_ds)
363 print(
'[INFO] : Training complete.')
366 print(
'[INFO] : Using the quantization dataset to quantize the model.')
367 quantized_model_path =
quantize(cfg=cfg, quantization_ds=quantization_ds,
368 float_model_path=trained_model_path)
370 print(
'[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
371 quantized_model_path =
quantize(cfg=cfg, quantization_ds=train_ds,
372 float_model_path=trained_model_path)
373 print(
'[INFO] : Quantization complete.')
376 evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=quantized_model_path, name_ds=
"test_set")
378 evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=quantized_model_path, name_ds=
"validation_set")
379 print(
'[INFO] : Evaluation complete.')
384 configs: DictConfig =
None,
385 train_ds: tf.data.Dataset =
None,
386 valid_ds: tf.data.Dataset =
None,
387 quantization_ds: tf.data.Dataset =
None,
388 test_ds: tf.data.Dataset =
None,
389 float_model_path: Optional[str] =
None,
390 fake: Optional[bool] =
False) ->
None:
392 @brief Dispatches execution to the appropriate pipeline based on the operation mode.
395 This function acts as a central dispatcher. It reads the `operation_mode` field
396 from the configuration and calls the corresponding function or chain.
399 - 'training' : Train a model.
400 - 'evaluation' : Evaluate model accuracy on a dataset.
401 - 'quantization' : Quantize a float model to INT8.
402 - 'deployment' : Deploy the model onto the STM32 board (generates C code, compiles, flashes).
403 - 'prediction' : Run inference on new input data.
404 - 'benchmarking' : Measure on-device performance metrics.
405 - 'chain_tqeb' : Training → Quantization → Evaluation → Benchmarking.
406 - 'chain_tqe' : Training → Quantization → Evaluation.
407 - 'chain_eqe' : Evaluation → Quantization → Evaluation.
408 - 'chain_qb' : Quantization → Benchmarking.
409 - 'chain_eqeb' : Evaluation → Quantization → Evaluation → Benchmarking.
410 - 'chain_qd' : Quantization → Deployment.
412 @note In deployment mode for STM32N6570-DK, after flashing the user must manually
413 toggle the boot switches and power-cycle the board.
415 @param mode Operation mode string (e.g., 'deployment', 'chain_qd').
416 @param configs Hydra configuration dictionary.
417 @param train_ds Training TensorFlow dataset.
418 @param valid_ds Validation TensorFlow dataset.
419 @param quantization_ds Calibration dataset for INT8 quantization.
420 @param test_ds Test TensorFlow dataset.
421 @param float_model_path Path to the float32 model file.
422 @param fake If True, use synthetic data for quantization calibration.
425 @throws ValueError if an unsupported operation_mode is provided.
427 mlflow.log_param(
"model_path", configs.general.model_path)
428 log_to_file(configs.output_dir, f
'operation_mode: {mode}')
430 if mode ==
'training':
432 train(cfg=configs, train_ds=train_ds, valid_ds=valid_ds, test_ds=test_ds)
434 train(cfg=configs, train_ds=train_ds, valid_ds=valid_ds)
435 display_figures(configs)
436 print(
'[INFO] : Training complete.')
438 elif mode ==
'evaluation':
440 gen_load_val(cfg=configs)
441 os.chdir(os.path.dirname(os.path.realpath(__file__)))
443 evaluate(cfg=configs, eval_ds=test_ds, name_ds=
"test_set")
445 evaluate(cfg=configs, eval_ds=valid_ds, name_ds=
"validation_set")
446 display_figures(configs)
447 print(
'[INFO] : Evaluation complete.')
449 elif mode ==
'deployment':
451 if configs.hardware_type ==
"MPU":
455 print(
'[INFO] : Deployment complete.')
456 if configs.deployment.hardware_setup.board ==
"STM32N6570-DK":
457 print(
'[INFO] : On STM32N6570-DK, please toggle the boot switches to the left and power cycle the board.')
459 elif mode ==
'quantization':
462 input_ds = quantization_ds
464 print(
'[INFO] : Using the quantization dataset to quantize the model.')
466 print(
'[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
472 print(
'[INFO] : Neither quantization dataset nor training set are provided! '
473 'Using fake data to quantize the model. The model performance will not be accurate.')
474 quantize(cfg=configs, quantization_ds=input_ds, fake=fake)
475 print(
'[INFO] : Quantization complete.')
477 elif mode ==
'prediction':
479 gen_load_val_predict(cfg=configs)
480 os.chdir(os.path.dirname(os.path.realpath(__file__)))
482 print(
'[INFO] : Prediction complete.')
484 elif mode ==
'benchmarking':
485 benchmark(cfg=configs)
486 print(
'[INFO] : Benchmark complete.')
488 elif mode ==
'chain_tqeb':
489 chain_tqeb(cfg=configs, train_ds=train_ds, valid_ds=valid_ds,
490 quantization_ds=quantization_ds, test_ds=test_ds)
491 print(
'[INFO] : chain_tqeb complete.')
493 elif mode ==
'chain_tqe':
494 chain_tqe(cfg=configs, train_ds=train_ds, valid_ds=valid_ds,
495 quantization_ds=quantization_ds, test_ds=test_ds)
496 print(
'[INFO] : chain_tqe complete.')
498 elif mode ==
'chain_eqe':
499 chain_eqe(cfg=configs, float_model_path=float_model_path, train_ds=train_ds,
500 valid_ds=valid_ds, quantization_ds=quantization_ds, test_ds=test_ds)
501 print(
'[INFO] : chain_eqe complete.')
503 elif mode ==
'chain_qb':
504 chain_qb(cfg=configs, float_model_path=float_model_path, train_ds=train_ds,
505 quantization_ds=quantization_ds)
506 print(
'[INFO] : chain_qb complete.')
508 elif mode ==
'chain_eqeb':
509 chain_eqeb(cfg=configs, float_model_path=float_model_path, train_ds=train_ds,
510 valid_ds=valid_ds, quantization_ds=quantization_ds, test_ds=test_ds)
511 print(
'[INFO] : chain_eqeb complete.')
513 elif mode ==
'chain_qd':
514 chain_qd(cfg=configs, float_model_path=float_model_path, train_ds=train_ds,
515 quantization_ds=quantization_ds)
516 print(
'[INFO] : chain_qd complete.')
519 raise ValueError(f
"Invalid mode: {mode}")
522 mlflow.log_artifact(configs.output_dir)
523 if mode
in [
'benchmarking',
'chain_qb',
'chain_eqeb',
'chain_tqeb']:
524 mlflow.log_param(
"stm32ai_version", configs.tools.stm32ai.version)
525 mlflow.log_param(
"target", configs.benchmarking.board)
526 log_to_file(configs.output_dir, f
'operation finished: {mode}')
529 if get_active_config_file()
is not None:
530 print(f
"[INFO] : ClearML task connection")
531 task = Task.current_task()
532 task.connect(configs)
535 @hydra.main(version_base=None, config_path="", config_name="user_config")
536 def main(cfg: DictConfig) ->
None:
538 @brief Main entry point of the STM32AI Model Zoo Services script.
541 This function is decorated with @hydra.main, which means Hydra automatically
542 loads the configuration from `user_config.yaml` and passes it as a DictConfig object.
545 1. Configure GPU memory limits (if specified in the config).
546 2. Parse and validate the full configuration via get_config().
547 3. Initialize MLflow experiment tracking.
548 4. Optionally initialize ClearML task tracking.
549 5. Set the global random seed for reproducibility.
550 6. Load and preprocess datasets (if required by the selected mode).
551 7. Dispatch to process_mode() based on cfg.operation_mode.
553 @note The operation mode is read from the YAML field `operation_mode`.
554 Modes requiring datasets (training, evaluation, etc.) will call preprocess()
555 to load and prepare the data. Modes like deployment do not require datasets.
557 @param cfg Hydra DictConfig object automatically populated from user_config.yaml.
562 if "general" in cfg
and cfg.general:
563 if "gpu_memory_limit" in cfg.general
and cfg.general.gpu_memory_limit:
564 set_gpu_memory_limit(cfg.general.gpu_memory_limit)
565 print(f
"[INFO] Setting upper limit of usable GPU memory to {int(cfg.general.gpu_memory_limit)}GBytes.")
567 print(
"[WARNING] The usable GPU memory is unlimited.\n"
568 "Please consider setting the 'gpu_memory_limit' attribute "
569 "in the 'general' section of your configuration file.")
573 cfg.output_dir = HydraConfig.get().run.dir
577 print(f
"[INFO] : ClearML config check")
578 if get_active_config_file()
is not None:
579 print(f
"[INFO] : ClearML initialization and configuration")
580 task = Task.init(project_name=cfg.general.project_name, task_name=
'pe_modelzoo_task')
581 task.connect_configuration(name=cfg.operation_mode, configuration=cfg)
584 seed = get_random_seed(cfg)
585 print(f
'[INFO] : The random seed for this simulation is {seed}')
587 tf.keras.utils.set_random_seed(seed)
590 mode = cfg.operation_mode
593 valid_modes = [
'training',
'evaluation',
'chain_tqeb',
'chain_tqe']
594 if mode
in valid_modes:
596 train_ds, valid_ds, quantization_ds, test_ds = preprocess_output
597 process_mode(mode=mode, configs=cfg, train_ds=train_ds, valid_ds=valid_ds,
598 quantization_ds=quantization_ds, test_ds=test_ds)
600 elif mode ==
'quantization':
602 if cfg.dataset.training_path
or cfg.dataset.quantization_path:
604 train_ds, valid_ds, quantization_ds, test_ds = preprocess_output
605 process_mode(mode=mode, configs=cfg, train_ds=train_ds, valid_ds=valid_ds,
606 quantization_ds=quantization_ds, test_ds=test_ds)
612 if mode
in [
'chain_eqe',
'chain_qb',
'chain_eqeb',
'chain_qd']:
613 if cfg.dataset.training_path
or cfg.dataset.quantization_path:
615 train_ds, valid_ds, quantization_ds, test_ds = preprocess_output
617 train_ds = valid_ds = quantization_ds = test_ds =
None
618 process_mode(mode=mode, configs=cfg, train_ds=train_ds, valid_ds=valid_ds,
619 quantization_ds=quantization_ds, test_ds=test_ds,
620 float_model_path=cfg.general.model_path)
626 if __name__ ==
"__main__":
627 parser = argparse.ArgumentParser()
628 parser.add_argument(
'--config-path', type=str, default=
'',
629 help=
'Path to folder containing configuration file')
630 parser.add_argument(
'--config-name', type=str, default=
'user_config',
631 help=
'Name of the configuration file')
632 parser.add_argument(
'params', nargs=
'*',
633 help=
'List of parameters to override in user_config.yaml')
634 args = parser.parse_args()
638 mlflow.log_param(
'config_path', args.config_path)
639 mlflow.log_param(
'config_name', args.config_name)
None deploy_mpu(DictConfig cfg=None, Optional[str] model_path_to_deploy=None, list credentials=None)
DefaultMunch get_config(DictConfig config_data)
None chain_qd(DictConfig cfg=None, str float_model_path=None, tf.data.Dataset train_ds=None, tf.data.Dataset quantization_ds=None)
None chain_qb(DictConfig cfg=None, str float_model_path=None, tf.data.Dataset train_ds=None, tf.data.Dataset quantization_ds=None)
None chain_eqeb(DictConfig cfg=None, str float_model_path=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None)
None chain_tqeb(DictConfig cfg=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None)
None process_mode(str mode=None, DictConfig configs=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None, Optional[str] float_model_path=None, Optional[bool] fake=False)
None chain_tqe(DictConfig cfg=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None)
None chain_eqe(DictConfig cfg=None, str float_model_path=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None)
None main(DictConfig cfg)