## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importsysfromtypingimport(Any,Callable,List,Optional,Union,no_type_check,overload,TYPE_CHECKING,)fromwarningsimportwarnfrompyspark.errors.exceptions.capturedimportunwrap_spark_exceptionfrompyspark.rddimport_load_from_socketfrompyspark.sql.pandas.serializersimportArrowCollectSerializerfrompyspark.sql.pandas.typesimport_dedup_namesfrompyspark.sql.typesimportArrayType,MapType,TimestampType,StructType,DataType,_create_rowfrompyspark.sql.utilsimportis_timestamp_ntz_preferredfrompyspark.traceback_utilsimportSCCallSiteSyncfrompyspark.errorsimportPySparkTypeErrorifTYPE_CHECKING:importnumpyasnpimportpyarrowaspafrompy4j.java_gatewayimportJavaObjectfrompyspark.sql.pandas._typingimportDataFrameLikeasPandasDataFrameLikefrompyspark.sqlimportDataFrameclassPandasConversionMixin:""" Mix-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame` can use this class. """deftoPandas(self)->"PandasDataFrameLike":""" Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. .. versionadded:: 1.3.0 .. versionchanged:: 3.4.0 Supports Spark Connect. Notes ----- This method should only be used if the resulting Pandas ``pandas.DataFrame`` is expected to be small, as all the data is loaded into the driver's memory. Usage with ``spark.sql.execution.arrow.pyspark.enabled=True`` is experimental. Examples -------- >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob """frompyspark.sql.dataframeimportDataFrameassertisinstance(self,DataFrame)frompyspark.sql.pandas.typesimport_create_converter_to_pandasfrompyspark.sql.pandas.utilsimportrequire_minimum_pandas_versionrequire_minimum_pandas_version()importpandasaspdjconf=self.sparkSession._jconfifjconf.arrowPySparkEnabled():use_arrow=Truetry:frompyspark.sql.pandas.typesimportto_arrow_schemafrompyspark.sql.pandas.utilsimportrequire_minimum_pyarrow_versionrequire_minimum_pyarrow_version()to_arrow_schema(self.schema)exceptExceptionase:ifjconf.arrowPySparkFallbackEnabled():msg=("toPandas attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, ""failed by the reason below:\n%s\n""Attempting non-optimization as ""'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to ""true."%str(e))warn(msg)use_arrow=Falseelse:msg=("toPandas attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has ""reached the error below and will not continue because automatic fallback ""with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to ""false.\n%s"%str(e))warn(msg)raise# Try to use Arrow optimization when the schema is supported and the required version# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.ifuse_arrow:try:importpyarrowself_destruct=jconf.arrowPySparkSelfDestructEnabled()batches=self._collect_as_arrow(split_batches=self_destruct)iflen(batches)>0:table=pyarrow.Table.from_batches(batches)# Ensure only the table has a reference to the batches, so that# self_destruct (if enabled) is effectivedelbatches# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type# values, but we should use datetime.date to match the behavior with when# Arrow optimization is disabled.pandas_options={"date_as_object":True}ifself_destruct:# Configure PyArrow to use as little memory as possible:# self_destruct - free columns as they are converted# split_blocks - create a separate Pandas block for each column# use_threads - convert one column at a timepandas_options.update({"self_destruct":True,"split_blocks":True,"use_threads":False,})# Rename columns to avoid duplicated column names.pdf=table.rename_columns([f"col_{i}"foriinrange(table.num_columns)]).to_pandas(**pandas_options)# Rename back to the original column names.pdf.columns=self.columnselse:pdf=pd.DataFrame(columns=self.columns)iflen(pdf.columns)>0:timezone=jconf.sessionLocalTimeZone()struct_in_pandas=jconf.pandasStructHandlingMode()error_on_duplicated_field_names=Falseifstruct_in_pandas=="legacy":error_on_duplicated_field_names=Truestruct_in_pandas="dict"returnpd.concat([_create_converter_to_pandas(field.dataType,field.nullable,timezone=timezone,struct_in_pandas=struct_in_pandas,error_on_duplicated_field_names=error_on_duplicated_field_names,)(pser)for(_,pser),fieldinzip(pdf.items(),self.schema.fields)],axis="columns",)else:returnpdfexceptExceptionase:# We might have to allow fallback here as well but multiple Spark jobs can# be executed. So, simply fail in this case for now.msg=("toPandas attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has ""reached the error below and can not continue. Note that ""'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an ""effect on failures in the middle of ""computation.\n%s"%str(e))warn(msg)raise# Below is toPandas without Arrow optimization.rows=self.collect()iflen(rows)>0:pdf=pd.DataFrame.from_records(rows,index=range(len(rows)),columns=self.columns# type: ignore[arg-type])else:pdf=pd.DataFrame(columns=self.columns)iflen(pdf.columns)>0:timezone=jconf.sessionLocalTimeZone()struct_in_pandas=jconf.pandasStructHandlingMode()returnpd.concat([_create_converter_to_pandas(field.dataType,field.nullable,timezone=timezone,struct_in_pandas=("row"ifstruct_in_pandas=="legacy"elsestruct_in_pandas),error_on_duplicated_field_names=False,timestamp_utc_localized=False,)(pser)for(_,pser),fieldinzip(pdf.items(),self.schema.fields)],axis="columns",)else:returnpdfdef_collect_as_arrow(self,split_batches:bool=False)->List["pa.RecordBatch"]:""" Returns all records as a list of ArrowRecordBatches, pyarrow must be installed and available on driver and worker Python environments. This is an experimental feature. :param split_batches: split batches such that each column is in its own allocation, so that the selfDestruct optimization is effective; default False. .. note:: Experimental. """frompyspark.sql.dataframeimportDataFrameassertisinstance(self,DataFrame)withSCCallSiteSync(self._sc):(port,auth_secret,jsocket_auth_server,)=self._jdf.collectAsArrowToPython()# Collect list of un-ordered batches where last element is a list of correct order indicestry:batch_stream=_load_from_socket((port,auth_secret),ArrowCollectSerializer())ifsplit_batches:# When spark.sql.execution.arrow.pyspark.selfDestruct.enabled, ensure# each column in each record batch is contained in its own allocation.# Otherwise, selfDestruct does nothing; it frees each column as its# converted, but each column will actually be a list of slices of record# batches, and so no memory is actually freed until all columns are# converted.importpyarrowasparesults=[]forbatch_or_indicesinbatch_stream:ifisinstance(batch_or_indices,pa.RecordBatch):batch_or_indices=pa.RecordBatch.from_arrays([# This call actually reallocates the arraypa.concat_arrays([array])forarrayinbatch_or_indices],schema=batch_or_indices.schema,)results.append(batch_or_indices)else:results=list(batch_stream)finally:withunwrap_spark_exception():# Join serving thread and raise any exceptions from collectAsArrowToPythonjsocket_auth_server.getResult()# Separate RecordBatches from batch order indices in resultsbatches=results[:-1]batch_order=results[-1]# Re-order the batch list using the correct orderreturn[batches[i]foriinbatch_order]classSparkConversionMixin:""" Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession` can use this class. """_jsparkSession:"JavaObject"@overloaddefcreateDataFrame(self,data:"PandasDataFrameLike",samplingRatio:Optional[float]=...)->"DataFrame":...@overloaddefcreateDataFrame(self,data:"PandasDataFrameLike",schema:Union[StructType,str],verifySchema:bool=...,)->"DataFrame":...defcreateDataFrame(# type: ignore[misc]self,data:"PandasDataFrameLike",schema:Optional[Union[StructType,List[str]]]=None,samplingRatio:Optional[float]=None,verifySchema:bool=True,)->"DataFrame":frompyspark.sqlimportSparkSessionassertisinstance(self,SparkSession)frompyspark.sql.pandas.utilsimportrequire_minimum_pandas_versionrequire_minimum_pandas_version()timezone=self._jconf.sessionLocalTimeZone()# If no schema supplied by user then get the names of columns onlyifschemaisNone:schema=[str(x)ifnotisinstance(x,str)elsexforxindata.columns]ifself._jconf.arrowPySparkEnabled()andlen(data)>0:try:returnself._create_from_pandas_with_arrow(data,schema,timezone)exceptExceptionase:ifself._jconf.arrowPySparkFallbackEnabled():msg=("createDataFrame attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, ""failed by the reason below:\n%s\n""Attempting non-optimization as ""'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to ""true."%str(e))warn(msg)else:msg=("createDataFrame attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has ""reached the error below and will not continue because automatic ""fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' ""has been set to false.\n%s"%str(e))warn(msg)raiseconverted_data=self._convert_from_pandas(data,schema,timezone)returnself._create_dataframe(converted_data,schema,samplingRatio,verifySchema)def_convert_from_pandas(self,pdf:"PandasDataFrameLike",schema:Union[StructType,str,List[str]],timezone:str)->List:""" Convert a pandas.DataFrame to list of records that can be used to make a DataFrame Returns ------- list list of records """importpandasaspdfrompyspark.sqlimportSparkSessionassertisinstance(self,SparkSession)iftimezoneisnotNone:frompyspark.sql.pandas.typesimport(_check_series_convert_timestamps_tz_local,_get_local_timezone,)frompandas.core.dtypes.commonimportis_datetime64tz_dtype,is_timedelta64_dtypecopied=Falseifisinstance(schema,StructType):def_create_converter(data_type:DataType)->Callable[[pd.Series],pd.Series]:ifisinstance(data_type,TimestampType):defcorrect_timestamp(pser:pd.Series)->pd.Series:return_check_series_convert_timestamps_tz_local(pser,timezone)returncorrect_timestampdef_converter(dt:DataType)->Optional[Callable[[Any],Any]]:ifisinstance(dt,ArrayType):element_conv=_converter(dt.elementType)or(lambdax:x)defconvert_array(value:Any)->Any:ifvalueisNone:returnNoneelse:return[element_conv(v)forvinvalue]returnconvert_arrayelifisinstance(dt,MapType):key_conv=_converter(dt.keyType)or(lambdax:x)value_conv=_converter(dt.valueType)or(lambdax:x)defconvert_map(value:Any)->Any:ifvalueisNone:returnNoneelse:return{key_conv(k):value_conv(v)fork,vinvalue.items()}returnconvert_mapelifisinstance(dt,StructType):field_names=dt.namesdedup_field_names=_dedup_names(field_names)field_convs=[_converter(f.dataType)or(lambdax:x)forfindt.fields]defconvert_struct(value:Any)->Any:ifvalueisNone:returnNoneelifisinstance(value,dict):_values=[field_convs[i](value.get(name,None))fori,nameinenumerate(dedup_field_names)]return_create_row(field_names,_values)else:_values=[field_convs[i](value[i])fori,nameinenumerate(value)]return_create_row(field_names,_values)returnconvert_structelifisinstance(dt,TimestampType):defconvert_timestamp(value:Any)->Any:ifvalueisNone:returnNoneelse:return(pd.Timestamp(value).tz_localize(timezone,ambiguous=False)# type: ignore.tz_convert(_get_local_timezone()).tz_localize(None).to_pydatetime())returnconvert_timestampelse:returnNoneconv=_converter(data_type)ifconvisnotNone:returnlambdapser:pser.apply(conv)# type: ignore[return-value]else:returnlambdapser:pseriflen(pdf.columns)>0:pdf=pd.concat([_create_converter(field.dataType)(pser)for(_,pser),fieldinzip(pdf.items(),schema.fields)],axis="columns",)copied=Trueelse:should_localize=notis_timestamp_ntz_preferred()forcolumn,seriesinpdf.items():s=seriesifshould_localizeandis_datetime64tz_dtype(s.dtype)ands.dt.tzisnotNone:s=_check_series_convert_timestamps_tz_local(series,timezone)ifsisnotseries:ifnotcopied:# Copy once if the series is modified to prevent the original# Pandas DataFrame from being updatedpdf=pdf.copy()copied=Truepdf[column]=sforcolumn,seriesinpdf.items():ifis_timedelta64_dtype(series):ifnotcopied:pdf=pdf.copy()copied=True# Explicitly set the timedelta as object so the output of numpy records can# hold the timedelta instances as are. Otherwise, it converts to the internal# numeric values.ser=pdf[column]pdf[column]=pd.Series(ser.dt.to_pytimedelta(),index=ser.index,dtype="object",name=ser.name)# Convert pandas.DataFrame to list of numpy recordsnp_records=pdf.set_axis([f"col_{i}"foriinrange(len(pdf.columns))],axis="columns"# type: ignore[arg-type]).to_records(index=False)# Check if any columns need to be fixed for Spark to infer properlyiflen(np_records)>0:record_dtype=self._get_numpy_record_dtype(np_records[0])ifrecord_dtypeisnotNone:return[r.astype(record_dtype).tolist()forrinnp_records]# Convert list of numpy records to python listsreturn[r.tolist()forrinnp_records]def_get_numpy_record_dtype(self,rec:"np.recarray")->Optional["np.dtype"]:""" Used when converting a pandas.DataFrame to Spark using to_records(), this will correct the dtypes of fields in a record so they can be properly loaded into Spark. Parameters ---------- rec : numpy.record a numpy record to check field dtypes Returns ------- numpy.dtype corrected dtype for a numpy.record or None if no correction needed """importnumpyasnpcur_dtypes=rec.dtypecol_names=cur_dtypes.namesrecord_type_list=[]has_rec_fix=Falseforiinrange(len(cur_dtypes)):curr_type=cur_dtypes[i]# If type is a datetime64 timestamp, convert to microseconds# NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,# conversion from [us] or lower will lead to py datetime objects, see SPARK-22417ifcurr_type==np.dtype("datetime64[ns]"):curr_type="datetime64[us]"has_rec_fix=Truerecord_type_list.append((str(col_names[i]),curr_type))returnnp.dtype(record_type_list)ifhas_rec_fixelseNonedef_create_from_pandas_with_arrow(self,pdf:"PandasDataFrameLike",schema:Union[StructType,List[str]],timezone:str)->"DataFrame":""" Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """frompyspark.sqlimportSparkSessionfrompyspark.sql.dataframeimportDataFrameassertisinstance(self,SparkSession)frompyspark.sql.pandas.serializersimportArrowStreamPandasSerializerfrompyspark.sql.typesimportTimestampTypefrompyspark.sql.pandas.typesimport(from_arrow_type,to_arrow_type,_deduplicate_field_names,)frompyspark.sql.pandas.utilsimport(require_minimum_pandas_version,require_minimum_pyarrow_version,)require_minimum_pandas_version()require_minimum_pyarrow_version()frompandas.api.typesimport(# type: ignore[attr-defined]is_datetime64_dtype,is_datetime64tz_dtype,)importpyarrowaspa# Create the Spark schema from list of names passed in with Arrow typesifisinstance(schema,(list,tuple)):arrow_schema=pa.Schema.from_pandas(pdf,preserve_index=False)struct=StructType()prefer_timestamp_ntz=is_timestamp_ntz_preferred()forname,fieldinzip(schema,arrow_schema):struct.add(name,from_arrow_type(field.type,prefer_timestamp_ntz),nullable=field.nullable)schema=struct# Determine arrow types to coerce data when creating batchesifisinstance(schema,StructType):spark_types=[_deduplicate_field_names(f.dataType)forfinschema.fields]elifisinstance(schema,DataType):raisePySparkTypeError(error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",message_parameters={"data_type":str(schema)},)else:# Any timestamps must be coerced to be compatible with Sparkspark_types=[TimestampType()ifis_datetime64_dtype(t)oris_datetime64tz_dtype(t)elseNonefortinpdf.dtypes]# Slice the DataFrame to be batchedstep=self._jconf.arrowMaxRecordsPerBatch()step=stepifstep>0elselen(pdf)pdf_slices=(pdf.iloc[start:start+step]forstartinrange(0,len(pdf),step))# Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_streamarrow_data=[[(c,to_arrow_type(t)iftisnotNoneelseNone,t)for(_,c),tinzip(pdf_slice.items(),spark_types)]forpdf_sliceinpdf_slices]jsparkSession=self._jsparkSessionsafecheck=self._jconf.arrowSafeTypeConversion()ser=ArrowStreamPandasSerializer(timezone,safecheck)@no_type_checkdefreader_func(temp_filename):returnself._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)@no_type_checkdefcreate_iter_server():returnself._jvm.ArrowIteratorServer()# Create Spark DataFrame from Arrow stream file, using one batch per partitionjiter=self._sc._serialize_to_jvm(arrow_data,ser,reader_func,create_iter_server)assertself._jvmisnotNonejdf=self._jvm.PythonSQLUtils.toDataFrame(jiter,schema.json(),jsparkSession)df=DataFrame(jdf,self)df._schema=schemareturndfdef_test()->None:importdoctestfrompyspark.sqlimportSparkSessionimportpyspark.sql.pandas.conversionglobs=pyspark.sql.pandas.conversion.__dict__.copy()spark=(SparkSession.builder.master("local[4]").appName("sql.pandas.conversion tests").getOrCreate())globs["spark"]=spark(failure_count,test_count)=doctest.testmod(pyspark.sql.pandas.conversion,globs=globs,optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE|doctest.REPORT_NDIFF,)spark.stop()iffailure_count:sys.exit(-1)if__name__=="__main__":_test()