diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 37f87549d0..4e2896867e 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -216,6 +216,27 @@ def build_image(self, image_spec: ImageSpec) -> Optional[str]: """ raise NotImplementedError("This method is not implemented in the base class.") + def should_build(self, image_spec: ImageSpec) -> bool: + """ + Whether or not the builder should build the ImageSpec. + + Args: + image_spec: image spec of the task. + + Returns: + True if the image should be built, otherwise it returns False. + """ + img_name = image_spec.image_name() + if not image_spec.exist(): + click.secho(f"Image {img_name} not found. building...", fg="blue") + return True + if image_spec._is_force_push: + click.secho(f"Image {img_name} found but overwriting existing image.", fg="blue") + return True + + click.secho(f"Image {img_name} found. Skip building.", fg="blue") + return False + class ImageBuildEngine: """ @@ -252,18 +273,11 @@ def build(cls, image_spec: ImageSpec): builder = image_spec.builder img_name = image_spec.image_name() - if image_spec.exist(): - if image_spec._is_force_push: - click.secho(f"Image {img_name} found. but overwriting existing image.", fg="blue") - cls._build_image(builder, image_spec, img_name) - else: - click.secho(f"Image {img_name} found. Skip building.", fg="blue") - else: - click.secho(f"Image {img_name} not found. building...", fg="blue") + if cls._get_builder(builder).should_build(image_spec): cls._build_image(builder, image_spec, img_name) @classmethod - def _build_image(cls, builder, image_spec, img_name): + def _get_builder(cls, builder: str) -> ImageSpecBuilder: if builder not in cls._REGISTRY: raise Exception(f"Builder {builder} is not registered.") if builder == "envd": @@ -275,7 +289,11 @@ def _build_image(cls, builder, image_spec, img_name): f"envd version {envd_version} is not compatible with flytekit>v1.10.2." f" Please upgrade envd to v0.3.39+." ) - fully_qualified_image_name = cls._REGISTRY[builder][0].build_image(image_spec) + return cls._REGISTRY[builder][0] + + @classmethod + def _build_image(cls, builder: str, image_spec: ImageSpec, img_name: str): + fully_qualified_image_name = cls._get_builder(builder).build_image(image_spec) if fully_qualified_image_name is not None: cls._IMAGE_NAME_TO_REAL_NAME[img_name] = fully_qualified_image_name diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py index 4a596c1e1e..011828d4ce 100644 --- a/tests/flytekit/unit/core/image_spec/test_image_spec.py +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -104,14 +104,14 @@ def test_image_spec_engine_priority(): def test_build_existing_image_with_force_push(): - image_spec = Mock() - image_spec.exist.return_value = True - image_spec._is_force_push = True + image_spec = ImageSpec(name="hello", builder="test").force_push() - ImageBuildEngine._build_image = Mock() + builder = Mock() + builder.build_image.return_value = "new_image_name" + ImageBuildEngine.register("test", builder) ImageBuildEngine.build(image_spec) - ImageBuildEngine._build_image.assert_called_once() + builder.build_image.assert_called_once() def test_custom_tag():