chenhao-db commented on code in PR #46338: URL: https://github.com/apache/spark/pull/46338#discussion_r1588326114
########## python/pyspark/sql/variant_utils.py: ########## @@ -245,47 +245,57 @@ def _get_string(cls, value: bytes, pos: int) -> str: length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False) cls._check_index(start + length - 1, len(value)) return value[start : start + length].decode("utf-8") - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_double(cls, value: bytes, pos: int) -> float: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) if type_info == VariantUtils.FLOAT: cls._check_index(pos + 4, len(value)) return struct.unpack("<f", value[pos + 1 : pos + 5])[0] elif type_info == VariantUtils.DOUBLE: cls._check_index(pos + 8, len(value)) return struct.unpack("<d", value[pos + 1 : pos + 9])[0] - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + + @classmethod + def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, max_scale: int): + # max_unscaled == 10**max_scale, but we pass a literal parameter to avoid redundant + # computation. + if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > max_scale: + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) scale = value[pos + 1] unscaled = 0 if type_info == VariantUtils.DECIMAL4: unscaled = cls._read_long(value, pos + 2, 4, signed=True) + cls._check_decimal(unscaled, scale, 1000000000, 9) elif type_info == VariantUtils.DECIMAL8: unscaled = cls._read_long(value, pos + 2, 8, signed=True) + cls._check_decimal(unscaled, scale, 1000000000000000000, 18) elif type_info == VariantUtils.DECIMAL16: cls._check_index(pos + 17, len(value)) unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True) + cls._check_decimal(unscaled, scale, 100000000000000000000000000000000000000, 38) Review Comment: Done. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org