From ffeef4fbfe340541bf5d7169255c10a02827fb0a Mon Sep 17 00:00:00 2001 From: Ami Oka Date: Fri, 9 Sep 2022 15:04:08 -0700 Subject: [PATCH] move getMapIdByValue to FieldMask.h Summary: Moved getMapIdByValue to FieldMask.h, and changed Patch.cpp and Object.h to use the function. Reviewed By: Mizuchi Differential Revision: D39364529 fbshipit-source-id: f4633036b311bd2fe97ab9c52613d290cbd2f4b7 --- .../src/thrift/lib/cpp2/detail/FieldMask.cpp | 14 +++++++++ .../thrift/src/thrift/lib/cpp2/detail/FieldMask.h | 5 ++++ .../thrift/src/thrift/lib/cpp2/protocol/Patch.cpp | 23 +++----------- .../src/thrift/lib/cpp2/protocol/detail/Object.cpp | 35 ---------------------- .../src/thrift/lib/cpp2/protocol/detail/Object.h | 8 ++--- 5 files changed, 26 insertions(+), 59 deletions(-) delete mode 100644 third-party/thrift/src/thrift/lib/cpp2/protocol/detail/Object.cpp diff --git a/third-party/thrift/src/thrift/lib/cpp2/detail/FieldMask.cpp b/third-party/thrift/src/thrift/lib/cpp2/detail/FieldMask.cpp index e56c7f49fe2..8a44083293c 100644 --- a/third-party/thrift/src/thrift/lib/cpp2/detail/FieldMask.cpp +++ b/third-party/thrift/src/thrift/lib/cpp2/detail/FieldMask.cpp @@ -297,4 +297,18 @@ void throwIfContainsMapMask(const Mask& mask) { } } +// Returns the MapId for the given key. +MapId findMapIdByValue(const Mask& mask, const Value& newKey) { + MapId mapId = MapId{reinterpret_cast(&newKey)}; + if (!(mask.includes_map_ref() || mask.excludes_map_ref())) { + return mapId; + } + const auto& mapIdToMask = mask.includes_map_ref() ? *mask.includes_map_ref() + : *mask.excludes_map_ref(); + auto it = std::find_if( + mapIdToMask.begin(), mapIdToMask.end(), [&newKey](const auto& kv) { + return *(reinterpret_cast(kv.first)) == newKey; + }); + return it == mapIdToMask.end() ? mapId : MapId{it->first}; +} } // namespace apache::thrift::protocol::detail diff --git a/third-party/thrift/src/thrift/lib/cpp2/detail/FieldMask.h b/third-party/thrift/src/thrift/lib/cpp2/detail/FieldMask.h index d29271973c9..1b09b19bf45 100644 --- a/third-party/thrift/src/thrift/lib/cpp2/detail/FieldMask.h +++ b/third-party/thrift/src/thrift/lib/cpp2/detail/FieldMask.h @@ -404,4 +404,9 @@ void compare_impl( mask[fieldId] = field_mask_constants::allMask(); }); } + +// Returns the MapId in map mask of the given Value key. +// If it doesn't exist, it returns the new MapId (pointer to the key). +// Assumes the map mask uses pointers to keys. +MapId findMapIdByValue(const Mask& mask, const Value& key); } // namespace apache::thrift::protocol::detail diff --git a/third-party/thrift/src/thrift/lib/cpp2/protocol/Patch.cpp b/third-party/thrift/src/thrift/lib/cpp2/protocol/Patch.cpp index 0e8291df2cd..d74330eec74 100644 --- a/third-party/thrift/src/thrift/lib/cpp2/protocol/Patch.cpp +++ b/third-party/thrift/src/thrift/lib/cpp2/protocol/Patch.cpp @@ -463,21 +463,6 @@ void ApplyPatch::operator()(const Object& patch, Object& value) const { applyFieldPatch(patchFields); } } - -// Returns the MapId for the given key. -int64_t getMapIdByValue(Mask& mask, const Value& newKey) { - int64_t mapId = reinterpret_cast(&newKey); - if (!mask.includes_map_ref()) { - return mapId; - } - const auto& includes = *mask.includes_map_ref(); - auto it = - std::find_if(includes.begin(), includes.end(), [&newKey](const auto& kv) { - return *(reinterpret_cast(kv.first)) == newKey; - }); - return it == includes.end() ? mapId : it->first; -} - // Inserts the next mask to getIncludesRef(mask)[id]. // Skips if mask is allMask (already includes all fields), or next is noneMask. template @@ -517,15 +502,15 @@ void insertFieldsToMask( } } else if (auto* map = patchFields.if_map()) { for (const auto& [key, value] : *map) { - auto readId = getMapIdByValue(masks.read, key); - auto writeId = getMapIdByValue(masks.write, key); + auto readId = static_cast(findMapIdByValue(masks.read, key)); + auto writeId = static_cast(findMapIdByValue(masks.write, key)); insertNextMask( masks, value, readId, writeId, recursive, getIncludesMapRef); } } else { // set of map keys (Remove) for (const auto& key : patchFields.as_set()) { - auto readId = getMapIdByValue(masks.read, key); - auto writeId = getMapIdByValue(masks.write, key); + auto readId = static_cast(findMapIdByValue(masks.read, key)); + auto writeId = static_cast(findMapIdByValue(masks.write, key)); insertMask(masks.read, readId, allMask(), getIncludesMapRef); insertMask(masks.write, writeId, allMask(), getIncludesMapRef); } diff --git a/third-party/thrift/src/thrift/lib/cpp2/protocol/detail/Object.cpp b/third-party/thrift/src/thrift/lib/cpp2/protocol/detail/Object.cpp deleted file mode 100644 index 4fed0642750..00000000000 --- a/third-party/thrift/src/thrift/lib/cpp2/protocol/detail/Object.cpp +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace apache::thrift::protocol::detail { - -int64_t getMaskKey(MaskRef ref, const Value& newKey) { - if (ref.isMapMask()) { - MapIdToMask map = ref.mask.includes_map_ref() - ? ref.mask.includes_map_ref().value() - : ref.mask.excludes_map_ref().value(); - for (auto& [key, next] : map) { - if (*(reinterpret_cast(key)) == newKey) { - return key; - } - } - } - return 0; -} - -} // namespace apache::thrift::protocol::detail diff --git a/third-party/thrift/src/thrift/lib/cpp2/protocol/detail/Object.h b/third-party/thrift/src/thrift/lib/cpp2/protocol/detail/Object.h index f3b723410f9..fbf908903e8 100644 --- a/third-party/thrift/src/thrift/lib/cpp2/protocol/detail/Object.h +++ b/third-party/thrift/src/thrift/lib/cpp2/protocol/detail/Object.h @@ -476,9 +476,6 @@ void setMaskedDataFull( type::ValueId{apache::thrift::util::i32ToZigzag(values.size() - 1)}; } -// Returns the map mask for the given key. -int64_t getMaskKey(MaskRef ref, const Value& newKey); - // parseValue with readMask and writeMask template MaskedDecodeResultValue parseValue( @@ -541,9 +538,10 @@ MaskedDecodeResultValue parseValue( prot.readMapBegin(keyType, valType, size); for (uint32_t i = 0; i < size; i++) { auto keyValue = parseValue(prot, keyType, string_to_binary); - MaskRef nextRead = readMask.get(MapId{getMaskKey(readMask, keyValue)}); + MaskRef nextRead = + readMask.get(findMapIdByValue(readMask.mask, keyValue)); MaskRef nextWrite = - writeMask.get(MapId{getMaskKey(writeMask, keyValue)}); + writeMask.get(findMapIdByValue(writeMask.mask, keyValue)); MaskedDecodeResultValue nestedResult = parseValue( prot, valType, nextRead, nextWrite, protocolData, string_to_binary); // Set nested MaskedDecodeResult if not empty. -- 2.11.4.GIT