From 1db2651240277b85687b9f1511872421e62406c2 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Fri, 22 Sep 2023 09:51:48 +0200 Subject: [PATCH] Add image reshaping for statically reshaped OpenVINO model (#428) * add vae image processor * add image reshaping for statically reshaped model * add test * format * fix pipeline * fix reshaping * disable reshaping for inpaint SD models * add reshaping for inpaint --- optimum/intel/openvino/modeling_diffusion.py | 250 +++++++++++++++++-- tests/openvino/test_stable_diffusion.py | 46 ++-- 2 files changed, 256 insertions(+), 40 deletions(-) diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index c2884ee57..cb9d92a15 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -22,6 +22,7 @@ import numpy as np import openvino +import PIL from diffusers import ( DDIMScheduler, LMSDiscreteScheduler, @@ -351,6 +352,13 @@ def width(self) -> int: return -1 return width.get_length() * self.vae_scale_factor + @property + def _batch_size(self) -> int: + batch_size = self.unet.model.inputs[0].get_partial_shape()[0] + if batch_size.is_dynamic: + return -1 + return batch_size.get_length() + def _reshape_unet( self, model: openvino.runtime.Model, @@ -649,6 +657,7 @@ def __call__( width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor _height = self.height _width = self.width + expected_batch_size = self._batch_size if _height != -1 and height != _height: logger.warning( @@ -664,11 +673,15 @@ def __call__( ) width = _width - if guidance_scale is not None and guidance_scale <= 1 and not self.is_dynamic: - raise ValueError( - f"`guidance_scale` was set to {guidance_scale}, static shapes are only supported for `guidance_scale` > 1, " - "please set `dynamic_shapes` to `True` when loading the model." - ) + if expected_batch_size != -1: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = kwargs.get("prompt_embeds").shape[0] + + _raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale) return StableDiffusionPipelineMixin.__call__( self, @@ -684,16 +697,115 @@ def __call__( class OVStableDiffusionImg2ImgPipeline(OVStableDiffusionPipelineBase, StableDiffusionImg2ImgPipelineMixin): - def __call__(self, *args, **kwargs): - # TODO : add default height and width if model statically reshaped - # resize image if doesn't match height and width given during reshaping - return StableDiffusionImg2ImgPipelineMixin.__call__(self, *args, **kwargs) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + **kwargs, + ): + _height = self.height + _width = self.width + expected_batch_size = self._batch_size + + if _height != -1 and _width != -1: + image = self.image_processor.preprocess(image, height=_height, width=_width).transpose(0, 2, 3, 1) + + if expected_batch_size != -1: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = kwargs.get("prompt_embeds").shape[0] + + _raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale) + + return StableDiffusionImg2ImgPipelineMixin.__call__( + self, + prompt=prompt, + image=image, + strength=strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + **kwargs, + ) class OVStableDiffusionInpaintPipeline(OVStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin): - def __call__(self, *args, **kwargs): - # TODO : add default height and width if model statically reshaped - return StableDiffusionInpaintPipelineMixin.__call__(self, *args, **kwargs) + def __call__( + self, + prompt: Optional[Union[str, List[str]]], + image: PIL.Image.Image, + mask_image: PIL.Image.Image, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + **kwargs, + ): + height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor + width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor + _height = self.height + _width = self.width + expected_batch_size = self._batch_size + + if _height != -1 and _width != -1: + if height != _height: + logger.warning( + f"`height` was set to {height} but the static model will output images of height {_height}." + "To fix the height, please reshape your model accordingly using the `.reshape()` method." + ) + height = _height + + if width != _width: + logger.warning( + f"`width` was set to {width} but the static model will output images of width {_width}." + "To fix the width, please reshape your model accordingly using the `.reshape()` method." + ) + width = _width + + if isinstance(image, list): + image = [self.image_processor.resize(i, _height, _width) for i in image] + else: + image = self.image_processor.resize(image, _height, _width) + + if isinstance(mask_image, list): + mask_image = [self.image_processor.resize(i, _height, _width) for i in mask_image] + else: + mask_image = self.image_processor.resize(mask_image, _height, _width) + + if expected_batch_size != -1: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = kwargs.get("prompt_embeds").shape[0] + + _raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale) + + return StableDiffusionInpaintPipelineMixin.__call__( + self, + prompt=prompt, + image=image, + mask_image=mask_image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + **kwargs, + ) class OVStableDiffusionXLPipelineBase(OVStableDiffusionPipelineBase): @@ -718,10 +830,116 @@ def __init__(self, *args, add_watermarker: Optional[bool] = None, **kwargs): class OVStableDiffusionXLPipeline(OVStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin): - def __call__(self, *args, **kwargs): - return StableDiffusionXLPipelineMixin.__call__(self, *args, **kwargs) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + **kwargs, + ): + height = height or self.unet.config["sample_size"] * self.vae_scale_factor + width = width or self.unet.config["sample_size"] * self.vae_scale_factor + _height = self.height + _width = self.width + expected_batch_size = self._batch_size + + if _height != -1 and height != _height: + logger.warning( + f"`height` was set to {height} but the static model will output images of height {_height}." + "To fix the height, please reshape your model accordingly using the `.reshape()` method." + ) + height = _height + + if _width != -1 and width != _width: + logger.warning( + f"`width` was set to {width} but the static model will output images of width {_width}." + "To fix the width, please reshape your model accordingly using the `.reshape()` method." + ) + width = _width + + if expected_batch_size != -1: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = kwargs.get("prompt_embeds").shape[0] + + _raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale) + + return StableDiffusionXLPipelineMixin.__call__( + self, + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + **kwargs, + ) class OVStableDiffusionXLImg2ImgPipeline(OVStableDiffusionXLPipelineBase, StableDiffusionXLImg2ImgPipelineMixin): - def __call__(self, *args, **kwargs): - return StableDiffusionXLImg2ImgPipelineMixin.__call__(self, *args, **kwargs) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.3, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + **kwargs, + ): + _height = self.height + _width = self.width + expected_batch_size = self._batch_size + + if _height != -1 and _width != -1: + image = self.image_processor.preprocess(image, height=_height, width=_width).transpose(0, 2, 3, 1) + + if expected_batch_size != -1: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = kwargs.get("prompt_embeds").shape[0] + + _raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale) + + return StableDiffusionXLImg2ImgPipelineMixin.__call__( + self, + prompt=prompt, + image=image, + strength=strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + **kwargs, + ) + + +def _raise_invalid_batch_size( + expected_batch_size: int, batch_size: int, num_images_per_prompt: int, guidance_scale: float +): + current_batch_size = batch_size * num_images_per_prompt * (1 if guidance_scale <= 1 else 2) + + if expected_batch_size != current_batch_size: + msg = "" + if guidance_scale is not None and guidance_scale <= 1: + msg = f"`guidance_scale` was set to {guidance_scale}, static shapes are currently only supported for `guidance_scale` > 1 " + + raise ValueError( + "The model was statically reshaped and the pipeline inputs do not match the expected shapes. " + f"The `batch_size`, `num_images_per_prompt` and `guidance_scale` were respectively set to {batch_size}, {num_images_per_prompt} and {guidance_scale}. " + f"The static model expects an input of size equal to {expected_batch_size} and got the following value instead : {current_batch_size}. " + f"To fix this, please either provide a different inputs to your model so that `batch_size` * `num_images_per_prompt` * 2 is equal to {expected_batch_size} " + "or reshape it again accordingly using the `.reshape()` method by setting `batch_size` to -1. " + msg + ) diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index 781fbe0ec..0e2ea91e4 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -184,10 +184,11 @@ def test_num_images_per_prompt_static_model(self, model_arch: str): pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False) batch_size, num_images, height, width = 2, 3, 128, 64 pipeline.half() - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images) - outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + for _height in [height, height + 16]: + inputs = self.generate_inputs(height=_height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): inputs = _generate_inputs(batch_size) @@ -264,21 +265,15 @@ def test_num_images_per_prompt_static_model(self, model_arch: str): model_id = MODEL_NAMES[model_arch] pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False) batch_size, num_images, height, width = 3, 4, 128, 64 - prompt = "sailing ship in storm by Leonardo da Vinci" pipeline.half() pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images) self.assertFalse(pipeline.is_dynamic) pipeline.compile() - # Verify output shapes requirements not matching the static model don't impact the final outputs - outputs = pipeline( - [prompt] * batch_size, - num_inference_steps=2, - num_images_per_prompt=num_images, - height=height + 8, - width=width, - output_type="np", - ).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + # Verify output shapes requirements not matching the static model doesn't impact the final outputs + for _height in [height, height + 16]: + inputs = _generate_inputs(batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=_height, width=width).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_height_width_properties(self, model_arch: str): @@ -341,10 +336,11 @@ def test_num_images_per_prompt_static_model(self, model_arch: str): pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False) batch_size, num_images, height, width = 1, 3, 128, 64 pipeline.half() - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images) - outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + for _height in [height, height + 16]: + inputs = self.generate_inputs(height=_height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) def generate_inputs(self, height=128, width=128, batch_size=1): inputs = super(OVStableDiffusionInpaintPipelineTest, self).generate_inputs(height, width, batch_size) @@ -432,10 +428,11 @@ def test_num_images_per_prompt_static_model(self, model_arch: str): pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images) self.assertFalse(pipeline.is_dynamic) pipeline.compile() - # Verify output shapes requirements not matching the static model don't impact the final outputs - inputs = _generate_inputs(batch_size) - outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=height, width=width).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + + for _height in [height, height + 16]: + inputs = _generate_inputs(batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=_height, width=width).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) class OVStableDiffusionXLImg2ImgPipelineTest(unittest.TestCase): @@ -467,10 +464,11 @@ def test_num_images_per_prompt_static_model(self, model_arch: str): pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False) batch_size, num_images, height, width = 2, 3, 128, 64 pipeline.half() - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images) - outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + for _height in [height, height + 16]: + inputs = self.generate_inputs(height=_height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): inputs = _generate_inputs(batch_size)