## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importfunctoolsimportshutilimporttempfileimportwarningsfromcontextlibimportcontextmanagerfromdistutils.versionimportLooseVersionimportdecimalfromtypingimportAny,Union,TYPE_CHECKINGimportpyspark.pandasaspsfrompyspark.pandas.frameimportDataFramefrompyspark.pandas.indexesimportIndexfrompyspark.pandas.seriesimportSeriesfrompyspark.pandas.utilsimportSPARK_CONF_ARROW_ENABLEDfrompyspark.testing.sqlutilsimportReusedSQLTestCasefrompyspark.errorsimportPySparkAssertionErrortabulate_requirement_message=Nonetry:fromtabulateimporttabulateexceptImportErrorase:# If tabulate requirement is not satisfied, skip related tests.tabulate_requirement_message=str(e)have_tabulate=tabulate_requirement_messageisNonematplotlib_requirement_message=Nonetry:importmatplotlibexceptImportErrorase:# If matplotlib requirement is not satisfied, skip related tests.matplotlib_requirement_message=str(e)have_matplotlib=matplotlib_requirement_messageisNoneplotly_requirement_message=Nonetry:importplotlyexceptImportErrorase:# If plotly requirement is not satisfied, skip related tests.plotly_requirement_message=str(e)have_plotly=plotly_requirement_messageisNonetry:frompyspark.sql.pandas.utilsimportrequire_minimum_pandas_versionrequire_minimum_pandas_version()importpandasaspdexceptImportError:pass__all__=["assertPandasOnSparkEqual"]def_assert_pandas_equal(left:Union[pd.DataFrame,pd.Series,pd.Index],right:Union[pd.DataFrame,pd.Series,pd.Index],checkExact:bool,):frompandas.core.dtypes.commonimportis_numeric_dtypefrompandas.testingimportassert_frame_equal,assert_index_equal,assert_series_equalifisinstance(left,pd.DataFrame)andisinstance(right,pd.DataFrame):try:ifLooseVersion(pd.__version__)>=LooseVersion("1.1"):kwargs=dict(check_freq=False)else:kwargs=dict()ifLooseVersion(pd.__version__)<LooseVersion("1.1.1"):# Due to https://github.com/pandas-dev/pandas/issues/35446checkExact=(checkExactandall([is_numeric_dtype(dtype)fordtypeinleft.dtypes])andall([is_numeric_dtype(dtype)fordtypeinright.dtypes]))assert_frame_equal(left,right,check_index_type=("equiv"iflen(left.index)>0elseFalse),check_column_type=("equiv"iflen(left.columns)>0elseFalse),check_exact=checkExact,**kwargs,)exceptAssertionError:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_DATAFRAME",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtypes),"right":right.to_string(),"right_dtype":str(right.dtypes),},)elifisinstance(left,pd.Series)andisinstance(right,pd.Series):try:ifLooseVersion(pd.__version__)>=LooseVersion("1.1"):kwargs=dict(check_freq=False)else:kwargs=dict()ifLooseVersion(pd.__version__)<LooseVersion("1.1.1"):# Due to https://github.com/pandas-dev/pandas/issues/35446checkExact=(checkExactandis_numeric_dtype(left.dtype)andis_numeric_dtype(right.dtype))assert_series_equal(left,right,check_index_type=("equiv"iflen(left.index)>0elseFalse),check_exact=checkExact,**kwargs,)exceptAssertionError:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_SERIES",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtype),"right":right.to_string(),"right_dtype":str(right.dtype),},)elifisinstance(left,pd.Index)andisinstance(right,pd.Index):try:ifLooseVersion(pd.__version__)<LooseVersion("1.1.1"):# Due to https://github.com/pandas-dev/pandas/issues/35446checkExact=(checkExactandis_numeric_dtype(left.dtype)andis_numeric_dtype(right.dtype))assert_index_equal(left,right,check_exact=checkExact)exceptAssertionError:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_INDEX",message_parameters={"left":left,"left_dtype":str(left.dtype),"right":right,"right_dtype":str(right.dtype),},)else:raiseValueError("Unexpected values: (%s, %s)"%(left,right))def_assert_pandas_almost_equal(left:Union[pd.DataFrame,pd.Series,pd.Index],right:Union[pd.DataFrame,pd.Series,pd.Index],rtol:float=1e-5,atol:float=1e-8,):""" This function checks if given pandas objects approximately same, which means the conditions below: - Both objects are nullable - Compare decimals and floats, where two values a and b are approximately equal if they satisfy the following formula: absolute(a - b) <= (atol + rtol * absolute(b)) where rtol=1e-5 and atol=1e-8 by default """defcompare_vals_approx(val1,val2):# compare vals for approximate equalityifisinstance(lval,(float,decimal.Decimal))orisinstance(rval,(float,decimal.Decimal)):ifabs(float(lval)-float(rval))>(atol+rtol*abs(float(rval))):returnFalseelifval1!=val2:returnFalsereturnTrueifisinstance(left,pd.DataFrame)andisinstance(right,pd.DataFrame):ifleft.shape!=right.shape:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_DATAFRAME",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtypes),"right":right.to_string(),"right_dtype":str(right.dtypes),},)forlcol,rcolinzip(left.columns,right.columns):iflcol!=rcol:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_DATAFRAME",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtypes),"right":right.to_string(),"right_dtype":str(right.dtypes),},)forlnull,rnullinzip(left[lcol].isnull(),right[rcol].isnull()):iflnull!=rnull:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_DATAFRAME",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtypes),"right":right.to_string(),"right_dtype":str(right.dtypes),},)forlval,rvalinzip(left[lcol].dropna(),right[rcol].dropna()):ifnotcompare_vals_approx(lval,rval):raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_DATAFRAME",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtypes),"right":right.to_string(),"right_dtype":str(right.dtypes),},)ifleft.columns.names!=right.columns.names:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_DATAFRAME",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtypes),"right":right.to_string(),"right_dtype":str(right.dtypes),},)elifisinstance(left,pd.Series)andisinstance(right,pd.Series):ifleft.name!=right.nameorlen(left)!=len(right):raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_SERIES",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtype),"right":right.to_string(),"right_dtype":str(right.dtype),},)forlnull,rnullinzip(left.isnull(),right.isnull()):iflnull!=rnull:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_SERIES",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtype),"right":right.to_string(),"right_dtype":str(right.dtype),},)forlval,rvalinzip(left.dropna(),right.dropna()):ifnotcompare_vals_approx(lval,rval):raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_SERIES",message_parameters={"left":left.to_string(),"left_dtype":str(left.dtype),"right":right.to_string(),"right_dtype":str(right.dtype),},)elifisinstance(left,pd.MultiIndex)andisinstance(right,pd.MultiIndex):iflen(left)!=len(right):raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_MULTIINDEX",message_parameters={"left":left,"left_dtype":str(left.dtype),"right":right,"right_dtype":str(right.dtype),},)forlval,rvalinzip(left,right):ifnotcompare_vals_approx(lval,rval):raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_MULTIINDEX",message_parameters={"left":left,"left_dtype":str(left.dtype),"right":right,"right_dtype":str(right.dtype),},)elifisinstance(left,pd.Index)andisinstance(right,pd.Index):iflen(left)!=len(right):raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_INDEX",message_parameters={"left":left,"left_dtype":str(left.dtype),"right":right,"right_dtype":str(right.dtype),},)forlnull,rnullinzip(left.isnull(),right.isnull()):iflnull!=rnull:raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_INDEX",message_parameters={"left":left,"left_dtype":str(left.dtype),"right":right,"right_dtype":str(right.dtype),},)forlval,rvalinzip(left.dropna(),right.dropna()):ifnotcompare_vals_approx(lval,rval):raisePySparkAssertionError(error_class="DIFFERENT_PANDAS_INDEX",message_parameters={"left":left,"left_dtype":str(left.dtype),"right":right,"right_dtype":str(right.dtype),},)else:ifnotisinstance(left,(pd.DataFrame,pd.Series,pd.Index)):raisePySparkAssertionError(error_class="INVALID_TYPE_DF_EQUALITY_ARG",message_parameters={"expected_type":f"{pd.DataFrame.__name__}, "f"{pd.Series.__name__}, "f"{pd.Index.__name__}, ","arg_name":"left","actual_type":type(left),},)elifnotisinstance(right,(pd.DataFrame,pd.Series,pd.Index)):raisePySparkAssertionError(error_class="INVALID_TYPE_DF_EQUALITY_ARG",message_parameters={"expected_type":f"{pd.DataFrame.__name__}, "f"{pd.Series.__name__}, "f"{pd.Index.__name__}, ","arg_name":"right","actual_type":type(right),},)
[docs]defassertPandasOnSparkEqual(actual:Union[DataFrame,Series,Index],expected:Union[DataFrame,pd.DataFrame,Series,pd.Series,Index,pd.Index],checkExact:bool=True,almost:bool=False,rtol:float=1e-5,atol:float=1e-8,checkRowOrder:bool=True,):r""" A util function to assert equality between actual (pandas-on-Spark object) and expected (pandas-on-Spark or pandas object). .. versionadded:: 3.5.0 .. deprecated:: 3.5.1 `assertPandasOnSparkEqual` will be removed in Spark 4.0.0. Parameters ---------- actual: pandas-on-Spark DataFrame, Series, or Index The object that is being compared or tested. expected: pandas-on-Spark or pandas DataFrame, Series, or Index The expected object, for comparison with the actual result. checkExact: bool, optional A flag indicating whether to compare exact equality. If set to 'True' (default), the data is compared exactly. If set to 'False', the data is compared less precisely, following pandas assert_frame_equal approximate comparison (see documentation for more details). almost: bool, optional A flag indicating whether to use unittest `assertAlmostEqual` or `assertEqual`. If set to 'True', the comparison is delegated to `unittest`'s `assertAlmostEqual` (see documentation for more details). If set to 'False' (default), the data is compared exactly with `unittest`'s `assertEqual`. rtol : float, optional The relative tolerance, used in asserting almost equality for float values in actual and expected. Set to 1e-5 by default. (See Notes) atol : float, optional The absolute tolerance, used in asserting almost equality for float values in actual and expected. Set to 1e-8 by default. (See Notes) checkRowOrder : bool, optional A flag indicating whether the order of rows should be considered in the comparison. If set to `False`, the row order is not taken into account. If set to `True` (default), the order of rows will be checked during comparison. (See Notes) Notes ----- For `checkRowOrder`, note that pandas-on-Spark DataFrame ordering is non-deterministic, unless explicitly sorted. When `almost` is set to True, approximate equality will be asserted, where two values a and b are approximately equal if they satisfy the following formula: ``absolute(a - b) <= (atol + rtol * absolute(b))``. Examples -------- >>> import pyspark.pandas as ps >>> psdf1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) >>> psdf2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) >>> assertPandasOnSparkEqual(psdf1, psdf2) # pass, ps.DataFrames are equal >>> s1 = ps.Series([212.32, 100.0001]) >>> s2 = ps.Series([212.32, 100.0]) >>> assertPandasOnSparkEqual(s1, s2, checkExact=False) # pass, ps.Series are approx equal >>> s1 = ps.Index([212.300001, 100.000]) >>> s2 = ps.Index([212.3, 100.0001]) >>> assertPandasOnSparkEqual(s1, s2, almost=True) # pass, ps.Index obj are almost equal """warnings.warn("`assertPandasOnSparkEqual` will be removed in Spark 4.0.0. ",FutureWarning,)ifactualisNoneandexpectedisNone:returnTrueelifactualisNoneorexpectedisNone:returnFalseifnotisinstance(actual,(DataFrame,Series,Index)):raisePySparkAssertionError(error_class="INVALID_TYPE_DF_EQUALITY_ARG",message_parameters={"expected_type":f"{DataFrame.__name__}, {Series.__name__}, {Index.__name__}","arg_name":"actual","actual_type":type(actual),},)elifnotisinstance(expected,(DataFrame,pd.DataFrame,Series,pd.Series,Index,pd.Index)):raisePySparkAssertionError(error_class="INVALID_TYPE_DF_EQUALITY_ARG",message_parameters={"expected_type":f"{DataFrame.__name__}, "f"{pd.DataFrame.__name__}, "f"{Series.__name__}, "f"{pd.Series.__name__}, "f"{Index.__name__}"f"{pd.Index.__name__}, ","arg_name":"expected","actual_type":type(expected),},)else:ifnotisinstance(actual,(pd.DataFrame,pd.Index,pd.Series)):actual=actual.to_pandas()ifnotisinstance(expected,(pd.DataFrame,pd.Index,pd.Series)):expected=expected.to_pandas()ifnotcheckRowOrder:ifisinstance(actual,pd.DataFrame)andlen(actual.columns)>0:actual=actual.sort_values(by=actual.columns[0],ignore_index=True)ifisinstance(expected,pd.DataFrame)andlen(expected.columns)>0:expected=expected.sort_values(by=expected.columns[0],ignore_index=True)ifalmost:_assert_pandas_almost_equal(actual,expected,rtol=rtol,atol=atol)else:_assert_pandas_equal(actual,expected,checkExact=checkExact)
classPandasOnSparkTestUtils:defconvert_str_to_lambda(self,func:str):""" This function converts `func` str to lambda call """returnlambdax:getattr(x,func)()defassertPandasEqual(self,left:Any,right:Any,check_exact:bool=True):_assert_pandas_equal(left,right,check_exact)defassertPandasAlmostEqual(self,left:Any,right:Any,rtol:float=1e-5,atol:float=1e-8,):_assert_pandas_almost_equal(left,right,rtol=rtol,atol=atol)defassert_eq(self,left:Any,right:Any,check_exact:bool=True,almost:bool=False,rtol:float=1e-5,atol:float=1e-8,check_row_order:bool=True,):""" Asserts if two arbitrary objects are equal or not. If given objects are Koalas DataFrame or Series, they are converted into pandas' and compared. :param left: object to compare :param right: object to compare :param check_exact: if this is False, the comparison is done less precisely. :param almost: if this is enabled, the comparison asserts approximate equality for float and decimal values, where two values a and b are approximately equal if they satisfy the following formula: absolute(a - b) <= (atol + rtol * absolute(b)) :param rtol: The relative tolerance, used in asserting approximate equality for float values. Set to 1e-5 by default. :param atol: The absolute tolerance, used in asserting approximate equality for float values in actual and expected. Set to 1e-8 by default. :param check_row_order: A flag indicating whether the order of rows should be considered in the comparison. If set to False, row order will be ignored. """importpandasaspdfrompandas.api.typesimportis_list_like# for pandas-on-Spark DataFrames, allow choice to ignore row orderifisinstance(left,(ps.DataFrame,ps.Series,ps.Index)):returnassertPandasOnSparkEqual(left,right,checkExact=check_exact,almost=almost,rtol=rtol,atol=atol,checkRowOrder=check_row_order,)lobj=self._to_pandas(left)robj=self._to_pandas(right)ifisinstance(lobj,(pd.DataFrame,pd.Series,pd.Index)):ifalmost:_assert_pandas_almost_equal(lobj,robj,rtol=rtol,atol=atol)else:_assert_pandas_equal(lobj,robj,checkExact=check_exact)elifis_list_like(lobj)andis_list_like(robj):self.assertTrue(len(left)==len(right))forlitem,riteminzip(left,right):self.assert_eq(litem,ritem,check_exact=check_exact,almost=almost)elif(lobjisnotNoneandpd.isna(lobj))and(robjisnotNoneandpd.isna(robj)):passelse:ifalmost:self.assertAlmostEqual(lobj,robj)else:self.assertEqual(lobj,robj)@staticmethoddef_to_pandas(obj:Any):ifisinstance(obj,(DataFrame,Series,Index)):returnobj.to_pandas()else:returnobjclassPandasOnSparkTestCase(ReusedSQLTestCase,PandasOnSparkTestUtils):@classmethoddefsetUpClass(cls):super(PandasOnSparkTestCase,cls).setUpClass()cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED,True)classTestUtils:@contextmanagerdeftemp_dir(self):tmp=tempfile.mkdtemp()try:yieldtmpfinally:shutil.rmtree(tmp)@contextmanagerdeftemp_file(self):withself.temp_dir()astmp:yieldtempfile.mkstemp(dir=tmp)[1]classComparisonTestBase(PandasOnSparkTestCase):@propertydefpsdf(self):returnps.from_pandas(self.pdf)@propertydefpdf(self):returnself.psdf.to_pandas()defcompare_both(f=None,almost=True):iffisNone:returnfunctools.partial(compare_both,almost=almost)elifisinstance(f,bool):returnfunctools.partial(compare_both,almost=f)@functools.wraps(f)defwrapped(self):ifalmost:compare=self.assertPandasAlmostEqualelse:compare=self.assertPandasEqualforresult_pandas,result_sparkinzip(f(self,self.pdf),f(self,self.psdf)):compare(result_pandas,result_spark.to_pandas())returnwrapped@contextmanagerdefassert_produces_warning(expected_warning=Warning,filter_level="always",check_stacklevel=True,raise_on_extra_warnings=True,):""" Context manager for running code expected to either raise a specific warning, or not raise any warnings. Verifies that the code raises the expected warning, and that it does not raise any other unexpected warnings. It is basically a wrapper around ``warnings.catch_warnings``. Notes ----- Replicated from pandas/_testing/_warnings.py. Parameters ---------- expected_warning : {Warning, False, None}, default Warning The type of Exception raised. ``exception.Warning`` is the base class for all warnings. To check that no warning is returned, specify ``False`` or ``None``. filter_level : str or None, default "always" Specifies whether warnings are ignored, displayed, or turned into errors. Valid values are: * "error" - turns matching warnings into exceptions * "ignore" - discard the warning * "always" - always emit a warning * "default" - print the warning the first time it is generated from each location * "module" - print the warning the first time it is generated from each module * "once" - print the warning the first time it is generated check_stacklevel : bool, default True If True, displays the line that called the function containing the warning to show were the function is called. Otherwise, the line that implements the function is displayed. raise_on_extra_warnings : bool, default True Whether extra warnings not of the type `expected_warning` should cause the test to fail. Examples -------- >>> import warnings >>> with assert_produces_warning(): ... warnings.warn(UserWarning()) ... >>> with assert_produces_warning(False): # doctest: +SKIP ... warnings.warn(RuntimeWarning()) ... Traceback (most recent call last): ... AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. >>> with assert_produces_warning(UserWarning): # doctest: +SKIP ... warnings.warn(RuntimeWarning()) Traceback (most recent call last): ... AssertionError: Did not see expected warning of class 'UserWarning' ..warn:: This is *not* thread-safe. """__tracebackhide__=Truewithwarnings.catch_warnings(record=True)asw:saw_warning=Falsewarnings.simplefilter(filter_level)yieldwextra_warnings=[]foractual_warninginw:ifexpected_warningandissubclass(actual_warning.category,expected_warning):saw_warning=Trueifcheck_stacklevelandissubclass(actual_warning.category,(FutureWarning,DeprecationWarning)):frominspectimportgetframeinfo,stackcaller=getframeinfo(stack()[2][0])msg=("Warning not set with correct stacklevel. ","File where warning is raised: {} != ".format(actual_warning.filename),"{}. Warning message: {}".format(caller.filename,actual_warning.message),)assertactual_warning.filename==caller.filename,msgelse:extra_warnings.append((actual_warning.category.__name__,actual_warning.message,actual_warning.filename,actual_warning.lineno,))ifexpected_warning:msg="Did not see expected warning of class {}".format(repr(expected_warning.__name__))assertsaw_warning,msgifraise_on_extra_warningsandextra_warnings:raiseAssertionError("Caused unexpected warning(s): {}".format(repr(extra_warnings)))