19 Authors: Giacomo Colosio, Sebastiano Colosio, Patrizio Acquadro, Tito Nicola Drugman
25 from copy
import deepcopy
26 from pathlib
import Path
28 from hydra.core.hydra_config
import HydraConfig
29 from omegaconf
import OmegaConf, DictConfig
30 from munch
import DefaultMunch
31 from typing
import Dict
33 from common.utils
import postprocess_config_dict, check_config_attributes, parse_tools_section, parse_benchmarking_section, \
34 parse_mlflow_section, parse_top_level, parse_general_section, parse_quantization_section, \
35 parse_training_section, parse_prediction_section, parse_deployment_section, check_hardware_type, \
36 parse_evaluation_section, check_attributes
42 legal = [
"kpts_conf_thresh",
"confidence_thresh",
"NMS_thresh",
"max_detection_boxes",
"plot_metrics"]
45 check_config_attributes(cfg, specs={
"legal": legal,
"all": required}, section=
"postprocessing")
47 cfg.plot_metrics = cfg.plot_metrics
if cfg.plot_metrics
is not None else False
50 def _parse_dataset_section(cfg: DictConfig, hardware_type:str, mode: str =
None, mode_groups: DictConfig =
None) ->
None:
52 legal = [
"name",
"keypoints",
"keypoints_file_path",
"training_path",
"validation_path",
"validation_split",
"test_path",
53 "quantization_path",
"quantization_split",
"seed",
"class_names"]
57 if mode
in mode_groups.training:
58 required += [
"training_path",
"keypoints", ]
59 elif mode
in mode_groups.evaluation:
60 one_or_more += [
"training_path",
"test_path",
"keypoints", ]
61 if mode
not in [
"quantization",
"benchmarking",
"chain_qb",
"deployment",
"chain_qd"]:
62 required += [
"keypoints", ]
63 if mode
in mode_groups.deployment
and hardware_type ==
"MPU":
64 required += [
"keypoints_file_path"]
65 check_config_attributes(cfg, specs={
"legal": legal,
"all": required,
"one_or_more": one_or_more},
71 if not cfg.validation_split:
72 cfg.validation_split = 0.2
73 cfg.seed = cfg.seed
if cfg.seed
else 123
76 if cfg.validation_split:
77 split = cfg.validation_split
78 if split <= 0.0
or split >= 1.0:
79 raise ValueError(f
"\nThe value of `validation_split` should be > 0 and < 1. Received {split}\n"
80 "Please check the 'dataset' section of your configuration file.")
83 if cfg.quantization_split:
84 split = cfg.quantization_split
85 if split <= 0.0
or split >= 1.0:
86 raise ValueError(f
"\nThe value of `quantization_split` should be > 0 and < 1. Received {split}\n"
87 "Please check the 'dataset' section of your configuration file.")
92 legal = [
"rescaling",
"resizing",
"color_mode"]
93 if mode ==
'deployment':
95 required=[
"resizing",
"color_mode"]
96 check_config_attributes(cfg.preprocessing, specs={
"legal": legal,
"all": required}, section=
"preprocessing")
99 check_config_attributes(cfg.preprocessing, specs={
"legal": legal,
"all": required}, section=
"preprocessing")
100 legal = [
"scale",
"offset"]
101 check_config_attributes(cfg.preprocessing.rescaling, specs={
"legal": legal,
"all": legal}, section=
"preprocessing.rescaling")
103 legal = [
"interpolation",
"aspect_ratio"]
104 check_config_attributes(cfg.preprocessing.resizing, specs={
"legal": legal,
"all": legal}, section=
"preprocessing.resizing")
106 if cfg.hardware_type ==
"MCU":
107 if cfg.preprocessing.resizing.aspect_ratio
not in (
"fit",
"crop",
"padding"):
108 raise ValueError(
"\nSupported methods for resizing images are 'fit', 'crop' and 'padding'. "
109 f
"Received {cfg.preprocessing.resizing.aspect_ratio}\n"
110 "Please check the `resizing.aspect_ratio` attribute in "
111 "the 'preprocessing' section of your configuration file.")
113 elif cfg.hardware_type ==
"MPU":
114 if cfg.preprocessing.resizing.aspect_ratio
not in [
"fit",
"padding"]:
115 raise ValueError(
"The only values of aspect_ratio that are supported at this point are 'fit' and 'padding'"
116 "('crop' is not supported).")
118 raise ValueError(f
"Unsupported hardware_type: {cfg.hardware_type}. "
119 "Expected 'MCU' or 'MPU'.")
122 interpolation_methods = [
"bilinear",
"nearest",
"area",
"lanczos3",
"lanczos5",
"bicubic",
"gaussian",
124 if cfg.preprocessing.resizing.interpolation
not in interpolation_methods:
125 raise ValueError(f
"\nUnknown value for `interpolation` attribute. Received {cfg.preprocessing.resizing.interpolation}\n"
126 f
"Supported values: {interpolation_methods}\n"
127 "Please check the 'preprocessing.resizing' section of your configuration file.")
130 color_modes = [
"grayscale",
"rgb",
"rgba"]
131 if cfg.preprocessing.color_mode
not in color_modes:
132 raise ValueError(f
"\nUnknown value for `color_mode` attribute. Received {cfg.preprocessing.color_mode}\n"
133 f
"Supported values: {color_modes}\n"
134 "Please check the 'preprocessing' section of your configuration file.")
139 This function checks the data augmentation section of the config file.
140 The attribute that introduces the section is either `data_augmentation`
141 or `custom_data_augmentation`. If it is `custom_data_augmentation`,
142 the name of the data augmentation function that is provided must be
143 different from `data_augmentation` as this is a reserved name.
146 cfg (DictConfig): The entire configuration file as a DefaultMunch dictionary.
153 if cfg.data_augmentation
and cfg.custom_data_augmentation:
154 raise ValueError(
"\nThe `data_augmentation` and `custom_data_augmentation` attributes "
155 "are mutually exclusive.\nPlease check your configuration file.")
157 if cfg.data_augmentation:
158 data_aug = DefaultMunch.fromDict({})
160 data_aug.function_name =
"data_augmentation"
161 data_aug.config = deepcopy(cfg.data_augmentation)
163 if cfg.custom_data_augmentation:
164 check_attributes(cfg.custom_data_augmentation,
165 expected=[
"function_name"],
167 section=
"custom_data_augmentation")
168 if cfg.custom_data_augmentation[
"function_name"] ==
"data_augmentation":
169 raise ValueError(
"\nThe function name `data_augmentation` is reserved.\n"
170 "Please use another name (attribute `function_name` in "
171 "the 'custom_data_augmentation' section).")
173 data_aug = DefaultMunch.fromDict({})
174 data_aug.function_name = cfg.custom_data_augmentation.function_name
175 if cfg.custom_data_augmentation.config:
176 data_aug.config = deepcopy(cfg.custom_data_augmentation.config)
178 cfg.data_augmentation = data_aug
183 Converts the configuration data, performs some checks and reformats
184 some sections so that they are easier to use later on.
187 config_data (DictConfig): dictionary containing the entire configuration file.
190 DefaultMunch: The configuration object.
193 config_dict = OmegaConf.to_container(config_data)
197 postprocess_config_dict(config_dict)
200 cfg = DefaultMunch.fromDict(config_dict)
201 mode_groups = DefaultMunch.fromDict({
202 "training": [
"training",
"chain_tqeb",
"chain_tqe"],
203 "evaluation": [
"evaluation",
"chain_tqeb",
"chain_tqe",
"chain_eqe",
"chain_eqeb"],
204 "quantization": [
"quantization",
"chain_tqeb",
"chain_tqe",
"chain_eqe",
205 "chain_qb",
"chain_eqeb",
"chain_qd"],
206 "benchmarking": [
"benchmarking",
"chain_tqeb",
"chain_qb",
"chain_eqeb"],
207 "deployment": [
"deployment",
"chain_qd"],
208 "prediction": [
"prediction"]
210 mode_choices = [
"training",
"evaluation",
"deployment",
211 "quantization",
"benchmarking",
"chain_tqeb",
"chain_tqe",
212 "chain_eqe",
"chain_qb",
"chain_eqeb",
"chain_qd",
"prediction"]
213 legal = [
"general",
"operation_mode",
"dataset",
"preprocessing",
"data_augmentation",
214 "training",
"postprocessing",
"quantization",
"evaluation",
"prediction",
"tools",
215 "benchmarking",
"deployment",
"mlflow",
"hydra"]
217 mode_groups=mode_groups,
218 mode_choices=mode_choices,
220 print(f
"[INFO] : Running `{cfg.operation_mode}` operation mode")
224 cfg.general = DefaultMunch.fromDict({})
225 legal = [
"project_name",
"model_path",
"logs_dir",
"saved_models_dir",
"deterministic_ops",
226 "display_figures",
"global_seed",
"gpu_memory_limit",
"model_type",
"num_threads_tflite"]
228 parse_general_section(cfg.general,
229 mode=cfg.operation_mode,
230 mode_groups=mode_groups,
233 output_dir = HydraConfig.get().runtime.output_dir)
236 check_hardware_type(cfg,
241 cfg.dataset = DefaultMunch.fromDict({})
243 mode=cfg.operation_mode,
244 mode_groups=mode_groups,)
248 mode=cfg.operation_mode)
251 if cfg.operation_mode
in mode_groups.training:
252 if cfg.data_augmentation
or cfg.custom_data_augmentation:
254 model_path_used = bool(cfg.general.model_path)
255 model_type_used = bool(cfg.general.model_type)
256 legal = [
"model",
"batch_size",
"epochs",
"optimizer",
"dropout",
"frozen_layers",
257 "callbacks",
"trained_model_path",
"resume_training_from"]
258 parse_training_section(cfg.training,
259 model_path_used=model_path_used,
260 model_type_used=model_type_used,
265 if cfg.operation_mode
in (mode_groups.prediction):
269 if cfg.operation_mode
in mode_groups.quantization:
270 legal = [
"quantizer",
"quantization_type",
"quantization_input_type",
271 "quantization_output_type",
"granularity",
"export_dir",
"optimize"]
272 parse_quantization_section(cfg.quantization,
276 if cfg.operation_mode
in mode_groups.evaluation
and "evaluation" in cfg:
277 legal = [
"gen_npy_input",
"gen_npy_output",
"npy_in_name",
"npy_out_name",
"target",
278 "profile",
"input_type",
"output_type",
"input_chpos",
"output_chpos"]
279 parse_evaluation_section(cfg.evaluation,
283 if cfg.operation_mode ==
"prediction":
284 parse_prediction_section(cfg.prediction)
287 if cfg.operation_mode
in (mode_groups.benchmarking + mode_groups.deployment):
288 parse_tools_section(cfg.tools,
293 if cfg.operation_mode
in mode_groups.benchmarking:
294 if "STM32MP" in cfg.benchmarking.board:
295 if cfg.operation_mode ==
"benchmarking" and not(cfg.tools.stm32ai.on_cloud):
296 print(
"Target selected for benchmark :", cfg.benchmarking.board)
297 print(
"Offline benchmarking for MPU is not yet available. Please use online benchmarking.")
301 if cfg.operation_mode
in mode_groups.benchmarking:
302 parse_benchmarking_section(cfg.benchmarking)
303 if cfg.hardware_type ==
"MPU":
304 if not (cfg.tools.stm32ai.on_cloud):
305 print(
"Target selected for benchmark :", cfg.benchmarking.board)
306 print(
"Offline benchmarking for MPU is not yet available. Please use online benchmarking.")
310 if cfg.operation_mode
in mode_groups.deployment:
311 if cfg.hardware_type ==
"MCU":
312 legal = [
"c_project_path",
"IDE",
"verbosity",
"hardware_setup",
"build_conf"]
313 legal_hw = [
"serie",
"board",
"stlink_serial_number"]
314 parse_deployment_section(cfg.deployment,
318 legal = [
"c_project_path",
"board_deploy_path",
"verbosity",
"hardware_setup"]
319 legal_hw = [
"serie",
"board",
"ip_address",
"stlink_serial_number"]
320 if cfg.preprocessing.color_mode !=
"rgb":
321 raise ValueError(
"\n Color mode used is not supported for deployment on MPU target \n Please use RGB format")
322 if cfg.preprocessing.resizing.aspect_ratio !=
"fit":
323 raise ValueError(
"\n Aspect ratio used is not supported for deployment on MPU target \n Please use FIT aspect ratio")
324 parse_deployment_section(cfg.deployment,
329 parse_mlflow_section(cfg.mlflow)
None _parse_data_augmentation_section(DictConfig cfg)
None _parse_postprocessing_section(DictConfig cfg)
None _parse_preprocessing_section(DictConfig cfg, str mode=None)
DefaultMunch get_config(DictConfig config_data)
None _parse_dataset_section(DictConfig cfg, str hardware_type, str mode=None, DictConfig mode_groups=None)