diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 27eaaa8fa3..2416bc7b10 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -1076,12 +1076,105 @@ def _is_nova_model_for_telemetry(self) -> bool: except Exception: return False + def _select_nova_hosting_config_entry(self, configs, instance_type, identifier): + """Select a single hosting config entry from a list of Nova configs. + + Picks the entry matching ``instance_type`` when provided, otherwise the + entry with ``Profile == "Default"`` (falling back to the first entry). + + Args: + configs: List of hosting config dicts. + instance_type: Requested instance type, or None. + identifier: Model identifier used for error messages. + + Returns: + The selected hosting config dict. + + Raises: + ValueError: If ``instance_type`` is provided but no entry matches it. + """ + if instance_type: + config = next( + (c for c in configs if c.get("InstanceType") == instance_type), None + ) + if not config: + supported = [c.get("InstanceType") for c in configs] + raise ValueError( + f"Instance type '{instance_type}' not supported for '{identifier}'. " + f"Supported: {supported}" + ) + return config + return next((c for c in configs if c.get("Profile") == "Default"), configs[0]) + + def _get_nova_hosting_config_from_hub_document(self, instance_type=None): + """Resolve Nova hosting config from the JumpStart hub document, if present. + + Reads hosting configs published in the hub content document, matching the + standard schema used by other custom models. Looks first inside the + ``RecipeCollection`` entry whose ``Name`` matches the recipe, then falls + back to the top-level ``HostingConfigs``. + + Returns: + A dict with ``image_uri``, ``env_vars``, and ``instance_type`` when a + usable hosting config is found, otherwise ``None``. + """ + try: + hub_document = self._fetch_hub_document_for_custom_model() + except Exception as e: # pragma: no cover - defensive, hub may be unavailable + logger.debug(f"Could not fetch hub document for Nova hosting config: {e}") + return None + + if not hub_document: + return None + + container = self._fetch_model_package().inference_specification.containers[0] + recipe_name = getattr(container.base_model, "recipe_name", None) or "" + + hosting_configs = None + for recipe in hub_document.get("RecipeCollection", []): + if recipe.get("Name") == recipe_name: + hosting_configs = recipe.get("HostingConfigs") + break + if not hosting_configs: + hosting_configs = hub_document.get("HostingConfigs") + + if not hosting_configs: + return None + + config = self._select_nova_hosting_config_entry( + hosting_configs, instance_type, recipe_name or "nova" + ) + + image_uri = config.get("EcrAddress") + if not image_uri: + # Hosting config present but no image override; let the hardcoded + # fallback supply the escrow image URI. + return None + + resolved_instance_type = config.get("InstanceType") or config.get( + "DefaultInstanceType" + ) + + return { + "image_uri": image_uri, + "env_vars": config.get("Environment", {}), + "instance_type": resolved_instance_type, + } + def _get_nova_hosting_config(self, instance_type=None): """Get Nova hosting config (image URI, env vars, instance type). - Nova training recipes don't have hosting configs in the JumpStart hub document. - This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs(). + Prefers hosting configs published in the JumpStart hub document (the + standard location used by other custom models). Falls back to the + hardcoded ``_NOVA_HOSTING_CONFIGS``, matching Rhinestone's + getNovaHostingConfigs(), when the hub document does not provide one. """ + hub_config = self._get_nova_hosting_config_from_hub_document( + instance_type=instance_type + ) + if hub_config: + return hub_config + model_package = self._fetch_model_package() hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name @@ -1102,16 +1195,9 @@ def _get_nova_hosting_config(self, instance_type=None): image_uri = f"{escrow_account}.dkr.ecr.{region}.amazonaws.com/nova-inference-repo:SM-Inference-latest" - if instance_type: - config = next((c for c in configs if c["InstanceType"] == instance_type), None) - if not config: - supported = [c["InstanceType"] for c in configs] - raise ValueError( - f"Instance type '{instance_type}' not supported for '{hub_content_name}'. " - f"Supported: {supported}" - ) - else: - config = next((c for c in configs if c.get("Profile") == "Default"), configs[0]) + config = self._select_nova_hosting_config_entry( + configs, instance_type, hub_content_name + ) return { "image_uri": image_uri, diff --git a/sagemaker-serve/tests/unit/test_nova_hosting_config.py b/sagemaker-serve/tests/unit/test_nova_hosting_config.py new file mode 100644 index 0000000000..6f174aa452 --- /dev/null +++ b/sagemaker-serve/tests/unit/test_nova_hosting_config.py @@ -0,0 +1,208 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for Nova hosting config resolution in ModelBuilder. + +Verifies that hosting configs published in the JumpStart hub document take +priority over the hardcoded ``_NOVA_HOSTING_CONFIGS`` fallback. +""" + +import unittest +from unittest.mock import MagicMock, patch + +from sagemaker.serve.model_builder import ModelBuilder + + +def _make_builder(region="us-east-1"): + """Create a ModelBuilder without running __init__.""" + mb = ModelBuilder.__new__(ModelBuilder) + mb.image_uri = None + mb.env_vars = None + mb.instance_type = None + session = MagicMock() + session.boto_region_name = region + mb.sagemaker_session = session + return mb + + +def _make_model_package(recipe_name="", hub_content_name="nova-textgeneration-lite"): + pkg = MagicMock() + base_model = MagicMock() + base_model.recipe_name = recipe_name + base_model.hub_content_name = hub_content_name + pkg.inference_specification.containers = [MagicMock(base_model=base_model)] + return pkg + + +class TestNovaHostingConfigResolution(unittest.TestCase): + """Tests for ModelBuilder._get_nova_hosting_config priority behavior.""" + + def test_hub_recipe_collection_config_takes_priority(self): + """Hosting config from RecipeCollection in the hub doc is preferred.""" + mb = _make_builder() + hub_doc = { + "RecipeCollection": [ + { + "Name": "my-nova-recipe", + "HostingConfigs": [ + { + "Profile": "Default", + "EcrAddress": "111.dkr.ecr.us-east-1.amazonaws.com/custom:tag", + "InstanceType": "ml.p5.48xlarge", + "Environment": { + "CONTEXT_LENGTH": "999", + "MAX_CONCURRENCY": "3", + }, + } + ], + } + ] + } + mp = _make_model_package( + recipe_name="my-nova-recipe", hub_content_name="nova-textgeneration-lite" + ) + with patch.object( + ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc + ), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp): + cfg = mb._get_nova_hosting_config() + + self.assertEqual( + cfg["image_uri"], "111.dkr.ecr.us-east-1.amazonaws.com/custom:tag" + ) + self.assertEqual( + cfg["env_vars"], {"CONTEXT_LENGTH": "999", "MAX_CONCURRENCY": "3"} + ) + self.assertEqual(cfg["instance_type"], "ml.p5.48xlarge") + + def test_top_level_hosting_configs_used_when_no_recipe_match(self): + """Top-level HostingConfigs is used when no RecipeCollection matches.""" + mb = _make_builder() + hub_doc = { + "HostingConfigs": [ + { + "Profile": "Default", + "EcrAddress": "222.dkr.ecr.us-east-1.amazonaws.com/top:tag", + "InstanceType": "ml.g6.24xlarge", + "Environment": {"CONTEXT_LENGTH": "100"}, + } + ] + } + mp = _make_model_package( + recipe_name="unmatched", hub_content_name="nova-textgeneration-micro" + ) + with patch.object( + ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc + ), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp): + cfg = mb._get_nova_hosting_config() + + self.assertEqual( + cfg["image_uri"], "222.dkr.ecr.us-east-1.amazonaws.com/top:tag" + ) + + def test_hardcoded_fallback_when_hub_has_no_hosting_config(self): + """Hardcoded escrow config is used when the hub doc has no hosting config.""" + mb = _make_builder() + mp = _make_model_package(hub_content_name="nova-textgeneration-lite") + with patch.object( + ModelBuilder, "_fetch_hub_document_for_custom_model", return_value={} + ), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp): + cfg = mb._get_nova_hosting_config() + + self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"]) + self.assertEqual(cfg["instance_type"], "ml.g6.48xlarge") + + def test_hardcoded_fallback_when_hub_fetch_raises(self): + """Hardcoded config is used defensively when hub fetch raises.""" + mb = _make_builder() + mp = _make_model_package(hub_content_name="nova-textgeneration-pro") + with patch.object( + ModelBuilder, + "_fetch_hub_document_for_custom_model", + side_effect=RuntimeError("hub unavailable"), + ), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp): + cfg = mb._get_nova_hosting_config() + + self.assertEqual(cfg["instance_type"], "ml.p5.48xlarge") + self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"]) + + def test_missing_ecr_address_falls_through_to_hardcoded(self): + """A hub hosting config without EcrAddress falls back to the escrow image.""" + mb = _make_builder() + hub_doc = { + "RecipeCollection": [ + { + "Name": "r", + "HostingConfigs": [ + {"Profile": "Default", "InstanceType": "ml.p5.48xlarge"} + ], + } + ] + } + mp = _make_model_package( + recipe_name="r", hub_content_name="nova-textgeneration-pro" + ) + with patch.object( + ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc + ), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp): + cfg = mb._get_nova_hosting_config() + + self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"]) + + def test_instance_type_match_in_hub_config(self): + """A requested instance type selects the matching hub config entry.""" + mb = _make_builder() + hub_doc = { + "RecipeCollection": [ + { + "Name": "r", + "HostingConfigs": [ + { + "Profile": "Default", + "EcrAddress": "333.dkr.ecr.us-east-1.amazonaws.com/a:tag", + "InstanceType": "ml.p5.48xlarge", + "Environment": {"CONTEXT_LENGTH": "1"}, + }, + { + "EcrAddress": "333.dkr.ecr.us-east-1.amazonaws.com/b:tag", + "InstanceType": "ml.g6.48xlarge", + "Environment": {"CONTEXT_LENGTH": "2"}, + }, + ], + } + ] + } + mp = _make_model_package( + recipe_name="r", hub_content_name="nova-textgeneration-lite" + ) + with patch.object( + ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc + ), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp): + cfg = mb._get_nova_hosting_config(instance_type="ml.g6.48xlarge") + + self.assertEqual( + cfg["image_uri"], "333.dkr.ecr.us-east-1.amazonaws.com/b:tag" + ) + self.assertEqual(cfg["instance_type"], "ml.g6.48xlarge") + + def test_unsupported_instance_type_raises(self): + """Requesting an unsupported instance type raises ValueError (fallback path).""" + mb = _make_builder() + mp = _make_model_package(hub_content_name="nova-textgeneration-pro") + with patch.object( + ModelBuilder, "_fetch_hub_document_for_custom_model", return_value={} + ), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp): + with self.assertRaises(ValueError): + mb._get_nova_hosting_config(instance_type="ml.invalid.type") + + +if __name__ == "__main__": + unittest.main()