20 Authors: Giacomo Colosio, Sebastiano Colosio, Patrizio Acquadro, Tito Nicola Drugman
25 from omegaconf
import DictConfig
27 import tensorflow
as tf
28 from typing
import Tuple
32 from common.utils
import get_model_name_and_its_input_shape
33 from .data_loader
import load_dataset
38 Preprocesses the data based on the provided configuration.
41 cfg (DictConfig): Configuration object containing the settings.
44 Tuple: A tuple containing the following:
45 - data_augmentation (object): Data augmentation object.
46 - augment (bool): Flag indicating whether data augmentation is enabled.
47 - pre_process (object): Preprocessing object.
48 - train_ds (object): Training dataset.
49 - valid_ds (object): Validation dataset.
53 if cfg.general.model_path:
54 _, input_shape = get_model_name_and_its_input_shape(cfg.general.model_path)
57 if cfg.training.model:
58 input_shape = cfg.training.model.input_shape
60 _, input_shape = get_model_name_and_its_input_shape(cfg.training.resume_training_from)
62 interpolation = cfg.preprocessing.resizing.interpolation
63 aspect_ratio = cfg.preprocessing.resizing.aspect_ratio
67 batch_size = cfg.training.batch_size
if cfg.training
else 32
69 train_ds, valid_ds, quantization_ds, test_ds =
load_dataset(
70 dataset_name=cfg.dataset.name,
71 training_path=cfg.dataset.training_path,
72 validation_path=cfg.dataset.validation_path,
73 quantization_path=cfg.dataset.quantization_path,
74 test_path=cfg.dataset.test_path,
75 validation_split=cfg.dataset.validation_split,
76 nbr_keypoints=cfg.dataset.keypoints,
77 image_size= input_shape[1:]
if cfg.general.model_path
and cfg.general.model_path.split(
'.')[-1]==
'onnx' else input_shape[:2],
78 interpolation=interpolation,
79 aspect_ratio=aspect_ratio,
80 color_mode=cfg.preprocessing.color_mode,
81 batch_size=batch_size,
82 seed=cfg.dataset.seed)
84 return train_ds, valid_ds, quantization_ds, test_ds
87 def apply_rescaling(dataset: tf.data.Dataset =
None, scale: float =
None, offset: float =
None):
89 Applies rescaling to a dataset using a tf.keras.Sequential model.
92 dataset (tf.data.Dataset): The dataset to be rescaled.
93 scale (float): The scaling factor.
94 offset (float): The offset factor.
100 rescaling = tf.keras.Sequential([
101 tf.keras.layers.Rescaling(scale, offset)
105 rescaled_dataset = dataset.map(
lambda x, y: (rescaling(x), y))
107 return rescaled_dataset
113 Preprocesses an input image according to input details.
116 image: Input image as a NumPy array.
117 input_details: Dictionary containing input details, including quantization and dtype.
120 Preprocessed image as a TensorFlow tensor.
124 image = tf.image.resize(image, (input_details[
'shape'][1], input_details[
'shape'][2]))
126 if input_details[
'dtype']
in [np.uint8, np.int8]:
127 image_processed = (image / input_details[
'quantization'][0]) + input_details[
'quantization'][1]
128 image_processed = np.clip(np.round(image_processed), np.iinfo(input_details[
'dtype']).min,
129 np.iinfo(input_details[
'dtype']).max)
132 image_processed = image
133 image_processed = tf.cast(image_processed, dtype=input_details[
'dtype'])
134 image_processed = tf.expand_dims(image_processed, 0)
135 return image_processed
Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset] load_dataset(str dataset_name=None, str training_path=None, str validation_path=None, str quantization_path=None, str test_path=None, float validation_split=None, int nbr_keypoints=None, tuple[int] image_size=None, str interpolation=None, str aspect_ratio=None, str color_mode=None, int batch_size=None, int seed=None)
tf.Tensor preprocess_input(np.ndarray image, dict input_details)
Tuple preprocess(DictConfig cfg=None)
def apply_rescaling(tf.data.Dataset dataset=None, float scale=None, float offset=None)