gene-db commented on code in PR #46338:
URL: https://github.com/apache/spark/pull/46338#discussion_r1588311315


##########
python/pyspark/sql/variant_utils.py:
##########
@@ -245,47 +245,57 @@ def _get_string(cls, value: bytes, pos: int) -> str:
                 length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, 
signed=False)
             cls._check_index(start + length - 1, len(value))
             return value[start : start + length].decode("utf-8")
-        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT", 
message_parameters={})
 
     @classmethod
     def _get_double(cls, value: bytes, pos: int) -> float:
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.PRIMITIVE:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", 
message_parameters={})
         if type_info == VariantUtils.FLOAT:
             cls._check_index(pos + 4, len(value))
             return struct.unpack("<f", value[pos + 1 : pos + 5])[0]
         elif type_info == VariantUtils.DOUBLE:
             cls._check_index(pos + 8, len(value))
             return struct.unpack("<d", value[pos + 1 : pos + 9])[0]
-        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT", 
message_parameters={})
+
+    @classmethod
+    def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, 
max_scale: int):
+        # max_unscaled == 10**max_scale, but we pass a literal parameter to 
avoid redundant
+        # computation.
+        if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > 
max_scale:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", 
message_parameters={})
 
     @classmethod
     def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.PRIMITIVE:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", 
message_parameters={})
         scale = value[pos + 1]
         unscaled = 0
         if type_info == VariantUtils.DECIMAL4:
             unscaled = cls._read_long(value, pos + 2, 4, signed=True)
+            cls._check_decimal(unscaled, scale, 1000000000, 9)
         elif type_info == VariantUtils.DECIMAL8:
             unscaled = cls._read_long(value, pos + 2, 8, signed=True)
+            cls._check_decimal(unscaled, scale, 1000000000000000000, 18)
         elif type_info == VariantUtils.DECIMAL16:
             cls._check_index(pos + 17, len(value))
             unscaled = int.from_bytes(value[pos + 2 : pos + 18], 
byteorder="little", signed=True)
+            cls._check_decimal(unscaled, scale, 
100000000000000000000000000000000000000, 38)

Review Comment:
   Let's make these literals into constants.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to