20 Authors: Giacomo Colosio, Sebastiano Colosio, Patrizio Acquadro, Tito Nicola Drugman
27 import tensorflow
as tf
31 from hydra.core.hydra_config
import HydraConfig
32 from omegaconf
import DictConfig
34 from common.optimization
import model_formatting_ptq_per_tensor
35 from common.utils
import get_model_name_and_its_input_shape, tf_dataset_to_np_array
36 from common.quantization
import quantize_onnx
37 from common.evaluation
import model_is_quantized
38 from src.preprocessing
import apply_rescaling
40 from typing
import Optional
44 def _tflite_ptq_quantizer(model: tf.keras.Model =
None, quantization_ds: tf.data.Dataset =
None, fake: bool =
False,
45 output_dir: str =
None, export_dir: Optional[str] =
None, input_shape: tuple =
None,
46 quantization_granularity: str =
None, quantization_input_type: str =
None,
47 quantization_output_type: str =
None, quantization_split: str =
None,
48 quantization_path: str =
None) ->
None:
50 Perform post-training quantization on a TensorFlow Lite model.
53 model (tf.keras.Model): The TensorFlow model to be quantized.
54 quantization_ds (tf.data.Dataset): The quantization dataset if it's provided by the user else the training
55 dataset. Defaults to None
56 fake (bool): Whether to use fake data for representative dataset generation.
57 output_dir (str): Path to the output directory. Defaults to None.
58 export_dir (str): Name of the export directory. Defaults to None.
59 input_shape (tuple: The input shape of the model. Defaults to None.
60 quantization_granularity (str): 'per_tensor' or 'per_channel'. Defaults to None.
61 quantization_input_type (str): The quantization type for the input. Defaults to None.
62 quantization_output_type (str): The quantization type for the output. Defaults to None.
63 quantization_path (str): the quantization dataset path if it's provided by the user. Defaults to None.
64 quantization_split (str): The Fraction of the data to use for the quantization
70 def _representative_data_gen():
72 Generate representative data for post-training quantization.
75 List[tf.Tensor]: A list of TensorFlow tensors representing the input data.
78 for _
in tqdm.tqdm(range(5)):
79 data = np.random.rand(1, input_shape[0], input_shape[1], input_shape[2])
80 yield [data.astype(np.float32)]
83 if not quantization_split:
84 print(
"[INFO] : Quantizing by using the provided dataset fully...")
85 for images, labels
in tqdm.tqdm(quantization_ds, total=len(quantization_ds)):
87 image = tf.cast(image, dtype=tf.float32)
88 image = tf.expand_dims(image, 0)
91 print(f
'[INFO] : Quantizing by using {quantization_split * 100} % of the provided dataset...')
92 quantization_ds_size = len(quantization_ds)
93 splited_ds = quantization_ds.take(int(quantization_ds_size * float(quantization_split)))
94 for images, labels
in tqdm.tqdm(splited_ds, total=len(splited_ds)):
96 image = tf.cast(image, dtype=tf.float32)
97 image = tf.expand_dims(image, 0)
101 converter = tf.lite.TFLiteConverter.from_keras_model(model)
104 tflite_models_dir = pathlib.Path(os.path.join(output_dir,
"{}/".format(export_dir)))
105 tflite_models_dir.mkdir(exist_ok=
True, parents=
True)
108 if quantization_input_type ==
'int8':
109 converter.inference_input_type = tf.int8
110 elif quantization_input_type ==
'uint8':
111 converter.inference_input_type = tf.uint8
114 if quantization_output_type ==
'int8':
115 converter.inference_output_type = tf.int8
116 elif quantization_output_type ==
'uint8':
117 converter.inference_output_type = tf.uint8
122 converter.optimizations = [tf.lite.Optimize.DEFAULT]
123 converter.representative_dataset = _representative_data_gen
126 if quantization_granularity ==
'per_tensor':
127 converter._experimental_disable_per_channel =
True
130 tflite_model_quantized = converter.convert()
131 tflite_model_quantized_file = tflite_models_dir /
"quantized_model.tflite"
132 tflite_model_quantized_file.write_bytes(tflite_model_quantized)
135 def quantize(cfg: DictConfig =
None, quantization_ds: Optional[tf.data.Dataset] =
None, fake: Optional[bool] =
False,
136 float_model_path: Optional[str] =
None) -> str:
138 Quantize the TensorFlow model with training data.
141 cfg (DictConfig): The configuration dictionary. Defaults to None.
142 quantization_ds (tf.data.Dataset): The quantization dataset if it's provided by the user else the training
143 dataset. Defaults to None.
144 fake (bool, optional): Whether to use fake data for representative dataset generation. Defaults to False.
145 float_model_path (str, optional): Model path to quantize
148 quantized model path (str)
151 model_path = float_model_path
if float_model_path
else cfg.general.model_path
152 _, input_shape = get_model_name_and_its_input_shape(model_path=model_path)
154 if model_path.split(
'.')[-1] ==
'onnx':
155 if (cfg[
'quantization'][
'quantizer'].lower() ==
"onnx_quantizer" and
156 cfg[
'quantization'][
'quantization_type'] ==
"PTQ"):
157 if model_is_quantized(model_path):
158 print(
'[INFO]: The input model is already quantized!\n\tReturning the same model!')
162 scale=cfg.preprocessing.rescaling.scale,
163 offset=cfg.preprocessing.rescaling.offset)
164 quant_split = cfg.dataset.quantization_split
if cfg.dataset.quantization_split
else 1.0
165 print(f
'[INFO] : Quantizing by using {quant_split * 100} % of the provided dataset...')
166 quantization_ds_size = len(quantization_ds)
167 splited_ds = quantization_ds.take(int(quantization_ds_size * float(quant_split)))
168 data, _ = tf_dataset_to_np_array(splited_ds)
170 print(f
'[INFO] : Quantizing by using fake dataset...')
172 quantized_model_path = quantize_onnx(quantization_samples=data, configs=cfg)
173 return quantized_model_path
175 raise TypeError(
"Quantizer and quantization type not supported."
176 "Check the `quantization` section of your user_config.yaml file!")
179 float_model = tf.keras.models.load_model(model_path)
180 output_dir = HydraConfig.get().runtime.output_dir
181 export_dir = cfg.quantization.export_dir
182 print(
"[INFO] : Quantizing the model ... This might take few minutes ...")
183 if cfg[
'quantization'][
'quantizer'] ==
"TFlite_converter" and cfg[
'quantization'][
'quantization_type'] ==
"PTQ":
184 quantization_granularity = cfg.quantization.granularity
185 quantization_optimize = cfg.quantization.optimize
186 print(f
'[INFO] : Quantization granularity : {quantization_granularity}')
189 if quantization_granularity ==
'per_tensor' and quantization_optimize:
199 print(
"[INFO] : Optimizing the model for improved per_tensor quantization...")
200 float_model = model_formatting_ptq_per_tensor(model_origin=float_model)
201 optimized_model_path = os.path.join(output_dir, export_dir,
"optimized_model.h5")
202 float_model.save(optimized_model_path)
204 print(
"[INFO] : Quantizing the model ... This might take few minutes ...")
207 export_dir=export_dir, input_shape=input_shape,
208 quantization_granularity=quantization_granularity,
209 quantization_input_type=cfg.quantization.quantization_input_type,
210 quantization_output_type=cfg.quantization.quantization_output_type)
213 quantization_split = cfg.dataset.quantization_split
214 quantization_path = cfg.dataset.quantization_path
215 quantization_ds =
apply_rescaling(dataset=quantization_ds, scale=cfg.preprocessing.rescaling.scale,
216 offset=cfg.preprocessing.rescaling.offset)
218 export_dir=export_dir, input_shape=input_shape,
219 quantization_granularity=quantization_granularity,
220 quantization_input_type=cfg.quantization.quantization_input_type,
221 quantization_output_type=cfg.quantization.quantization_output_type,
222 quantization_split=quantization_split, quantization_path=quantization_path)
223 quantized_model_path = os.path.join(output_dir, export_dir,
"quantized_model.tflite")
224 return quantized_model_path
226 raise TypeError(
"Quantizer and quantization type not supported."
227 "Check the `quantization` section of your user_config.yaml file!")
def apply_rescaling(tf.data.Dataset dataset=None, float scale=None, float offset=None)
str quantize(DictConfig cfg=None, Optional[tf.data.Dataset] quantization_ds=None, Optional[bool] fake=False, Optional[str] float_model_path=None)
None _tflite_ptq_quantizer(tf.keras.Model model=None, tf.data.Dataset quantization_ds=None, bool fake=False, str output_dir=None, Optional[str] export_dir=None, tuple input_shape=None, str quantization_granularity=None, str quantization_input_type=None, str quantization_output_type=None, str quantization_split=None, str quantization_path=None)