diff --git a/pymongo/message.py b/pymongo/message.py index fdac2b4daa..2354539623 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -171,14 +171,6 @@ def _convert_write_result( elif operation == "update": if "upserted" in result: res["upserted"] = [{"index": 0, "_id": result["upserted"]}] - # Versions of MongoDB before 2.6 don't return the _id for an - # upsert if _id is not an ObjectId. - elif result.get("updatedExisting") is False and affected == 1: - # If _id is in both the update document *and* the query spec - # the update document _id takes precedence. - update = command["updates"][0] - _id = update["u"].get("_id", update["q"].get("_id")) - res["upserted"] = [{"index": 0, "_id": _id}] return res @@ -660,7 +652,7 @@ def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> N else: # There's nothing intelligent we can say # about size for update and delete - raise DocumentTooLarge(f"{operation!r} command document too large") + raise DocumentTooLarge(f"{operation} command document too large") # From the Client Side Encryption spec: diff --git a/test/test_message.py b/test/test_message.py new file mode 100644 index 0000000000..fb79920338 --- /dev/null +++ b/test/test_message.py @@ -0,0 +1,451 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for message.py.""" + +from __future__ import annotations + +import struct +import sys +from unittest.mock import MagicMock + +sys.path[0:0] = [""] + +from test import unittest + +from bson import CodecOptions, encode +from pymongo.compression_support import ZlibContext, _have_zlib +from pymongo.errors import DocumentTooLarge, OperationFailure +from pymongo.message import ( + _convert_client_bulk_exception, + _convert_exception, + _convert_write_result, + _gen_find_command, + _gen_get_more_command, + _maybe_add_read_preference, + _op_msg, + _raise_document_too_large, +) +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference, SecondaryPreferred + +_OPTS = CodecOptions() + + +class TestMessage(unittest.TestCase): + # _gen_get_more_command helper + def _make_conn(self, max_wire_version=9): + conn = MagicMock() + conn.max_wire_version = max_wire_version + return conn + + # _maybe_add_read_preference + + def test_primary_no_read_preference_added(self): + spec: dict = {"find": "col"} + result = _maybe_add_read_preference(spec, ReadPreference.PRIMARY) + self.assertNotIn("$readPreference", result) + self.assertNotIn("$query", result) + + def test_secondary_adds_read_preference(self): + spec: dict = {"find": "col"} + result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY) + self.assertIn("$readPreference", result) + self.assertEqual(result["$readPreference"]["mode"], "secondary") + self.assertIn("$query", result) + + def test_secondary_preferred_no_tags_does_not_add(self): + spec: dict = {"find": "col"} + result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY_PREFERRED) + self.assertNotIn("$readPreference", result) + + def test_secondary_preferred_with_tags_adds_read_preference(self): + pref = SecondaryPreferred(tag_sets=[{"dc": "east"}]) + spec: dict = {"find": "col"} + result = _maybe_add_read_preference(spec, pref) + self.assertIn("$readPreference", result) + self.assertEqual(result["$readPreference"]["mode"], "secondaryPreferred") + self.assertIn("$query", result) + + def test_existing_query_wrapper_preserved(self): + spec: dict = {"$query": {"x": 1}, "other": 2} + result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY) + self.assertIn("$readPreference", result) + self.assertEqual(result["$query"], {"x": 1}) + + # _convert_exception / _convert_client_bulk_exception + + def test_basic_exception(self): + exc = ValueError("bad value") + doc = _convert_exception(exc) + self.assertEqual(doc["errmsg"], "bad value") + self.assertEqual(doc["errtype"], "ValueError") + + def test_client_bulk_exception_includes_code(self): + exc = OperationFailure("failed", code=11000) + doc = _convert_client_bulk_exception(exc) + self.assertEqual(doc["errmsg"], "failed") + self.assertEqual(doc["code"], 11000) + self.assertEqual(doc["errtype"], "OperationFailure") + + # _convert_write_result + # In the update command spec, `q` is the query/filter and `u` is the update document. + + def test_insert_basic(self): + cmd = {"documents": [{"_id": 1}, {"_id": 2}]} + result = _convert_write_result("insert", cmd, {"n": 0}) + self.assertEqual(result["ok"], 1) + self.assertEqual(result["n"], 2) + + def test_update_basic(self): + cmd = {"updates": [{"q": {}, "u": {"$set": {"x": 1}}}]} + result = _convert_write_result("update", cmd, {"n": 1, "updatedExisting": True}) + self.assertEqual(result["ok"], 1) + self.assertNotIn("upserted", result) + + def test_update_with_upserted_id(self): + cmd = {"updates": [{"q": {}, "u": {"_id": 42}}]} + result = _convert_write_result("update", cmd, {"n": 1, "upserted": 42}) + self.assertIn("upserted", result) + self.assertEqual(result["upserted"][0]["_id"], 42) + + def test_delete_basic(self): + cmd = {"deletes": [{"q": {}, "limit": 1}]} + result = _convert_write_result("delete", cmd, {"n": 1}) + self.assertEqual(result["ok"], 1) + self.assertEqual(result["n"], 1) + + def test_write_error(self): + cmd = {"documents": [{"_id": 1}]} + gle = {"n": 0, "err": "duplicate key error", "code": 11000} + result = _convert_write_result("insert", cmd, gle) + self.assertIn("writeErrors", result) + self.assertEqual(result["writeErrors"][0]["code"], 11000) + + def test_write_concern_timeout(self): + cmd = {"documents": [{"_id": 1}]} + gle = {"n": 1, "errmsg": "timeout", "wtimeout": True} + result = _convert_write_result("insert", cmd, gle) + self.assertIn("writeConcernError", result) + self.assertEqual(result["writeConcernError"]["code"], 64) + + def test_write_error_with_err_info(self): + # Covers the `if "errInfo" in result:` branch, which test_write_error does not enter. + cmd = {"documents": [{"_id": 1}]} + gle = {"n": 0, "err": "err", "code": 123, "errInfo": {"detail": "x"}} + result = _convert_write_result("insert", cmd, gle) + self.assertIn("errInfo", result["writeErrors"][0]) + + # _op_msg + + def test_op_msg_max_doc_size_zero_without_docs(self): + max_doc_size = _op_msg(0, {"ping": 1}, "testdb", None, _OPTS)[3] + self.assertEqual(max_doc_size, 0) + + def test_op_msg_max_doc_size_matches_largest_encoded_doc(self): + docs = [{"_id": 1}, {"_id": 2, "data": "a" * 100}] + cmd: dict = {"insert": "col", "documents": docs} + max_doc_size = _op_msg(0, cmd, "testdb", None, _OPTS)[3] + self.assertEqual(max_doc_size, max(len(encode(d)) for d in docs)) + + def test_op_msg_read_preference_added_for_non_primary(self): + cmd: dict = {"find": "col"} + _op_msg(0, cmd, "testdb", ReadPreference.SECONDARY, _OPTS) + self.assertIn("$readPreference", cmd) + + def test_op_msg_read_preference_skipped_if_already_present(self): + cmd: dict = {"find": "col", "$readPreference": {"mode": "nearest"}} + _op_msg(0, cmd, "testdb", ReadPreference.SECONDARY, _OPTS) + self.assertEqual(cmd["$readPreference"]["mode"], "nearest") + + def test_op_msg_documents_field_is_restored(self): + docs = [{"_id": 1}] + cmd: dict = {"insert": "col", "documents": docs} + _op_msg(0, cmd, "testdb", None, _OPTS) + self.assertIn("documents", cmd) + self.assertEqual(cmd["documents"], docs) + + @unittest.skipUnless(_have_zlib(), "zlib not available") + def test_op_msg_compressed_zlib_header(self): + # Verify the compressed path is taken and produces a valid OP_COMPRESSED frame. + # Header layout (little-endian): [msgLen(4), reqId(4), responseTo(4), opCode(4), + # originalOpcode(4), uncompressedSize(4), compressorId(1)] + ctx = ZlibContext(6) + _, msg, _, _ = _op_msg(0, {"ping": 1}, "testdb", None, _OPTS, ctx=ctx) + (opcode,) = struct.unpack_from("