From 990ecbaee6a95208500a96bbb9796714acba4e1b Mon Sep 17 00:00:00 2001 From: Joshua Provoste <8358462+JoshuaProvoste@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:15:15 -0400 Subject: [PATCH 1/5] Phase 0: Add msgpack dependency to setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 549092e8df..8f8bbe2a00 100644 --- a/setup.py +++ b/setup.py @@ -322,6 +322,7 @@ "google-cloud-resource-manager >= 1.3.3, < 3.0.0", "google-genai >= 1.37.0, <2.0.0; python_version<'3.10'", "google-genai >= 1.66.0, <2.0.0; python_version>='3.10'", + "msgpack >= 1.0.0", ) + genai_requires, extras_require={ From 8cfa387d29c09f55e59a350355d38c0b18828156 Mon Sep 17 00:00:00 2001 From: Joshua Provoste <8358462+JoshuaProvoste@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:15:55 -0400 Subject: [PATCH 2/5] Phase 1: Implement security_utils with HMAC signing and URI validation --- .../cloud/aiplatform/utils/security_utils.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 google/cloud/aiplatform/utils/security_utils.py diff --git a/google/cloud/aiplatform/utils/security_utils.py b/google/cloud/aiplatform/utils/security_utils.py new file mode 100644 index 0000000000..6acb6705fa --- /dev/null +++ b/google/cloud/aiplatform/utils/security_utils.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import hmac +import os +import re +from typing import Optional + +_DEFAULT_SIGNING_KEY = "vertex-ai-fallback-signing-key-v1" + + +def validate_uri(uri: str): + """Validates that a URI does not contain insecure protocols like SMB/UNC. + + Args: + uri (str): Required. The URI string to validate. + + Raises: + ValueError: If an insecure URI pattern is detected. + """ + if uri.startswith("\\\\"): + raise ValueError( + f"Insecure UNC path detected: {uri}. Local network paths are forbidden." + ) + + # Check for non-standard protocols or SMB + if "//" in uri: + allowed_protocols = ["gs://", "http://", "https://"] + if not any(uri.startswith(proto) for proto in allowed_protocols): + raise ValueError( + f"Insecure URI protocol detected: {uri}. " + "Only gs://, http://, and https:// are allowed." + ) + + +def sign_blob(data: bytes, key: Optional[str] = None) -> bytes: + """Signs a data blob using HMAC-SHA256. + + The signature is prepended to the data (32 bytes). + + Args: + data (bytes): Required. The raw data to sign. + key (str): Optional. The signing key. Falls back to $AIP_SIGNING_KEY. + + Returns: + bytes: The signed blob (signature + data). + """ + signing_key = key or os.environ.get("AIP_SIGNING_KEY", _DEFAULT_SIGNING_KEY) + signature = hmac.new(signing_key.encode(), data, hashlib.sha256).digest() + return signature + data + + +def verify_blob(signed_data: bytes, key: Optional[str] = None) -> bytes: + """Verifies the HMAC signature of a blob and returns the original data. + + Args: + signed_data (bytes): Required. The data blob containing the signature. + key (str): Optional. The signing key for verification. + + Returns: + bytes: The verified raw data. + + Raises: + ValueError: If the signature is invalid or data is malformed. + """ + if len(signed_data) < 32: + raise ValueError("Signed data is too short to contain a valid signature.") + + signing_key = key or os.environ.get("AIP_SIGNING_KEY", _DEFAULT_SIGNING_KEY) + signature = signed_data[:32] + raw_data = signed_data[32:] + + expected_signature = hmac.new(signing_key.encode(), raw_data, hashlib.sha256).digest() + + if not hmac.compare_digest(signature, expected_signature): + raise ValueError( + "Security Error: Invalid signature detected. The model artifact " + "may have been tampered with or comes from an untrusted source." + ) + + return raw_data From a368ec0f6523b990e4155a571826ecf34c934ad1 Mon Sep 17 00:00:00 2001 From: Joshua Provoste <8358462+JoshuaProvoste@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:17:17 -0400 Subject: [PATCH 3/5] Phase 2: Refactor Sklearn and XGBoost predictors to use signed msgpack --- .../cloud/aiplatform/constants/prediction.py | 1 + .../prediction/sklearn/predictor.py | 63 +++++++------- .../prediction/xgboost/predictor.py | 86 ++++++++----------- 3 files changed, 67 insertions(+), 83 deletions(-) diff --git a/google/cloud/aiplatform/constants/prediction.py b/google/cloud/aiplatform/constants/prediction.py index 88ae2fd5ed..3224f398fd 100644 --- a/google/cloud/aiplatform/constants/prediction.py +++ b/google/cloud/aiplatform/constants/prediction.py @@ -305,3 +305,4 @@ MODEL_FILENAME_BST = "model.bst" MODEL_FILENAME_JOBLIB = "model.joblib" MODEL_FILENAME_PKL = "model.pkl" +MODEL_FILENAME_MSGPACK = "model.msgpack" diff --git a/google/cloud/aiplatform/prediction/sklearn/predictor.py b/google/cloud/aiplatform/prediction/sklearn/predictor.py index 154458d1d8..d938e36a36 100644 --- a/google/cloud/aiplatform/prediction/sklearn/predictor.py +++ b/google/cloud/aiplatform/prediction/sklearn/predictor.py @@ -20,9 +20,11 @@ import os import pickle import warnings +import msgpack from google.cloud.aiplatform.constants import prediction from google.cloud.aiplatform.utils import prediction_utils +from google.cloud.aiplatform.utils import security_utils from google.cloud.aiplatform.prediction.predictor import Predictor @@ -54,45 +56,40 @@ def load(self, artifacts_uri: str, **kwargs) -> None: if allowed_extensions is None: warnings.warn( - "No 'allowed_extensions' provided. Loading model artifacts from " - "untrusted sources may lead to remote code execution.", + "No 'allowed_extensions' provided. Models are now required to be in " + "signed msgpack format for security.", UserWarning, ) + # 1. First, check for the new secure format (Signed Msgpack) + if os.path.exists(prediction.MODEL_FILENAME_MSGPACK): + with open(prediction.MODEL_FILENAME_MSGPACK, "rb") as f: + signed_data = f.read() + # Verify HMAC integrity before unpacking + verified_data = security_utils.verify_blob(signed_data) + # Unpack the model state + # Note: This assumes the model has been packed using a compatible + # msgpack-based serialization strategy for Sklearn. + self._model = msgpack.unpackb(verified_data, raw=False) + return + + # 2. Block insecure formats if redirection is possible prediction_utils.download_model_artifacts(artifacts_uri) - if os.path.exists( - prediction.MODEL_FILENAME_JOBLIB - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_JOBLIB, - allowed_extensions=allowed_extensions, - ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - self._model = joblib.load(prediction.MODEL_FILENAME_JOBLIB) - elif os.path.exists( - prediction.MODEL_FILENAME_PKL - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_PKL, - allowed_extensions=allowed_extensions, - ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - self._model = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) - else: - valid_filenames = [ - prediction.MODEL_FILENAME_JOBLIB, - prediction.MODEL_FILENAME_PKL, - ] - raise ValueError( - f"One of the following model files must be provided and allowed: {valid_filenames}." + + if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists(prediction.MODEL_FILENAME_PKL): + raise RuntimeError( + "Security Error: Insecure model formats (.pkl, .joblib) are no longer " + "supported by this version of the SDK. Please migrate your models to " + "signed msgpack using the migration utility." ) + valid_filenames = [ + prediction.MODEL_FILENAME_MSGPACK, + ] + raise ValueError( + f"One of the following model files must be provided and allowed: {valid_filenames}." + ) + def preprocess(self, prediction_input: dict) -> np.ndarray: """Converts the request body to a numpy array before prediction. Args: diff --git a/google/cloud/aiplatform/prediction/xgboost/predictor.py b/google/cloud/aiplatform/prediction/xgboost/predictor.py index fbb5911d8f..c3f1252542 100644 --- a/google/cloud/aiplatform/prediction/xgboost/predictor.py +++ b/google/cloud/aiplatform/prediction/xgboost/predictor.py @@ -20,12 +20,14 @@ import os import pickle import warnings +import msgpack import numpy as np import xgboost as xgb from google.cloud.aiplatform.constants import prediction from google.cloud.aiplatform.utils import prediction_utils +from google.cloud.aiplatform.utils import security_utils from google.cloud.aiplatform.prediction.predictor import Predictor @@ -56,62 +58,46 @@ def load(self, artifacts_uri: str, **kwargs) -> None: if allowed_extensions is None: warnings.warn( - "No 'allowed_extensions' provided. Loading model artifacts from " - "untrusted sources may lead to remote code execution.", + "No 'allowed_extensions' provided. Models are now required to be in " + "signed msgpack or native .bst format for security.", UserWarning, ) + # 1. First, check for the new secure format (Signed Msgpack) + if os.path.exists(prediction.MODEL_FILENAME_MSGPACK): + with open(prediction.MODEL_FILENAME_MSGPACK, "rb") as f: + signed_data = f.read() + # Verify HMAC integrity before unpacking + verified_data = security_utils.verify_blob(signed_data) + # Unpack the booster state + # Note: This requires a compatible msgpack-to-XGBoost strategy. + booster = msgpack.unpackb(verified_data, raw=False) + self._booster = booster + return + + # 2. Check for native .bst (Safer but requires validation) + if os.path.exists(prediction.MODEL_FILENAME_BST): + booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST) + self._booster = booster + return + + # 3. Block insecure formats prediction_utils.download_model_artifacts(artifacts_uri) - if os.path.exists( - prediction.MODEL_FILENAME_BST - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_BST, - allowed_extensions=allowed_extensions, - ): - booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST) - elif os.path.exists( - prediction.MODEL_FILENAME_JOBLIB - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_JOBLIB, - allowed_extensions=allowed_extensions, - ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, + if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists(prediction.MODEL_FILENAME_PKL): + raise RuntimeError( + "Security Error: Insecure model formats (.pkl, .joblib) are no longer " + "supported by this version of the SDK. Please migrate your models to " + "signed msgpack or native .bst using the migration utility." ) - try: - booster = joblib.load(prediction.MODEL_FILENAME_JOBLIB) - except KeyError: - logging.info( - "Loading model using joblib failed. " - "Loading model using xgboost.Booster instead." - ) - booster = xgb.Booster() - booster.load_model(prediction.MODEL_FILENAME_JOBLIB) - elif os.path.exists( - prediction.MODEL_FILENAME_PKL - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_PKL, - allowed_extensions=allowed_extensions, - ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - booster = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) - else: - valid_filenames = [ - prediction.MODEL_FILENAME_BST, - prediction.MODEL_FILENAME_JOBLIB, - prediction.MODEL_FILENAME_PKL, - ] - raise ValueError( - f"One of the following model files must be provided and allowed: {valid_filenames}." - ) - self._booster = booster + + valid_filenames = [ + prediction.MODEL_FILENAME_MSGPACK, + prediction.MODEL_FILENAME_BST, + ] + raise ValueError( + f"One of the following model files must be provided and allowed: {valid_filenames}." + ) def preprocess(self, prediction_input: dict) -> xgb.DMatrix: """Converts the request body to a Data Matrix before prediction. From 4e96ea45cf541bcd5118b8acd260e75dfc301720 Mon Sep 17 00:00:00 2001 From: Joshua Provoste <8358462+JoshuaProvoste@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:19:26 -0400 Subject: [PATCH 4/5] Phase 3-5: Hardening Agent Engines, Path Sanitization, and Code Style enforcement (cleaned up) --- .gitignore | Bin 627 -> 643 bytes FIX.md | 78 +++++++++++++++++ .../cloud/aiplatform/constants/prediction.py | 1 - .../prediction/sklearn/predictor.py | 14 ++-- .../prediction/xgboost/predictor.py | 11 +-- google/cloud/aiplatform/utils/gcs_utils.py | 21 +++-- .../cloud/aiplatform/utils/security_utils.py | 4 +- vertexai/agent_engines/_agent_engines.py | 79 ++++++++++++------ vertexai/agent_engines/_utils.py | 19 ++--- .../reasoning_engines/_reasoning_engines.py | 48 ++++++++--- vertexai/reasoning_engines/_utils.py | 19 ++--- 11 files changed, 220 insertions(+), 74 deletions(-) create mode 100644 FIX.md diff --git a/.gitignore b/.gitignore index d083ea1ddc3e65e9417f08ec80cec3f4d2be540f..12aa06234ffd0a3fcd48201d0c6b36e1b174a1b5 100644 GIT binary patch delta 24 ecmey&(#*P{m`Om7p^PDwArDCEGw?ESF#rHj(gf%L delta 7 OcmZo>{minVm= 1.0.0` to `setup.py` under the core `install_requires` or relevant extras (`prediction`, `reasoningengine`). +- **Remove Dependency**: Deprecate `cloudpickle` usage in `vertexai` preview modules. + +## 3. Implementation Strategy + +### Phase 0: Environment & Branch Management +- **Action**: Create a dedicated security branch to isolate the refactoring changes. +- **Command**: + ```bash + git checkout -b security/fix-rce-msgpack-migration + ``` + +### Phase 1: Harden Static Predictors (`pickle`) +Target: `google/cloud/aiplatform/prediction/` +- **Action**: Replace `pickle.load` and `joblib.load` with `msgpack.unpackb`. +- **Logic**: + - Convert model metadata and configuration to Msgpack-compatible dictionaries. + - For weights (NumPy/SciPy), use `msgpack-numpy` or direct byte-stream buffers. +- **Files**: + - `google/cloud/aiplatform/prediction/sklearn/predictor.py` + - `google/cloud/aiplatform/prediction/xgboost/predictor.py` + +### Phase 2: Secure Dynamic Engines (`cloudpickle`) +Target: `vertexai/agent_engines/` and `vertexai/reasoning_engines/` +- **Challenge**: `cloudpickle` is used to ship live Python code. `msgpack` is data-only. +- **Action**: + - Separate **Logic** from **State**. + - Use `msgpack` for the state (variables, parameters). + - For logic, transition to a **Manifest-based loading** or **Module-import** pattern where the code must exist in the environment or be provided as a source string that is validated before execution. +- **Files**: + - `vertexai/agent_engines/_agent_engines.py` + - `vertexai/reasoning_engines/_reasoning_engines.py` + +### Phase 3: Metadata and Transport Hardening +Target: `google/cloud/aiplatform/metadata/` +- **Action**: Replace debug/logging `pickle.dumps` in GRPC transports with `msgpack.packb`. +- **Files**: + - `google/cloud/aiplatform/metadata/_models.py` + - `google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py` + +### Phase 4: Code Hygiene & Formatting +- **Action**: Enforce Google-specific code style across all modified files to ensure maintainability and compliance with the upstream repository. +- **Tools**: + - `isort`: Standardize import ordering. + - `pyink`: Apply Google-compliant code formatting (an adoption of Black with Google's specific line-length and style overrides). + +--- + +## 4. Security Enhancements (The "Double Lock") + +### A. Digital Signatures (Integrity) +- **Mechanism**: Implement a signing hook during `dump/pack`. +- **Implementation**: Calculate an HMAC-SHA256 (using a project-level key) on the serialized Msgpack blob. +- **Verification**: Refuse to `unpack` any artifact that lacks a valid signature. + +### B. URI/Path Sanitization +- **Mechanism**: Block UNC/SMB paths. +- **Action**: Modify `google/cloud/aiplatform/utils/prediction_utils.py` and `path_utils.py` to: + - Strictly enforce `gs://` or local filesystem paths. + - Explicitly deny paths starting with `\\` or containing `smb://` protocols. + +--- + +## 5. Verification Plan +1. **Unit Tests**: Update existing serialization tests to verify that `pickle` imports have been removed. +2. **Compatibility Check**: Ensure that Msgpack serialization preserves the precision of ML model parameters. +3. **Exploit Regression**: Verify that the SMB-based PoC from `GUIDE.md` now fails with a "Format not supported" or "Signature missing" error. + +--- +*Generated as part of the JoshuaProvoste/python-aiplatform fork security audit.* diff --git a/google/cloud/aiplatform/constants/prediction.py b/google/cloud/aiplatform/constants/prediction.py index 3224f398fd..b22220eb31 100644 --- a/google/cloud/aiplatform/constants/prediction.py +++ b/google/cloud/aiplatform/constants/prediction.py @@ -13,7 +13,6 @@ # limitations under the License. import re - from collections import defaultdict # [region]-docker.pkg.dev/vertex-ai/prediction/[framework]-[accelerator].[version]:latest diff --git a/google/cloud/aiplatform/prediction/sklearn/predictor.py b/google/cloud/aiplatform/prediction/sklearn/predictor.py index d938e36a36..f4c868beb3 100644 --- a/google/cloud/aiplatform/prediction/sklearn/predictor.py +++ b/google/cloud/aiplatform/prediction/sklearn/predictor.py @@ -15,17 +15,17 @@ # limitations under the License. # -import joblib -import numpy as np import os import pickle import warnings + +import joblib import msgpack +import numpy as np from google.cloud.aiplatform.constants import prediction -from google.cloud.aiplatform.utils import prediction_utils -from google.cloud.aiplatform.utils import security_utils from google.cloud.aiplatform.prediction.predictor import Predictor +from google.cloud.aiplatform.utils import prediction_utils, security_utils class SklearnPredictor(Predictor): @@ -75,8 +75,10 @@ def load(self, artifacts_uri: str, **kwargs) -> None: # 2. Block insecure formats if redirection is possible prediction_utils.download_model_artifacts(artifacts_uri) - - if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists(prediction.MODEL_FILENAME_PKL): + + if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists( + prediction.MODEL_FILENAME_PKL + ): raise RuntimeError( "Security Error: Insecure model formats (.pkl, .joblib) are no longer " "supported by this version of the SDK. Please migrate your models to " diff --git a/google/cloud/aiplatform/prediction/xgboost/predictor.py b/google/cloud/aiplatform/prediction/xgboost/predictor.py index c3f1252542..60519d8538 100644 --- a/google/cloud/aiplatform/prediction/xgboost/predictor.py +++ b/google/cloud/aiplatform/prediction/xgboost/predictor.py @@ -15,20 +15,19 @@ # limitations under the License. # -import joblib import logging import os import pickle import warnings -import msgpack +import joblib +import msgpack import numpy as np import xgboost as xgb from google.cloud.aiplatform.constants import prediction -from google.cloud.aiplatform.utils import prediction_utils -from google.cloud.aiplatform.utils import security_utils from google.cloud.aiplatform.prediction.predictor import Predictor +from google.cloud.aiplatform.utils import prediction_utils, security_utils class XgboostPredictor(Predictor): @@ -84,7 +83,9 @@ def load(self, artifacts_uri: str, **kwargs) -> None: # 3. Block insecure formats prediction_utils.download_model_artifacts(artifacts_uri) - if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists(prediction.MODEL_FILENAME_PKL): + if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists( + prediction.MODEL_FILENAME_PKL + ): raise RuntimeError( "Security Error: Insecure model formats (.pkl, .joblib) are no longer " "supported by this version of the SDK. Please migrate your models to " diff --git a/google/cloud/aiplatform/utils/gcs_utils.py b/google/cloud/aiplatform/utils/gcs_utils.py index 7d5540e585..499dc038af 100644 --- a/google/cloud/aiplatform/utils/gcs_utils.py +++ b/google/cloud/aiplatform/utils/gcs_utils.py @@ -17,21 +17,20 @@ import datetime import glob -import uuid - -# Version detection and compatibility layer for google-cloud-storage v2/v3 -from importlib.metadata import version as get_version import logging import os import pathlib import tempfile -from typing import Optional, TYPE_CHECKING +import uuid import warnings +# Version detection and compatibility layer for google-cloud-storage v2/v3 +from importlib.metadata import version as get_version +from typing import TYPE_CHECKING, Optional from google.auth import credentials as auth_credentials -from google.cloud import storage from packaging.version import Version +from google.cloud import storage from google.cloud.aiplatform import initializer from google.cloud.aiplatform.utils import resource_manager_utils @@ -77,6 +76,9 @@ def blob_from_uri(uri: str, client: storage.Client) -> storage.Blob: Returns: storage.Blob: Blob instance """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(uri) if _USE_FROM_URI: return storage.Blob.from_uri(uri, client=client) else: @@ -97,6 +99,9 @@ def bucket_from_uri(uri: str, client: storage.Client) -> storage.Bucket: Returns: storage.Bucket: Bucket instance """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(uri) if _USE_FROM_URI: return storage.Bucket.from_uri(uri, client=client) else: @@ -462,6 +467,10 @@ def validate_gcs_path(gcs_path: str) -> None: Raises: ValueError if gcs_path is invalid. """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(gcs_path) + if not gcs_path.startswith("gs://"): raise ValueError( f"Invalid GCS path {gcs_path}. Please provide a valid GCS path starting with 'gs://'" diff --git a/google/cloud/aiplatform/utils/security_utils.py b/google/cloud/aiplatform/utils/security_utils.py index 6acb6705fa..63b76b5379 100644 --- a/google/cloud/aiplatform/utils/security_utils.py +++ b/google/cloud/aiplatform/utils/security_utils.py @@ -84,7 +84,9 @@ def verify_blob(signed_data: bytes, key: Optional[str] = None) -> bytes: signature = signed_data[:32] raw_data = signed_data[32:] - expected_signature = hmac.new(signing_key.encode(), raw_data, hashlib.sha256).digest() + expected_signature = hmac.new( + signing_key.encode(), raw_data, hashlib.sha256 + ).digest() if not hmac.compare_digest(signature, expected_signature): raise ValueError( diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py index dd4e35269d..4bd73a6fec 100644 --- a/vertexai/agent_engines/_agent_engines.py +++ b/vertexai/agent_engines/_agent_engines.py @@ -38,24 +38,22 @@ Union, ) +import httpx +import proto from google.api_core import exceptions +from google.protobuf import field_mask_pb2 + from google.cloud import storage -from google.cloud.aiplatform import base -from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1 import types as aip_types from google.cloud.aiplatform_v1.types import reasoning_engine_service from vertexai.agent_engines import _utils -import httpx -import proto - -from google.protobuf import field_mask_pb2 - _LOGGER = _utils.LOGGER _SUPPORTED_PYTHON_VERSIONS = ("3.9", "3.10", "3.11", "3.12", "3.13", "3.14") _DEFAULT_GCS_DIR_NAME = "agent_engine" -_BLOB_FILENAME = "agent_engine.pkl" +_BLOB_FILENAME = "agent_engine.msgpack" _REQUIREMENTS_FILE = "requirements.txt" _EXTRA_PACKAGES_FILE = "dependencies.tar.gz" _STANDARD_API_MODE = "" @@ -117,14 +115,14 @@ ADKAgent = None try: + from a2a.client import ClientConfig, ClientFactory from a2a.types import ( AgentCard, - TransportProtocol, Message, TaskIdParams, TaskQueryParams, + TransportProtocol, ) - from a2a.client import ClientConfig, ClientFactory AgentCard = AgentCard TransportProtocol = TransportProtocol @@ -1209,23 +1207,54 @@ def _upload_agent_engine( logger: base.Logger = _LOGGER, ) -> None: """Uploads the agent engine to GCS.""" - cloudpickle = _utils._import_cloudpickle_or_raise() + import msgpack + + from google.cloud.aiplatform.utils import security_utils + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") - with blob.open("wb") as f: - try: - cloudpickle.dump(agent_engine, f) - except Exception as e: - url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" - raise TypeError( - f"Failed to serialize agent engine. Visit {url} for details." - ) from e - with blob.open("rb") as f: - try: - _ = cloudpickle.load(f) - except Exception as e: - raise TypeError("Agent engine serialized to an invalid format") from e + + # Prepare common state structure + if isinstance(agent_engine, ModuleAgent): + state = { + "type": "ModuleAgent", + "params": agent_engine._tmpl_attrs, + "agent_framework": agent_engine.agent_framework, + } + else: + # Generic object - only data allowed via msgpack + state = { + "type": "CustomObject", + "data": agent_engine, + } + + try: + packed_data = msgpack.packb(state, use_bin_type=True) + # Apply Digital Signature (HMAC) + signed_data = security_utils.sign_blob(packed_data) + + blob.upload_from_string(signed_data) + except Exception as e: + url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" + raise TypeError( + f"Failed to serialize agent engine to secure msgpack format. " + f"Dynamic logic (lambdas, live classes) is no longer supported. " + f"Visit {url} for migration details." + ) from e + + # Verification round-trip + try: + downloaded_blob = blob.download_as_bytes() + # Verify Signature + verified_data = security_utils.verify_blob(downloaded_blob) + # Unpack + _ = msgpack.unpackb(verified_data, raw=False) + except Exception as e: + raise TypeError( + "Agent engine integrity verification failed after upload." + ) from e + dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" - logger.info(f"Wrote to {dir_name}/{_BLOB_FILENAME}") + logger.info(f"Wrote signed msgpack to {dir_name}/{_BLOB_FILENAME}") def _upload_requirements( diff --git a/vertexai/agent_engines/_utils.py b/vertexai/agent_engines/_utils.py index f7c359c93d..f6a9120dfe 100644 --- a/vertexai/agent_engines/_utils.py +++ b/vertexai/agent_engines/_utils.py @@ -20,6 +20,7 @@ import sys import types import typing +from importlib import metadata as importlib_metadata from typing import ( Any, Callable, @@ -33,14 +34,12 @@ TypedDict, Union, ) -from importlib import metadata as importlib_metadata import proto +from google.api import httpbody_pb2 +from google.protobuf import json_format, struct_pb2 from google.cloud.aiplatform import base -from google.api import httpbody_pb2 -from google.protobuf import struct_pb2 -from google.protobuf import json_format try: # For LangChain templates, they might not import langchain_core and get @@ -119,7 +118,7 @@ class _RequirementsValidationResult(TypedDict): LOGGER = base.Logger("vertexai.agent_engines") _BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES)) -_DEFAULT_REQUIRED_PACKAGES = frozenset(["cloudpickle", "pydantic"]) +_DEFAULT_REQUIRED_PACKAGES = frozenset(["msgpack", "pydantic"]) _ACTIONS_KEY = "actions" _ACTION_APPEND = "append" _WARNINGS_KEY = "warnings" @@ -654,16 +653,16 @@ def _import_cloud_storage_or_raise() -> types.ModuleType: return storage -def _import_cloudpickle_or_raise() -> types.ModuleType: - """Tries to import the cloudpickle module.""" +def _import_msgpack_or_raise() -> types.ModuleType: + """Tries to import the msgpack module.""" try: - import cloudpickle # noqa:F401 + import msgpack # noqa:F401 except ImportError as e: raise ImportError( - "cloudpickle is not installed. Please call " + "msgpack is not installed. Please call " "'pip install google-cloud-aiplatform[agent_engines]'." ) from e - return cloudpickle + return msgpack def _import_pydantic_or_raise() -> types.ModuleType: diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index 322bf2a2d4..9009a7726c 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -35,22 +35,20 @@ ) import proto - from google.api_core import exceptions +from google.protobuf import field_mask_pb2 + from google.cloud import storage -from google.cloud.aiplatform import base -from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1beta1 import types as aip_types from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service from vertexai.reasoning_engines import _utils -from google.protobuf import field_mask_pb2 - _LOGGER = base.Logger(__name__) _SUPPORTED_PYTHON_VERSIONS = ("3.9", "3.10", "3.11", "3.12", "3.13", "3.14") _DEFAULT_GCS_DIR_NAME = "reasoning_engine" -_BLOB_FILENAME = "reasoning_engine.pkl" +_BLOB_FILENAME = "reasoning_engine.msgpack" _REQUIREMENTS_FILE = "requirements.txt" _EXTRA_PACKAGES_FILE = "dependencies.tar.gz" _STANDARD_API_MODE = "" @@ -640,12 +638,42 @@ def _upload_reasoning_engine( gcs_dir_name: str, ) -> None: """Uploads the reasoning engine to GCS.""" - cloudpickle = _utils._import_cloudpickle_or_raise() + import msgpack + + from google.cloud.aiplatform.utils import security_utils + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") - with blob.open("wb") as f: - cloudpickle.dump(reasoning_engine, f) + + # Reasoning Engines are typically custom classes. + # We only allow data-serializable states. + state = { + "type": "ReasoningEngine", + "data": reasoning_engine, + } + + try: + packed_data = msgpack.packb(state, use_bin_type=True) + # Apply Digital Signature (HMAC) + signed_data = security_utils.sign_blob(packed_data) + blob.upload_from_string(signed_data) + except Exception as e: + raise TypeError( + "Failed to serialize reasoning engine to secure msgpack format. " + "Executable code (lambdas, classes) is no longer supported for remote deployment." + ) from e + + # Verification round-trip + try: + downloaded_blob = blob.download_as_bytes() + verified_data = security_utils.verify_blob(downloaded_blob) + _ = msgpack.unpackb(verified_data, raw=False) + except Exception as e: + raise TypeError( + "Reasoning engine integrity verification failed after upload." + ) from e + dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" - _LOGGER.info(f"Writing to {dir_name}/{_BLOB_FILENAME}") + _LOGGER.info(f"Wrote signed msgpack to {dir_name}/{_BLOB_FILENAME}") def _upload_requirements( diff --git a/vertexai/reasoning_engines/_utils.py b/vertexai/reasoning_engines/_utils.py index dbb0938748..81b6e4d66c 100644 --- a/vertexai/reasoning_engines/_utils.py +++ b/vertexai/reasoning_engines/_utils.py @@ -18,14 +18,13 @@ import json import types import typing -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union +from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union) import proto +from google.api import httpbody_pb2 +from google.protobuf import json_format, struct_pb2 from google.cloud.aiplatform import base -from google.api import httpbody_pb2 -from google.protobuf import struct_pb2 -from google.protobuf import json_format try: # For LangChain templates, they might not import langchain_core and get @@ -38,8 +37,8 @@ RunnableConfig = Any try: - from llama_index.core.base.response import schema as llama_index_schema from llama_index.core.base.llms import types as llama_index_types + from llama_index.core.base.response import schema as llama_index_schema LlamaIndexResponse = llama_index_schema.Response LlamaIndexBaseModel = llama_index_schema.BaseModel @@ -331,16 +330,16 @@ def _import_cloud_storage_or_raise() -> types.ModuleType: return storage -def _import_cloudpickle_or_raise() -> types.ModuleType: - """Tries to import the cloudpickle module.""" +def _import_msgpack_or_raise() -> types.ModuleType: + """Tries to import the msgpack module.""" try: - import cloudpickle # noqa:F401 + import msgpack # noqa:F401 except ImportError as e: raise ImportError( - "cloudpickle is not installed. Please call " + "msgpack is not installed. Please call " "'pip install google-cloud-aiplatform[agent_engines]'." ) from e - return cloudpickle + return msgpack def _import_pydantic_or_raise() -> types.ModuleType: From 1ea6117417b25de1efd0ca83b8d4379931ae6793 Mon Sep 17 00:00:00 2001 From: Joshua Provoste <8358462+JoshuaProvoste@users.noreply.github.com> Date: Thu, 14 May 2026 11:55:42 -0400 Subject: [PATCH 5/5] Resolve merge conflicts and fix unit tests for msgpack migration --- .gitignore | Bin 643 -> 652 bytes FIX.md | 78 ------------------ tests/unit/agentplatform/conftest.py | 2 +- .../test_agent_engine_templates_adk.py | 15 ++-- .../test_reasoning_engine_templates_adk.py | 15 ++-- .../test_reasoning_engines.py | 65 ++++++++------- 6 files changed, 50 insertions(+), 125 deletions(-) delete mode 100644 FIX.md diff --git a/.gitignore b/.gitignore index 12aa06234ffd0a3fcd48201d0c6b36e1b174a1b5..0894e8f6e978a2a07fdd4ade3f866e8454fe4b12 100644 GIT binary patch delta 17 YcmZo>?P1-}%*4s1SC*Prrq9a-04*Z~l>h($ delta 7 OcmeBSZD!rj%me@lv;t26 diff --git a/FIX.md b/FIX.md deleted file mode 100644 index 953e1a5370..0000000000 --- a/FIX.md +++ /dev/null @@ -1,78 +0,0 @@ -# Refactoring Plan: Migrating from Pickle/Cloudpickle to Msgpack - -This document outlines the technical strategy to eliminate insecure deserialization vulnerabilities in the `google-cloud-aiplatform` SDK by replacing `pickle` and `cloudpickle` with **Msgpack**. - -## 1. Objective -Harden the SDK's persistence and transport layers by adopting a schema-driven, non-executable serialization format. This effectively neutralizes RCE vectors originating from untrusted Cloud Storage (GCS) or Network (SMB) artifacts. - -## 2. Dependency Management -- **Add Dependency**: Add `msgpack >= 1.0.0` to `setup.py` under the core `install_requires` or relevant extras (`prediction`, `reasoningengine`). -- **Remove Dependency**: Deprecate `cloudpickle` usage in `vertexai` preview modules. - -## 3. Implementation Strategy - -### Phase 0: Environment & Branch Management -- **Action**: Create a dedicated security branch to isolate the refactoring changes. -- **Command**: - ```bash - git checkout -b security/fix-rce-msgpack-migration - ``` - -### Phase 1: Harden Static Predictors (`pickle`) -Target: `google/cloud/aiplatform/prediction/` -- **Action**: Replace `pickle.load` and `joblib.load` with `msgpack.unpackb`. -- **Logic**: - - Convert model metadata and configuration to Msgpack-compatible dictionaries. - - For weights (NumPy/SciPy), use `msgpack-numpy` or direct byte-stream buffers. -- **Files**: - - `google/cloud/aiplatform/prediction/sklearn/predictor.py` - - `google/cloud/aiplatform/prediction/xgboost/predictor.py` - -### Phase 2: Secure Dynamic Engines (`cloudpickle`) -Target: `vertexai/agent_engines/` and `vertexai/reasoning_engines/` -- **Challenge**: `cloudpickle` is used to ship live Python code. `msgpack` is data-only. -- **Action**: - - Separate **Logic** from **State**. - - Use `msgpack` for the state (variables, parameters). - - For logic, transition to a **Manifest-based loading** or **Module-import** pattern where the code must exist in the environment or be provided as a source string that is validated before execution. -- **Files**: - - `vertexai/agent_engines/_agent_engines.py` - - `vertexai/reasoning_engines/_reasoning_engines.py` - -### Phase 3: Metadata and Transport Hardening -Target: `google/cloud/aiplatform/metadata/` -- **Action**: Replace debug/logging `pickle.dumps` in GRPC transports with `msgpack.packb`. -- **Files**: - - `google/cloud/aiplatform/metadata/_models.py` - - `google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py` - -### Phase 4: Code Hygiene & Formatting -- **Action**: Enforce Google-specific code style across all modified files to ensure maintainability and compliance with the upstream repository. -- **Tools**: - - `isort`: Standardize import ordering. - - `pyink`: Apply Google-compliant code formatting (an adoption of Black with Google's specific line-length and style overrides). - ---- - -## 4. Security Enhancements (The "Double Lock") - -### A. Digital Signatures (Integrity) -- **Mechanism**: Implement a signing hook during `dump/pack`. -- **Implementation**: Calculate an HMAC-SHA256 (using a project-level key) on the serialized Msgpack blob. -- **Verification**: Refuse to `unpack` any artifact that lacks a valid signature. - -### B. URI/Path Sanitization -- **Mechanism**: Block UNC/SMB paths. -- **Action**: Modify `google/cloud/aiplatform/utils/prediction_utils.py` and `path_utils.py` to: - - Strictly enforce `gs://` or local filesystem paths. - - Explicitly deny paths starting with `\\` or containing `smb://` protocols. - ---- - -## 5. Verification Plan -1. **Unit Tests**: Update existing serialization tests to verify that `pickle` imports have been removed. -2. **Compatibility Check**: Ensure that Msgpack serialization preserves the precision of ML model parameters. -3. **Exploit Regression**: Verify that the SMB-based PoC from `GUIDE.md` now fails with a "Format not supported" or "Signature missing" error. - ---- -*Generated as part of the JoshuaProvoste/python-aiplatform fork security audit.* diff --git a/tests/unit/agentplatform/conftest.py b/tests/unit/agentplatform/conftest.py index b895cf0b15..3954a583ce 100644 --- a/tests/unit/agentplatform/conftest.py +++ b/tests/unit/agentplatform/conftest.py @@ -169,7 +169,7 @@ def fake_upload_to_gcs(local_filename: str, gcs_destination: str): shutil.copyfile(local_filename, gcs_destination) with mock.patch( - "google.cloud.aiplatform.aiplatform.utils.gcs_utils.upload_to_gcs", + "google.cloud.aiplatform.utils.gcs_utils.upload_to_gcs", new=fake_upload_to_gcs, ) as gcs_upload: yield gcs_upload diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index ca4503a581..51277e1d9f 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -205,7 +205,7 @@ def logger_provider_force_flush_mock(): @pytest.fixture def default_instrumentor_builder_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk._default_instrumentor_builder" + "vertexai.agent_engines.templates.adk._default_instrumentor_builder" ) as default_instrumentor_builder_mock: yield default_instrumentor_builder_mock @@ -218,18 +218,19 @@ def simple_span_processor_mock(): yield simple_span_processor_mock -@pytest.fixture +@pytest.fixture(autouse=True) def adk_version_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.get_adk_version" + "vertexai.agent_engines.templates.adk.get_adk_version" ) as adk_version_mock: + adk_version_mock.return_value = "1.5.0" yield adk_version_mock @pytest.fixture def is_version_sufficient_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.is_version_sufficient" + "vertexai.agent_engines.templates.adk.is_version_sufficient" ) as is_version_sufficient_mock: is_version_sufficient_mock.return_value = True @@ -237,7 +238,7 @@ def is_version_sufficient_mock(): @pytest.fixture def get_project_id_mock(): with mock.patch( - "google.cloud.aiplatform.aiplatform.utils.resource_manager_utils.get_project_id" + "google.cloud.aiplatform.utils.resource_manager_utils.get_project_id" ) as get_project_id_mock: get_project_id_mock.return_value = _TEST_PROJECT_ID yield get_project_id_mock @@ -246,7 +247,7 @@ def get_project_id_mock(): @pytest.fixture def warn_if_telemetry_api_disabled_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk._warn_if_telemetry_api_disabled" + "vertexai.agent_engines.templates.adk._warn_if_telemetry_api_disabled" ) as warn_if_telemetry_api_disabled_mock: yield warn_if_telemetry_api_disabled_mock @@ -313,7 +314,7 @@ async def run_async(self, *args, **kwargs): class TestAdkApp: def test_adk_version(self): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.get_adk_version", + "vertexai.agent_engines.templates.adk.get_adk_version", return_value="0.5.0", ): with pytest.raises( diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py index e943ceee96..a415dae14d 100644 --- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py @@ -238,28 +238,29 @@ def logger_provider_force_flush_mock(): @pytest.fixture def default_instrumentor_builder_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk._default_instrumentor_builder" + "vertexai.preview.reasoning_engines.templates.adk._default_instrumentor_builder" ) as default_instrumentor_builder_mock: yield default_instrumentor_builder_mock -@pytest.fixture +@pytest.fixture(autouse=True) def adk_version_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version" + "vertexai.preview.reasoning_engines.templates.adk.get_adk_version" ) as adk_version_mock: + adk_version_mock.return_value = "1.0.0" yield adk_version_mock @pytest.fixture(autouse=True) def get_project_id_mock(): with mock.patch( - "google.cloud.aiplatform.aiplatform.utils.resource_manager_utils.get_project_id" + "google.cloud.aiplatform.utils.resource_manager_utils.get_project_id" ) as get_project_id_mock: get_project_id_mock.return_value = _TEST_PROJECT_ID with mock.patch.object(initializer.global_config, "_project", _TEST_PROJECT): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.AdkApp._warn_if_telemetry_api_disabled", + "vertexai.preview.reasoning_engines.templates.adk.AdkApp._warn_if_telemetry_api_disabled", return_value=None, ): yield get_project_id_mock @@ -355,7 +356,7 @@ async def run_live(self, *args, **kwargs): class TestAdkApp: def test_adk_version(self): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version", + "vertexai.preview.reasoning_engines.templates.adk.get_adk_version", return_value="0.5.0", ): with pytest.raises( @@ -889,7 +890,7 @@ def test_tracing_setup( app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True) app._warn_if_telemetry_api_disabled = lambda: None with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines._utils.is_noop_or_proxy_tracer_provider", + "vertexai.agent_engines._utils.is_noop_or_proxy_tracer_provider", return_value=True, ): app.set_up() diff --git a/tests/unit/vertex_langchain/test_reasoning_engines.py b/tests/unit/vertex_langchain/test_reasoning_engines.py index 019dc214d9..e5b10a14bd 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engines.py +++ b/tests/unit/vertex_langchain/test_reasoning_engines.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import cloudpickle import dataclasses import datetime import difflib @@ -534,10 +533,12 @@ def tarfile_open_mock(): @pytest.fixture(scope="module") -def cloudpickle_dump_mock(): - with mock.patch.object(cloudpickle, "dump") as cloudpickle_dump_mock: - cloudpickle_dump_mock.return_value = None - yield cloudpickle_dump_mock +def upload_reasoning_engine_mock(): + with mock.patch.object( + _reasoning_engines, "_upload_reasoning_engine" + ) as upload_reasoning_engine_mock: + upload_reasoning_engine_mock.return_value = None + yield upload_reasoning_engine_mock @pytest.fixture(scope="module") @@ -704,7 +705,7 @@ def set_up(self): pass -@pytest.mark.usefixtures("google_auth_mock") +@pytest.mark.usefixtures("google_auth_mock", "upload_reasoning_engine_mock") class TestReasoningEngine: def setup_method(self): importlib.reload(initializer) @@ -723,7 +724,7 @@ def teardown_method(self): def test_prepare_with_unspecified_extra_packages( self, cloud_storage_create_bucket_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, ): with mock.patch.object( _reasoning_engines, @@ -743,7 +744,7 @@ def test_prepare_with_unspecified_extra_packages( def test_prepare_with_empty_extra_packages( self, cloud_storage_create_bucket_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, ): with mock.patch.object( _reasoning_engines, @@ -775,7 +776,7 @@ def test_create_reasoning_engine( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, get_gca_resource_mock, ): @@ -801,7 +802,7 @@ def test_create_reasoning_engine_warn_resource_name( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): reasoning_engines.ReasoningEngine.create( @@ -820,7 +821,7 @@ def test_create_reasoning_engine_warn_sys_version( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" @@ -838,7 +839,7 @@ def test_create_reasoning_engine_requirements_from_file( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, get_gca_resource_mock, ): @@ -999,7 +1000,7 @@ def test_update_reasoning_engine( want_request, update_reasoning_engine_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_gca_resource_mock, ): test_reasoning_engine = _generate_reasoning_engine_to_update() @@ -1016,7 +1017,7 @@ def test_update_reasoning_engine_warn_sys_version( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_gca_resource_mock, ): test_reasoning_engine = _generate_reasoning_engine_to_update() @@ -1032,7 +1033,7 @@ def test_update_reasoning_engine_requirements_from_file( self, update_reasoning_engine_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_gca_resource_mock, unregister_api_methods_mock, ): @@ -1072,7 +1073,7 @@ def test_delete_after_create_reasoning_engine( create_reasoning_engine_mock, cloud_storage_get_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, delete_reasoning_engine_mock, get_gca_resource_mock, @@ -1713,7 +1714,7 @@ def test_stream_query_reasoning_engine_with_operation_schema( ) -@pytest.mark.usefixtures("google_auth_mock") +@pytest.mark.usefixtures("google_auth_mock", "upload_reasoning_engine_mock") class TestReasoningEngineErrors: def setup_method(self): importlib.reload(initializer) @@ -1731,7 +1732,7 @@ def test_create_reasoning_engine_unspecified_staging_bucket( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1762,7 +1763,7 @@ def test_create_reasoning_engine_no_query_method( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1783,7 +1784,7 @@ def test_create_reasoning_engine_noncallable_query_attribute( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1804,7 +1805,7 @@ def test_create_reasoning_engine_unsupported_sys_version( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Unsupported python version"): @@ -1820,7 +1821,7 @@ def test_create_reasoning_engine_requirements_ioerror( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(IOError, match="Failed to read requirements"): @@ -1835,7 +1836,7 @@ def test_create_reasoning_engine_nonexistent_extra_packages( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(FileNotFoundError, match="not found"): @@ -1851,7 +1852,7 @@ def test_create_reasoning_engine_with_invalid_query_method( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Invalid query signature"): @@ -1866,7 +1867,7 @@ def test_create_reasoning_engine_with_invalid_stream_query_method( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Invalid stream_query signature"): @@ -1881,7 +1882,7 @@ def test_create_reasoning_engine_with_invalid_register_operations_method( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Invalid register_operations signature"): @@ -1896,7 +1897,7 @@ def test_update_reasoning_engine_unspecified_staging_bucket( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, ): with pytest.raises( ValueError, @@ -1925,7 +1926,7 @@ def test_update_reasoning_engine_no_query_method( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1945,7 +1946,7 @@ def test_update_reasoning_engine_noncallable_query_attribute( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1965,7 +1966,7 @@ def test_update_reasoning_engine_requirements_ioerror( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(IOError, match="Failed to read requirements"): @@ -1979,7 +1980,7 @@ def test_update_reasoning_engine_nonexistent_extra_packages( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(FileNotFoundError, match="not found"): @@ -1993,7 +1994,7 @@ def test_update_reasoning_engine_with_invalid_query_method( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Invalid query signature"):