STM32N6 NPU Deployment — Politecnico di Milano  1.0
Documentation for Neural Network Deployment on STM32N6 NPU - Politecnico di Milano 2024-2025
models_mgt.py
Go to the documentation of this file.
1 # /*---------------------------------------------------------------------------------------------
2 # * Copyright (c) 2024 STMicroelectronics.
3 # * All rights reserved.
4 # *
5 # * This software is licensed under terms that can be found in the LICENSE file in
6 # * the root directory of this software component.
7 # * If no LICENSE file comes with this software, it is provided AS-IS.
8 # *--------------------------------------------------------------------------------------------*/
9 
10 
19 Authors: Giacomo Colosio, Sebastiano Colosio, Patrizio Acquadro, Tito Nicola Drugman
20 #
21 # @copyright Copyright (c) 2023-2024 STMicroelectronics. All rights reserved.
22 
23 
24 import os
25 from pathlib import Path
26 import tensorflow as tf
27 from onnx import ModelProto
28 import onnxruntime
29 from omegaconf import DictConfig
30 import numpy as np
31 
32 from common.utils import check_model_support, check_attributes
33 from src.models import st_movenet_lightning_heatmaps, custom
34 
35 
36 def ai_runner_invoke(image_processed,ai_runner_interpreter):
37  def reduce_shape(x): # reduce shape (request by legacy API)
38  old_shape = x.shape
39  n_shape = [old_shape[0]]
40  for v in x.shape[1:len(x.shape) - 1]:
41  if v != 1:
42  n_shape.append(v)
43  n_shape.append(old_shape[-1])
44  return x.reshape(n_shape)
45 
46  preds, _ = ai_runner_interpreter.invoke(image_processed)
47  predictions = []
48  for x in preds:
49  x = reduce_shape(x)
50  predictions.append(x.copy())
51  return predictions
52 
53 def _get_zoo_model(cfg: DictConfig):
54  """
55  Returns a Keras model object based on the specified configuration and parameters.
56 
57  Args:
58  cfg (DictConfig): A dictionary containing the configuration for the model.
59  num_classes (int): The number of classes for the model.
60  dropout (float): The dropout rate for the model.
61  section (str): The section of the model to be used.
62 
63  Returns:
64  tf.keras.Model: A Keras model object based on the specified configuration and parameters.
65  """
66 
67  # Define the supported models and their versions
68  supported_models = {
69  'heatmaps_spe': None
70  }
71 
72  model_name = cfg.general.model_type
73  message = "\nPlease check the 'general' section of your configuration file."
74  check_model_support(model_name, supported_models=supported_models, message=message)
75 
76  cft = cfg.training.model
77  input_shape = cft.input_shape
78  nb_keypoints = cfg.dataset.keypoints
79  random_resizing = True if cfg.data_augmentation and cfg.data_augmentation.config.random_periodic_resizing else False
80  section = "training.model"
81  model = None
82 
83  if cft.name == 'st_movenet_lightning_heatmaps':
84  check_attributes(cft, expected=["name","alpha","input_shape"], optional=["pretrained_weights"], section=section)
85  model = st_movenet_lightning_heatmaps(input_shape=input_shape,
86  nb_keypoints=nb_keypoints,
87  alpha=cft.alpha,
88  pretrained_weights=cft.pretrained_weights)
89  elif cft.name == "custom":
90  check_attributes(cft, expected=["name","input_shape"], section=section)
91  model = custom(input_shape=input_shape,
92  nb_keypoints=nb_keypoints)
93 
94  return model
95 
96 
97 def load_model_for_training(cfg: DictConfig) -> tuple:
98  """"
99  Loads a model for training.
100 
101  The model to train can be:
102  - a model from the Model Zoo
103  - a user model (BYOM)
104  - a model previously trained during a training that was interrupted.
105 
106  When a training is run, the following files are saved in the saved_models
107  directory:
108  base_model.h5:
109  Model saved before the training started. Weights are random.
110  best_weights.h5:
111  Best weights obtained since the beginning of the training.
112  last_weights.h5:
113  Weights saved at the end of the last epoch.
114 
115  To resume a training, the last weights are loaded into the base model.
116  """
117 
118  model_type = cfg.general.model_type
119  model = None
120 
121  # Train a model from the Model Zoo
122  if cfg.training.model:
123  print("[INFO] : Loading Model Zoo model:", model_type)
124  model = _get_zoo_model(cfg)
125 
126  cft = cfg.training.model
127  if cft.pretrained_weights:
128  print(f"[INFO] : Loaded pretrained weights: `{cft.pretrained_weights}`")
129  else:
130  print(f"[INFO] : No pretrained weights were loaded.")
131 
132  # Bring your own model
133  elif cfg.general.model_path:
134  print("[INFO] : Loading model", cfg.general.model_path)
135  model = tf.keras.models.load_model(cfg.general.model_path, compile=False)
136 
137  # Check that the model has a specified input shape
138  input_shape = tuple(model.input.shape[1:])
139  if None in input_shape:
140  raise ValueError(f"\nThe model input shape is unspecified. Got {str(input_shape)}\n"
141  "Unable to proceed with training.")
142 
143  # Resume a previously interrupted training
144  elif cfg.training.resume_training_from:
145  resume_dir = os.path.join(cfg.training.resume_training_from, cfg.general.saved_models_dir)
146  print(f"[INFO] : Resuming training from directory {resume_dir}\n")
147 
148  message = "\nUnable to resume training."
149  if not os.path.isdir(resume_dir):
150  raise FileNotFoundError(f"\nCould not find resume directory {resume_dir}{message}")
151  model_path = os.path.join(resume_dir, "base_model.h5")
152  if not os.path.isfile(model_path):
153  raise FileNotFoundError(f"\nCould not find model file {model_path}{message}\n")
154  last_weights_path = os.path.join(resume_dir, "last_weights.h5")
155  if not os.path.isfile(last_weights_path):
156  raise FileNotFoundError(f"\nCould not find model weights file {last_weights_path}{message}\n")
157 
158  model = tf.keras.models.load_model(model_path, compile=False)
159  model.load_weights(last_weights_path)
160 
161  return model
tuple load_model_for_training(DictConfig cfg)
Definition: models_mgt.py:97
def ai_runner_invoke(image_processed, ai_runner_interpreter)
Definition: models_mgt.py:36
def _get_zoo_model(DictConfig cfg)
Definition: models_mgt.py:53