Using CASE and WHENΒΆ

Let us understand how to perform conditional operations using CASE and WHEN in Spark.

  • CASE and WHEN is typically used to apply transformations based up on conditions. We can use CASE and WHEN similar to SQL using expr or selectExpr.

  • If we want to use APIs, Spark provides functions such as when and otherwise. when is available as part of pyspark.sql.functions. On top of column type that is generated using when we should be able to invoke otherwise.

Let us start spark context for this Notebook so that we can execute the code provided. You can sign up for our 10 node state of the art cluster/labs to learn Spark SQL using our unique integrated LMS.

from pyspark.sql import SparkSession

import getpass
username = getpass.getuser()

spark = SparkSession. \
    builder. \
    config('spark.ui.port', '0'). \
    config("spark.sql.warehouse.dir", f"/user/{username}/warehouse"). \
    enableHiveSupport(). \
    appName(f'{username} | Python - Processing Column Data'). \
    master('yarn'). \
    getOrCreate()

If you are going to use CLIs, you can use Spark SQL using one of the 3 approaches.

Using Spark SQL

spark2-sql \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse

Using Scala

spark2-shell \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse

Using Pyspark

pyspark2 \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse
employees = [(1, "Scott", "Tiger", 1000.0, 10,
                      "united states", "+1 123 456 7890", "123 45 6789"
                     ),
                     (2, "Henry", "Ford", 1250.0, None,
                      "India", "+91 234 567 8901", "456 78 9123"
                     ),
                     (3, "Nick", "Junior", 750.0, '',
                      "united KINGDOM", "+44 111 111 1111", "222 33 4444"
                     ),
                     (4, "Bill", "Gomes", 1500.0, 10,
                      "AUSTRALIA", "+61 987 654 3210", "789 12 6118"
                     )
                ]
employeesDF = spark. \
    createDataFrame(employees,
                    schema="""employee_id INT, first_name STRING, 
                    last_name STRING, salary FLOAT, bonus STRING, nationality STRING,
                    phone_number STRING, ssn STRING"""
                   )
employeesDF.show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|
|          2|     Henry|     Ford|1250.0| null|         India|+91 234 567 8901|456 78 9123|
|          3|      Nick|   Junior| 750.0|     |united KINGDOM|+44 111 111 1111|222 33 4444|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
  • Let us transform bonus to 0 in case of null or empty, otherwise return the bonus amount.

from pyspark.sql.functions import coalesce, lit, col
employeesDF. \
    withColumn('bonus1', coalesce(col('bonus').cast('int'), lit(0))). \
    show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|bonus1|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|    10|
|          2|     Henry|     Ford|1250.0| null|         India|+91 234 567 8901|456 78 9123|     0|
|          3|      Nick|   Junior| 750.0|     |united KINGDOM|+44 111 111 1111|222 33 4444|     0|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|    10|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+------+
from pyspark.sql.functions import expr
employeesDF. \
    withColumn(
        'bonus', 
        expr("""
            CASE WHEN bonus IS NULL OR bonus = '' THEN 0
            ELSE bonus
            END
            """)
    ). \
    show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|
|          2|     Henry|     Ford|1250.0|    0|         India|+91 234 567 8901|456 78 9123|
|          3|      Nick|   Junior| 750.0|    0|united KINGDOM|+44 111 111 1111|222 33 4444|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
from pyspark.sql.functions import when
when?
Signature: when(condition, value)
Docstring:
Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.

:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.

>>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
[Row(age=3), Row(age=4)]

>>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
[Row(age=3), Row(age=None)]

.. versionadded:: 1.4
File:      /usr/hdp/current/spark2-client/python/pyspark/sql/functions.py
Type:      function
employeesDF. \
    withColumn(
        'bonus',
        when((col('bonus').isNull()) | (col('bonus') == lit('')), 0).otherwise(col('bonus'))
    ). \
    show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|
|          2|     Henry|     Ford|1250.0|    0|         India|+91 234 567 8901|456 78 9123|
|          3|      Nick|   Junior| 750.0|    0|united KINGDOM|+44 111 111 1111|222 33 4444|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
  • Create a dataframe using list called as persons and categorize them based up on following rules.

Age range

Category

0 to 2 Months

New Born

2+ Months to 12 Months

Infant

12+ Months to 48 Months

Toddler

48+ Months to 144 Months

Kids

144+ Months

Teenager or Adult

persons = [
    (1, 1),
    (2, 13),
    (3, 18),
    (4, 60),
    (5, 120),
    (6, 0),
    (7, 12),
    (8, 160)
]
personsDF = spark.createDataFrame(persons, schema='id INT, age INT')
personsDF.show()
+---+---+
| id|age|
+---+---+
|  1|  1|
|  2| 13|
|  3| 18|
|  4| 60|
|  5|120|
|  6|  0|
|  7| 12|
|  8|160|
+---+---+
personsDF. \
    withColumn(
        'category',
        expr("""
            CASE
            WHEN age BETWEEN 0 AND 2 THEN 'New Born'
            WHEN age > 2 AND age <= 12 THEN 'Infant'
            WHEN age > 12 AND age <= 48 THEN 'Toddler'
            WHEN age > 48 AND age <= 144 THEN 'Kid'
            ELSE 'Teenager or Adult'
            END
        """)
    ). \
    show()
+---+---+-----------------+
| id|age|         category|
+---+---+-----------------+
|  1|  1|         New Born|
|  2| 13|          Toddler|
|  3| 18|          Toddler|
|  4| 60|              Kid|
|  5|120|              Kid|
|  6|  0|         New Born|
|  7| 12|           Infant|
|  8|160|Teenager or Adult|
+---+---+-----------------+
personsDF. \
    withColumn(
        'category',
        when(col('age').between(0, 2), 'New Born').
        when((col('age') > 2) & (col('age') <= 12), 'Infant').
        when((col('age') > 12) & (col('age') <= 48), 'Toddler').
        when((col('age') > 48) & (col('age') <= 144), 'Kid').
        otherwise('Teenager or Adult')
    ). \
    show()
+---+---+-----------------+
| id|age|         category|
+---+---+-----------------+
|  1|  1|         New Born|
|  2| 13|          Toddler|
|  3| 18|          Toddler|
|  4| 60|              Kid|
|  5|120|              Kid|
|  6|  0|         New Born|
|  7| 12|           Infant|
|  8|160|Teenager or Adult|
+---+---+-----------------+