STM32N6 NPU Deployment — Politecnico di Milano  1.0
Documentation for Neural Network Deployment on STM32N6 NPU - Politecnico di Milano 2024-2025
stm32ai_main.py
Go to the documentation of this file.
1 
36 
37 import os
38 import sys
39 from hydra.core.hydra_config import HydraConfig
40 import hydra
41 import warnings
42 warnings.filterwarnings("ignore")
43 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
44 
45 import tensorflow as tf
46 from omegaconf import DictConfig
47 import mlflow
48 import argparse
49 import logging
50 from clearml import Task
51 from clearml.backend_config.defs import get_active_config_file
52 
53 
54 SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
55 sys.path.append(os.path.dirname(SCRIPT_DIR))
56 
57 from common.utils import mlflow_ini, set_gpu_memory_limit, get_random_seed, display_figures, log_to_file
58 from common.benchmarking import benchmark, cloud_connect
59 from common.evaluation import gen_load_val
60 from common.prediction import gen_load_val_predict
61 from src.preprocessing import preprocess
62 from src.utils import get_config
63 from src.training import train
64 from src.evaluation import evaluate
65 from src.quantization import quantize
66 from src.prediction import predict
67 from deployment import deploy, deploy_mpu
68 
69 from typing import Optional
70 
71 
72 def chain_qd(cfg: DictConfig = None, float_model_path: str = None,
73  train_ds: tf.data.Dataset = None,
74  quantization_ds: tf.data.Dataset = None) -> None:
75  """
76  @brief Executes the Quantization → Deployment pipeline (chain_qd).
77 
78  @details
79  This chain is used when a float model is already trained and only needs to be
80  quantized and then deployed onto the STM32N6 board.
81 
82  Quantization strategy (in order of priority):
83  1. Use the dedicated quantization dataset if provided.
84  2. Fall back to the training dataset if no quantization dataset is available.
85  3. Use fake (synthetic) data if neither dataset is provided — accuracy will be degraded.
86 
87  After quantization, the model is deployed:
88  - On MPU targets via deploy_mpu().
89  - On MCU targets (e.g., STM32N6570-DK) via deploy().
90 
91  @param cfg Hydra configuration dictionary loaded from user_config.yaml.
92  @param float_model_path Path to the float32 model file (.tflite, .h5, or .onnx).
93  @param train_ds TensorFlow dataset used as fallback for quantization calibration.
94  @param quantization_ds Dedicated TensorFlow dataset for INT8 quantization calibration.
95 
96  @return None
97  """
98  # Optionally connect to STM32Cube.AI Developer Cloud for remote benchmarking/deployment
99  credentials = None
100  if cfg.tools.stm32ai.on_cloud:
101  _, _, credentials = cloud_connect(stm32ai_version=cfg.tools.stm32ai.version)
102 
103  # Select the calibration dataset for quantization
104  if quantization_ds:
105  print('[INFO] : Using the quantization dataset to quantize the model.')
106  quantized_model_path = quantize(cfg=cfg, quantization_ds=quantization_ds,
107  float_model_path=float_model_path)
108  elif train_ds:
109  print('[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
110  quantized_model_path = quantize(cfg=cfg, quantization_ds=train_ds,
111  float_model_path=float_model_path)
112  else:
113  print('[INFO] : Neither quantization dataset nor training set are provided! '
114  'Using fake data to quantize the model. The model performance will not be accurate.')
115  quantized_model_path = quantize(cfg=cfg, fake=True)
116  print('[INFO] : Quantization complete.')
117 
118  # Deploy to MPU or MCU depending on hardware_type in the config
119  if cfg.hardware_type == "MPU":
120  deploy_mpu(cfg=cfg, model_path_to_deploy=quantized_model_path, credentials=credentials)
121  else:
122  deploy(cfg=cfg, model_path_to_deploy=quantized_model_path, credentials=credentials)
123 
124  print('[INFO] : Deployment complete.')
125  # Remind the user to toggle boot switches on the STM32N6570-DK after flashing
126  if cfg.deployment.hardware_setup.board == "STM32N6570-DK":
127  print('[INFO] : On STM32N6570-DK, please toggle the boot switches to the left and power cycle the board.')
128 
129 
130 def chain_eqeb(cfg: DictConfig = None, float_model_path: str = None,
131  train_ds: tf.data.Dataset = None,
132  valid_ds: tf.data.Dataset = None,
133  quantization_ds: tf.data.Dataset = None,
134  test_ds: tf.data.Dataset = None) -> None:
135  """
136  @brief Executes the Evaluation → Quantization → Evaluation → Benchmarking pipeline (chain_eqeb).
137 
138  @details
139  This chain is used to fully characterize both the float and quantized versions of a model:
140  1. Evaluate the float model to establish a baseline accuracy.
141  2. Quantize to INT8 using the provided calibration dataset.
142  3. Evaluate the quantized model to measure accuracy degradation.
143  4. Benchmark the quantized model on the target STM32 board to measure real-world latency.
144 
145  @param cfg Hydra configuration dictionary.
146  @param float_model_path Path to the float32 model.
147  @param train_ds Training dataset (used as fallback for quantization calibration).
148  @param valid_ds Validation dataset for evaluation.
149  @param quantization_ds Dedicated calibration dataset for INT8 quantization.
150  @param test_ds Test dataset (takes priority over valid_ds for evaluation).
151 
152  @return None
153  """
154  credentials = None
155  if cfg.tools.stm32ai.on_cloud:
156  _, _, credentials = cloud_connect(stm32ai_version=cfg.tools.stm32ai.version)
157 
158  # Step 1: Evaluate the float model
159  if test_ds:
160  evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=float_model_path, name_ds="test_set")
161  else:
162  evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=float_model_path, name_ds="validation_set")
163  print('[INFO] : Evaluation complete.')
164  display_figures(cfg)
165 
166  # Step 2: Quantize the model to INT8
167  if quantization_ds:
168  print('[INFO] : Using the quantization dataset to quantize the model.')
169  quantized_model_path = quantize(cfg=cfg, quantization_ds=quantization_ds,
170  float_model_path=float_model_path)
171  else:
172  print('[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
173  quantized_model_path = quantize(cfg=cfg, quantization_ds=train_ds,
174  float_model_path=float_model_path)
175  print('[INFO] : Quantization complete.')
176 
177  # Step 3: Evaluate the quantized model
178  if test_ds:
179  evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=quantized_model_path, name_ds="test_set")
180  else:
181  evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=quantized_model_path, name_ds="validation_set")
182  print('[INFO] : Evaluation complete.')
183  display_figures(cfg)
184 
185  # Step 4: Benchmark on the target board
186  benchmark(cfg=cfg, model_path_to_benchmark=quantized_model_path, credentials=credentials)
187  print('[INFO] : Benchmarking complete.')
188 
189 
190 def chain_qb(cfg: DictConfig = None, float_model_path: str = None,
191  train_ds: tf.data.Dataset = None,
192  quantization_ds: tf.data.Dataset = None) -> None:
193  """
194  @brief Executes the Quantization → Benchmarking pipeline (chain_qb).
195 
196  @details
197  Useful when accuracy evaluation is not needed and the goal is to quickly measure
198  the on-device performance of a quantized model (latency, memory usage).
199 
200  @param cfg Hydra configuration dictionary.
201  @param float_model_path Path to the float32 model to quantize.
202  @param train_ds Training dataset (fallback for quantization calibration).
203  @param quantization_ds Dedicated calibration dataset for INT8 quantization.
204 
205  @return None
206  """
207  credentials = None
208  if cfg.tools.stm32ai.on_cloud:
209  _, _, credentials = cloud_connect(stm32ai_version=cfg.tools.stm32ai.version)
210 
211  if quantization_ds:
212  print('[INFO] : Using the quantization dataset to quantize the model.')
213  quantized_model_path = quantize(cfg=cfg, quantization_ds=quantization_ds,
214  float_model_path=float_model_path)
215  elif train_ds:
216  print('[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
217  quantized_model_path = quantize(cfg=cfg, quantization_ds=train_ds,
218  float_model_path=float_model_path)
219  else:
220  print('[INFO] : Neither quantization dataset nor training set are provided! '
221  'Using fake data to quantize the model. The model performance will not be accurate.')
222  quantized_model_path = quantize(cfg=cfg, fake=True)
223  print('[INFO] : Quantization complete.')
224 
225  benchmark(cfg=cfg, model_path_to_benchmark=quantized_model_path, credentials=credentials)
226  print('[INFO] : Benchmarking complete.')
227 
228 
229 def chain_eqe(cfg: DictConfig = None, float_model_path: str = None,
230  train_ds: tf.data.Dataset = None,
231  valid_ds: tf.data.Dataset = None,
232  quantization_ds: tf.data.Dataset = None,
233  test_ds: tf.data.Dataset = None) -> None:
234  """
235  @brief Executes the Evaluation → Quantization → Evaluation pipeline (chain_eqe).
236 
237  @details
238  Evaluates accuracy before and after INT8 quantization to measure the accuracy
239  degradation introduced by the quantization process. No on-device benchmarking.
240 
241  @param cfg Hydra configuration dictionary.
242  @param float_model_path Path to the float32 model.
243  @param train_ds Training dataset (fallback for quantization calibration).
244  @param valid_ds Validation dataset for evaluation.
245  @param quantization_ds Dedicated calibration dataset for INT8 quantization.
246  @param test_ds Test dataset (takes priority over valid_ds).
247 
248  @return None
249  """
250  # Evaluate float model
251  if test_ds:
252  evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=float_model_path, name_ds="test_set")
253  else:
254  evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=float_model_path, name_ds="validation_set")
255  print('[INFO] : Evaluation complete.')
256  display_figures(cfg)
257 
258  # Quantize
259  if quantization_ds:
260  print('[INFO] : Using the quantization dataset to quantize the model.')
261  quantized_model_path = quantize(cfg=cfg, quantization_ds=quantization_ds,
262  float_model_path=float_model_path)
263  else:
264  print('[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
265  quantized_model_path = quantize(cfg=cfg, quantization_ds=train_ds,
266  float_model_path=float_model_path)
267  print('[INFO] : Quantization complete.')
268 
269  # Evaluate quantized model
270  if test_ds:
271  evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=quantized_model_path, name_ds="test_set")
272  else:
273  evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=quantized_model_path, name_ds="validation_set")
274  print('[INFO] : Evaluation complete.')
275  display_figures(cfg)
276 
277 
278 def chain_tqeb(cfg: DictConfig = None, train_ds: tf.data.Dataset = None,
279  valid_ds: tf.data.Dataset = None,
280  quantization_ds: tf.data.Dataset = None,
281  test_ds: tf.data.Dataset = None) -> None:
282  """
283  @brief Executes the full Training → Quantization → Evaluation → Benchmarking pipeline (chain_tqeb).
284 
285  @details
286  This is the most complete pipeline, covering the entire model lifecycle from training
287  to on-device performance measurement. It is particularly useful when starting from scratch
288  or when fine-tuning a model for a new dataset.
289 
290  Pipeline steps:
291  1. Train the model on the provided training dataset.
292  2. Quantize the trained model to INT8.
293  3. Evaluate the quantized model for accuracy.
294  4. Benchmark on the target STM32 board.
295 
296  @param cfg Hydra configuration dictionary.
297  @param train_ds Training dataset.
298  @param valid_ds Validation dataset.
299  @param quantization_ds Dedicated calibration dataset (falls back to train_ds if not provided).
300  @param test_ds Test dataset (takes priority over valid_ds for evaluation).
301 
302  @return None
303  """
304  credentials = None
305  if cfg.tools.stm32ai.on_cloud:
306  _, _, credentials = cloud_connect(stm32ai_version=cfg.tools.stm32ai.version)
307 
308  # Step 1: Train
309  if test_ds:
310  trained_model_path = train(cfg=cfg, train_ds=train_ds, valid_ds=valid_ds, test_ds=test_ds)
311  else:
312  trained_model_path = train(cfg=cfg, train_ds=train_ds, valid_ds=valid_ds)
313  print('[INFO] : Training complete.')
314 
315  # Step 2: Quantize
316  if quantization_ds:
317  print('[INFO] : Using the quantization dataset to quantize the model.')
318  quantized_model_path = quantize(cfg=cfg, quantization_ds=quantization_ds,
319  float_model_path=trained_model_path)
320  else:
321  print('[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
322  quantized_model_path = quantize(cfg=cfg, quantization_ds=train_ds,
323  float_model_path=trained_model_path)
324  print('[INFO] : Quantization complete.')
325 
326  # Step 3: Evaluate quantized model
327  if test_ds:
328  evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=quantized_model_path, name_ds="test_set")
329  else:
330  evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=quantized_model_path, name_ds="validation_set")
331  print('[INFO] : Evaluation complete.')
332  display_figures(cfg)
333 
334  # Step 4: Benchmark on target board
335  benchmark(cfg=cfg, model_path_to_benchmark=quantized_model_path, credentials=credentials)
336  print('[INFO] : Benchmarking complete.')
337 
338 
339 def chain_tqe(cfg: DictConfig = None, train_ds: tf.data.Dataset = None,
340  valid_ds: tf.data.Dataset = None,
341  quantization_ds: tf.data.Dataset = None,
342  test_ds: tf.data.Dataset = None) -> None:
343  """
344  @brief Executes the Training → Quantization → Evaluation pipeline (chain_tqe).
345 
346  @details
347  Similar to chain_tqeb but without the final on-device benchmarking step.
348  Useful when the goal is to verify accuracy after quantization without needing
349  to connect a physical STM32 board.
350 
351  @param cfg Hydra configuration dictionary.
352  @param train_ds Training dataset.
353  @param valid_ds Validation dataset.
354  @param quantization_ds Dedicated calibration dataset (falls back to train_ds if not provided).
355  @param test_ds Test dataset (takes priority over valid_ds for evaluation).
356 
357  @return None
358  """
359  if test_ds:
360  trained_model_path = train(cfg=cfg, train_ds=train_ds, valid_ds=valid_ds, test_ds=test_ds)
361  else:
362  trained_model_path = train(cfg=cfg, train_ds=train_ds, valid_ds=valid_ds)
363  print('[INFO] : Training complete.')
364 
365  if quantization_ds:
366  print('[INFO] : Using the quantization dataset to quantize the model.')
367  quantized_model_path = quantize(cfg=cfg, quantization_ds=quantization_ds,
368  float_model_path=trained_model_path)
369  else:
370  print('[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
371  quantized_model_path = quantize(cfg=cfg, quantization_ds=train_ds,
372  float_model_path=trained_model_path)
373  print('[INFO] : Quantization complete.')
374 
375  if test_ds:
376  evaluate(cfg=cfg, eval_ds=test_ds, model_path_to_evaluate=quantized_model_path, name_ds="test_set")
377  else:
378  evaluate(cfg=cfg, eval_ds=valid_ds, model_path_to_evaluate=quantized_model_path, name_ds="validation_set")
379  print('[INFO] : Evaluation complete.')
380  display_figures(cfg)
381 
382 
383 def process_mode(mode: str = None,
384  configs: DictConfig = None,
385  train_ds: tf.data.Dataset = None,
386  valid_ds: tf.data.Dataset = None,
387  quantization_ds: tf.data.Dataset = None,
388  test_ds: tf.data.Dataset = None,
389  float_model_path: Optional[str] = None,
390  fake: Optional[bool] = False) -> None:
391  """
392  @brief Dispatches execution to the appropriate pipeline based on the operation mode.
393 
394  @details
395  This function acts as a central dispatcher. It reads the `operation_mode` field
396  from the configuration and calls the corresponding function or chain.
397 
398  Supported modes:
399  - 'training' : Train a model.
400  - 'evaluation' : Evaluate model accuracy on a dataset.
401  - 'quantization' : Quantize a float model to INT8.
402  - 'deployment' : Deploy the model onto the STM32 board (generates C code, compiles, flashes).
403  - 'prediction' : Run inference on new input data.
404  - 'benchmarking' : Measure on-device performance metrics.
405  - 'chain_tqeb' : Training → Quantization → Evaluation → Benchmarking.
406  - 'chain_tqe' : Training → Quantization → Evaluation.
407  - 'chain_eqe' : Evaluation → Quantization → Evaluation.
408  - 'chain_qb' : Quantization → Benchmarking.
409  - 'chain_eqeb' : Evaluation → Quantization → Evaluation → Benchmarking.
410  - 'chain_qd' : Quantization → Deployment.
411 
412  @note In deployment mode for STM32N6570-DK, after flashing the user must manually
413  toggle the boot switches and power-cycle the board.
414 
415  @param mode Operation mode string (e.g., 'deployment', 'chain_qd').
416  @param configs Hydra configuration dictionary.
417  @param train_ds Training TensorFlow dataset.
418  @param valid_ds Validation TensorFlow dataset.
419  @param quantization_ds Calibration dataset for INT8 quantization.
420  @param test_ds Test TensorFlow dataset.
421  @param float_model_path Path to the float32 model file.
422  @param fake If True, use synthetic data for quantization calibration.
423 
424  @return None
425  @throws ValueError if an unsupported operation_mode is provided.
426  """
427  mlflow.log_param("model_path", configs.general.model_path)
428  log_to_file(configs.output_dir, f'operation_mode: {mode}')
429 
430  if mode == 'training':
431  if test_ds:
432  train(cfg=configs, train_ds=train_ds, valid_ds=valid_ds, test_ds=test_ds)
433  else:
434  train(cfg=configs, train_ds=train_ds, valid_ds=valid_ds)
435  display_figures(configs)
436  print('[INFO] : Training complete.')
437 
438  elif mode == 'evaluation':
439  # Generate and load the model on the STM32N6 via stedgeai, then validate on device
440  gen_load_val(cfg=configs)
441  os.chdir(os.path.dirname(os.path.realpath(__file__)))
442  if test_ds:
443  evaluate(cfg=configs, eval_ds=test_ds, name_ds="test_set")
444  else:
445  evaluate(cfg=configs, eval_ds=valid_ds, name_ds="validation_set")
446  display_figures(configs)
447  print('[INFO] : Evaluation complete.')
448 
449  elif mode == 'deployment':
450  # Select MPU or MCU deployment path based on hardware_type config field
451  if configs.hardware_type == "MPU":
452  deploy_mpu(cfg=configs)
453  else:
454  deploy(cfg=configs)
455  print('[INFO] : Deployment complete.')
456  if configs.deployment.hardware_setup.board == "STM32N6570-DK":
457  print('[INFO] : On STM32N6570-DK, please toggle the boot switches to the left and power cycle the board.')
458 
459  elif mode == 'quantization':
460  # Select calibration data source
461  if quantization_ds:
462  input_ds = quantization_ds
463  fake = False
464  print('[INFO] : Using the quantization dataset to quantize the model.')
465  elif train_ds:
466  print('[INFO] : Quantization dataset is not provided! Using the training set to quantize the model.')
467  input_ds = train_ds
468  fake = False
469  else:
470  input_ds = None
471  fake = True
472  print('[INFO] : Neither quantization dataset nor training set are provided! '
473  'Using fake data to quantize the model. The model performance will not be accurate.')
474  quantize(cfg=configs, quantization_ds=input_ds, fake=fake)
475  print('[INFO] : Quantization complete.')
476 
477  elif mode == 'prediction':
478  # Load model on device, then run prediction pipeline
479  gen_load_val_predict(cfg=configs)
480  os.chdir(os.path.dirname(os.path.realpath(__file__)))
481  predict(cfg=configs)
482  print('[INFO] : Prediction complete.')
483 
484  elif mode == 'benchmarking':
485  benchmark(cfg=configs)
486  print('[INFO] : Benchmark complete.')
487 
488  elif mode == 'chain_tqeb':
489  chain_tqeb(cfg=configs, train_ds=train_ds, valid_ds=valid_ds,
490  quantization_ds=quantization_ds, test_ds=test_ds)
491  print('[INFO] : chain_tqeb complete.')
492 
493  elif mode == 'chain_tqe':
494  chain_tqe(cfg=configs, train_ds=train_ds, valid_ds=valid_ds,
495  quantization_ds=quantization_ds, test_ds=test_ds)
496  print('[INFO] : chain_tqe complete.')
497 
498  elif mode == 'chain_eqe':
499  chain_eqe(cfg=configs, float_model_path=float_model_path, train_ds=train_ds,
500  valid_ds=valid_ds, quantization_ds=quantization_ds, test_ds=test_ds)
501  print('[INFO] : chain_eqe complete.')
502 
503  elif mode == 'chain_qb':
504  chain_qb(cfg=configs, float_model_path=float_model_path, train_ds=train_ds,
505  quantization_ds=quantization_ds)
506  print('[INFO] : chain_qb complete.')
507 
508  elif mode == 'chain_eqeb':
509  chain_eqeb(cfg=configs, float_model_path=float_model_path, train_ds=train_ds,
510  valid_ds=valid_ds, quantization_ds=quantization_ds, test_ds=test_ds)
511  print('[INFO] : chain_eqeb complete.')
512 
513  elif mode == 'chain_qd':
514  chain_qd(cfg=configs, float_model_path=float_model_path, train_ds=train_ds,
515  quantization_ds=quantization_ds)
516  print('[INFO] : chain_qd complete.')
517 
518  else:
519  raise ValueError(f"Invalid mode: {mode}")
520 
521  # Log all output artifacts and configuration to MLflow
522  mlflow.log_artifact(configs.output_dir)
523  if mode in ['benchmarking', 'chain_qb', 'chain_eqeb', 'chain_tqeb']:
524  mlflow.log_param("stm32ai_version", configs.tools.stm32ai.version)
525  mlflow.log_param("target", configs.benchmarking.board)
526  log_to_file(configs.output_dir, f'operation finished: {mode}')
527 
528  # Optional ClearML task logging
529  if get_active_config_file() is not None:
530  print(f"[INFO] : ClearML task connection")
531  task = Task.current_task()
532  task.connect(configs)
533 
534 
535 @hydra.main(version_base=None, config_path="", config_name="user_config")
536 def main(cfg: DictConfig) -> None:
537  """
538  @brief Main entry point of the STM32AI Model Zoo Services script.
539 
540  @details
541  This function is decorated with @hydra.main, which means Hydra automatically
542  loads the configuration from `user_config.yaml` and passes it as a DictConfig object.
543 
544  Execution flow:
545  1. Configure GPU memory limits (if specified in the config).
546  2. Parse and validate the full configuration via get_config().
547  3. Initialize MLflow experiment tracking.
548  4. Optionally initialize ClearML task tracking.
549  5. Set the global random seed for reproducibility.
550  6. Load and preprocess datasets (if required by the selected mode).
551  7. Dispatch to process_mode() based on cfg.operation_mode.
552 
553  @note The operation mode is read from the YAML field `operation_mode`.
554  Modes requiring datasets (training, evaluation, etc.) will call preprocess()
555  to load and prepare the data. Modes like deployment do not require datasets.
556 
557  @param cfg Hydra DictConfig object automatically populated from user_config.yaml.
558 
559  @return None
560  """
561  # Configure GPU memory ceiling to avoid OOM errors during training/quantization
562  if "general" in cfg and cfg.general:
563  if "gpu_memory_limit" in cfg.general and cfg.general.gpu_memory_limit:
564  set_gpu_memory_limit(cfg.general.gpu_memory_limit)
565  print(f"[INFO] Setting upper limit of usable GPU memory to {int(cfg.general.gpu_memory_limit)}GBytes.")
566  else:
567  print("[WARNING] The usable GPU memory is unlimited.\n"
568  "Please consider setting the 'gpu_memory_limit' attribute "
569  "in the 'general' section of your configuration file.")
570 
571  # Parse and validate the configuration file
572  cfg = get_config(cfg)
573  cfg.output_dir = HydraConfig.get().run.dir
574  mlflow_ini(cfg)
575 
576  # ClearML initialization (optional — only if a valid config file exists)
577  print(f"[INFO] : ClearML config check")
578  if get_active_config_file() is not None:
579  print(f"[INFO] : ClearML initialization and configuration")
580  task = Task.init(project_name=cfg.general.project_name, task_name='pe_modelzoo_task')
581  task.connect_configuration(name=cfg.operation_mode, configuration=cfg)
582 
583  # Set global random seed for reproducibility across TensorFlow operations
584  seed = get_random_seed(cfg)
585  print(f'[INFO] : The random seed for this simulation is {seed}')
586  if seed is not None:
587  tf.keras.utils.set_random_seed(seed)
588 
589  # Dispatch based on operation mode
590  mode = cfg.operation_mode
591 
592 
593  valid_modes = ['training', 'evaluation', 'chain_tqeb', 'chain_tqe']
594  if mode in valid_modes:
595  preprocess_output = preprocess(cfg=cfg)
596  train_ds, valid_ds, quantization_ds, test_ds = preprocess_output
597  process_mode(mode=mode, configs=cfg, train_ds=train_ds, valid_ds=valid_ds,
598  quantization_ds=quantization_ds, test_ds=test_ds)
599 
600  elif mode == 'quantization':
601 
602  if cfg.dataset.training_path or cfg.dataset.quantization_path:
603  preprocess_output = preprocess(cfg=cfg)
604  train_ds, valid_ds, quantization_ds, test_ds = preprocess_output
605  process_mode(mode=mode, configs=cfg, train_ds=train_ds, valid_ds=valid_ds,
606  quantization_ds=quantization_ds, test_ds=test_ds)
607  else:
608  process_mode(mode=mode, configs=cfg, fake=True)
609 
610  else:
611 
612  if mode in ['chain_eqe', 'chain_qb', 'chain_eqeb', 'chain_qd']:
613  if cfg.dataset.training_path or cfg.dataset.quantization_path:
614  preprocess_output = preprocess(cfg=cfg)
615  train_ds, valid_ds, quantization_ds, test_ds = preprocess_output
616  else:
617  train_ds = valid_ds = quantization_ds = test_ds = None
618  process_mode(mode=mode, configs=cfg, train_ds=train_ds, valid_ds=valid_ds,
619  quantization_ds=quantization_ds, test_ds=test_ds,
620  float_model_path=cfg.general.model_path)
621  else:
622 
623  process_mode(mode=mode, configs=cfg)
624 
625 
626 if __name__ == "__main__":
627  parser = argparse.ArgumentParser()
628  parser.add_argument('--config-path', type=str, default='',
629  help='Path to folder containing configuration file')
630  parser.add_argument('--config-name', type=str, default='user_config',
631  help='Name of the configuration file')
632  parser.add_argument('params', nargs='*',
633  help='List of parameters to override in user_config.yaml')
634  args = parser.parse_args()
635 
636  main()
637 
638  mlflow.log_param('config_path', args.config_path)
639  mlflow.log_param('config_name', args.config_name)
640  mlflow.end_run()
Definition: deploy.py:1
None deploy_mpu(DictConfig cfg=None, Optional[str] model_path_to_deploy=None, list credentials=None)
Definition: deploy.py:199
DefaultMunch get_config(DictConfig config_data)
None chain_qd(DictConfig cfg=None, str float_model_path=None, tf.data.Dataset train_ds=None, tf.data.Dataset quantization_ds=None)
Definition: stm32ai_main.py:74
None chain_qb(DictConfig cfg=None, str float_model_path=None, tf.data.Dataset train_ds=None, tf.data.Dataset quantization_ds=None)
None chain_eqeb(DictConfig cfg=None, str float_model_path=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None)
None chain_tqeb(DictConfig cfg=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None)
None process_mode(str mode=None, DictConfig configs=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None, Optional[str] float_model_path=None, Optional[bool] fake=False)
None chain_tqe(DictConfig cfg=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None)
None chain_eqe(DictConfig cfg=None, str float_model_path=None, tf.data.Dataset train_ds=None, tf.data.Dataset valid_ds=None, tf.data.Dataset quantization_ds=None, tf.data.Dataset test_ds=None)
None main(DictConfig cfg)