I'm trying to convert some Pandas code to Spark for scaling. myfunc is a wrapper to a complex API that takes a string and returns a new string (meaning I can't use vectorized functions).

def myfunc(ds):
    for attribute, value in ds.items():
        value = api_function(attribute, value)
        ds[attribute] = value
    return ds

df = df.apply(myfunc, axis='columns')

myfunc takes a DataSeries, breaks it up into individual cells, calls the API for each cell, and builds a new DataSeries with the same column names. This effectively modifies all cells in the DataFrame.

I'm new to Spark and I want to translate this logic using pyspark. I've converted my pandas DataFrame to Spark:

spark = SparkSession.builder.appName('My app').getOrCreate()
spark_schema = StructType([StructField(c, StringType(), True) for c in df.columns])
spark_df = spark.createDataFrame(df, schema=spark_schema)

This is where I get lost. Do I need a UDF, a pandas_udf? How do I iterate across all cells and return a new string for each using myfunc? spark_df.foreach() doesn't return anything and it doesn't have a map() function.

I can modify myfunc from DataSeries -> DataSeries to string -> string if necessary.

score:5

Accepted answer

The solution is:

udf_func = udf(func, StringType())
for col_name in spark_df.columns:
    spark_df = spark_df.withColumn(col_name, udf_func(lit(col_name), col_name))
return spark_df.toPandas()

There are 3 key insights that helped me figure this out:

  1. If you use withColumn with the name of an existing column (col_name), Spark "overwrites"/shadows the original column. This essentially gives the appearance of editing the column directly as if it were mutable.
  2. By creating a loop across the original columns and reusing the same DataFrame variable spark_df, I use the same principle to simulate a mutable DataFrame, creating a chain of column-wise transformations, each time "overwriting" a column (per #1 - see below)
  3. Spark UDFs expect all parameters to be Column types, which means it attempts to resolve column values for each parameter. Because api_function's first parameter is a literal value that will be the same for all rows in the vector, you must use the lit() function. Simply passing col_name to the function will attempt to extract the column values for that column. As far as I could tell, passing col_name is equivalent to passing col(col_name).

Assuming 3 columns 'a', 'b' and 'c', unrolling this concept would look like this:

spark_df = spark_df.withColumn('a', udf_func(lit('a'), 'a')
                   .withColumn('b', udf_func(lit('b'), 'b')
                   .withColumn('c', udf_func(lit('c'), 'c')

score:7

Option 1: Use a UDF on One Column at a Time

The simplest approach would be to rewrite your function to take a string as an argument (so that it is string -> string) and use a UDF. There's a nice example here. This works on one column at a time. So, if your DataFrame has a reasonable number of columns, you can apply the UDF to each column one at a time:

from pyspark.sql.functions import col
new_df = df.select(udf(col("col1")), udf(col("col2")), ...)

Example

df = sc.parallelize([[1, 4], [2,5], [3,6]]).toDF(["col1", "col2"])
df.show()
+----+----+
|col1|col2|
+----+----+
|   1|   4|
|   2|   5|
|   3|   6|
+----+----+

def plus1_udf(x):
    return x + 1
plus1 = spark.udf.register("plus1", plus1_udf)

new_df = df.select(plus1(col("col1")), plus1(col("col2")))
new_df.show()
+-----------+-----------+
|plus1(col1)|plus1(col2)|
+-----------+-----------+
|          2|          5|
|          3|          6|
|          4|          7|
+-----------+-----------+

Option 2: Map the entire DataFrame at once

map is available for Scala DataFrames, but, at the moment, not in PySpark. The lower-level RDD API does have a map function in PySpark. So, if you have too many columns to transform one at a time, you could operate on every single cell in the DataFrame like this:

def map_fn(row):
    return [api_function(x) for (column, x) in row.asDict().items()

column_names = df.columns
new_df = df.rdd.map(map_fn).toDF(df.columns)

Example

df = sc.parallelize([[1, 4], [2,5], [3,6]]).toDF(["col1", "col2"])
def map_fn(row):
   return [value + 1 for (_, value) in row.asDict().items()]

columns = df.columns
new_df = df.rdd.map(map_fn).toDF(columns)
new_df.show()
+----+----+
|col1|col2|
+----+----+
|   2|   5|
|   3|   6|
|   4|   7|
+----+----+

Context

The documentation of foreach only gives the example of printing, but we can verify looking at the code that it indeed does not return anything.

You can read about pandas_udf in this post, but it seems that it is most suited to vectorized functions, which, as you pointed out, you can't use because of api_function.


Related Query