19 Authors: Giacomo Colosio, Sebastiano Colosio, Patrizio Acquadro, Tito Nicola Drugman
25 from pathlib
import Path
26 import tensorflow
as tf
27 from onnx
import ModelProto
29 from omegaconf
import DictConfig
32 from common.utils
import check_model_support, check_attributes
33 from src.models
import st_movenet_lightning_heatmaps, custom
39 n_shape = [old_shape[0]]
40 for v
in x.shape[1:len(x.shape) - 1]:
43 n_shape.append(old_shape[-1])
44 return x.reshape(n_shape)
46 preds, _ = ai_runner_interpreter.invoke(image_processed)
50 predictions.append(x.copy())
55 Returns a Keras model object based on the specified configuration and parameters.
58 cfg (DictConfig): A dictionary containing the configuration for the model.
59 num_classes (int): The number of classes for the model.
60 dropout (float): The dropout rate for the model.
61 section (str): The section of the model to be used.
64 tf.keras.Model: A Keras model object based on the specified configuration and parameters.
72 model_name = cfg.general.model_type
73 message =
"\nPlease check the 'general' section of your configuration file."
74 check_model_support(model_name, supported_models=supported_models, message=message)
76 cft = cfg.training.model
77 input_shape = cft.input_shape
78 nb_keypoints = cfg.dataset.keypoints
79 random_resizing =
True if cfg.data_augmentation
and cfg.data_augmentation.config.random_periodic_resizing
else False
80 section =
"training.model"
83 if cft.name ==
'st_movenet_lightning_heatmaps':
84 check_attributes(cft, expected=[
"name",
"alpha",
"input_shape"], optional=[
"pretrained_weights"], section=section)
85 model = st_movenet_lightning_heatmaps(input_shape=input_shape,
86 nb_keypoints=nb_keypoints,
88 pretrained_weights=cft.pretrained_weights)
89 elif cft.name ==
"custom":
90 check_attributes(cft, expected=[
"name",
"input_shape"], section=section)
91 model = custom(input_shape=input_shape,
92 nb_keypoints=nb_keypoints)
99 Loads a model for training.
101 The model to train can be:
102 - a model from the Model Zoo
103 - a user model (BYOM)
104 - a model previously trained during a training that was interrupted.
106 When a training is run, the following files are saved in the saved_models
109 Model saved before the training started. Weights are random.
111 Best weights obtained since the beginning of the training.
113 Weights saved at the end of the last epoch.
115 To resume a training, the last weights are loaded into the base model.
118 model_type = cfg.general.model_type
122 if cfg.training.model:
123 print(
"[INFO] : Loading Model Zoo model:", model_type)
126 cft = cfg.training.model
127 if cft.pretrained_weights:
128 print(f
"[INFO] : Loaded pretrained weights: `{cft.pretrained_weights}`")
130 print(f
"[INFO] : No pretrained weights were loaded.")
133 elif cfg.general.model_path:
134 print(
"[INFO] : Loading model", cfg.general.model_path)
135 model = tf.keras.models.load_model(cfg.general.model_path, compile=
False)
138 input_shape = tuple(model.input.shape[1:])
139 if None in input_shape:
140 raise ValueError(f
"\nThe model input shape is unspecified. Got {str(input_shape)}\n"
141 "Unable to proceed with training.")
144 elif cfg.training.resume_training_from:
145 resume_dir = os.path.join(cfg.training.resume_training_from, cfg.general.saved_models_dir)
146 print(f
"[INFO] : Resuming training from directory {resume_dir}\n")
148 message =
"\nUnable to resume training."
149 if not os.path.isdir(resume_dir):
150 raise FileNotFoundError(f
"\nCould not find resume directory {resume_dir}{message}")
151 model_path = os.path.join(resume_dir,
"base_model.h5")
152 if not os.path.isfile(model_path):
153 raise FileNotFoundError(f
"\nCould not find model file {model_path}{message}\n")
154 last_weights_path = os.path.join(resume_dir,
"last_weights.h5")
155 if not os.path.isfile(last_weights_path):
156 raise FileNotFoundError(f
"\nCould not find model weights file {last_weights_path}{message}\n")
158 model = tf.keras.models.load_model(model_path, compile=
False)
159 model.load_weights(last_weights_path)
tuple load_model_for_training(DictConfig cfg)
def ai_runner_invoke(image_processed, ai_runner_interpreter)
def _get_zoo_model(DictConfig cfg)