Using CASE and WHENΒΆ
Let us understand how to perform conditional operations using CASE and WHEN in Spark.
- CASEand- WHENis typically used to apply transformations based up on conditions. We can use- CASEand- WHENsimilar to SQL using- expror- selectExpr.
- If we want to use APIs, Spark provides functions such as - whenand- otherwise.- whenis available as part of- pyspark.sql.functions. On top of column type that is generated using- whenwe should be able to invoke- otherwise.
Let us start spark context for this Notebook so that we can execute the code provided. You can sign up for our 10 node state of the art cluster/labs to learn Spark SQL using our unique integrated LMS.
from pyspark.sql import SparkSession
import getpass
username = getpass.getuser()
spark = SparkSession. \
    builder. \
    config('spark.ui.port', '0'). \
    config("spark.sql.warehouse.dir", f"/user/{username}/warehouse"). \
    enableHiveSupport(). \
    appName(f'{username} | Python - Processing Column Data'). \
    master('yarn'). \
    getOrCreate()
If you are going to use CLIs, you can use Spark SQL using one of the 3 approaches.
Using Spark SQL
spark2-sql \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse
Using Scala
spark2-shell \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse
Using Pyspark
pyspark2 \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse
employees = [(1, "Scott", "Tiger", 1000.0, 10,
                      "united states", "+1 123 456 7890", "123 45 6789"
                     ),
                     (2, "Henry", "Ford", 1250.0, None,
                      "India", "+91 234 567 8901", "456 78 9123"
                     ),
                     (3, "Nick", "Junior", 750.0, '',
                      "united KINGDOM", "+44 111 111 1111", "222 33 4444"
                     ),
                     (4, "Bill", "Gomes", 1500.0, 10,
                      "AUSTRALIA", "+61 987 654 3210", "789 12 6118"
                     )
                ]
employeesDF = spark. \
    createDataFrame(employees,
                    schema="""employee_id INT, first_name STRING, 
                    last_name STRING, salary FLOAT, bonus STRING, nationality STRING,
                    phone_number STRING, ssn STRING"""
                   )
employeesDF.show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|
|          2|     Henry|     Ford|1250.0| null|         India|+91 234 567 8901|456 78 9123|
|          3|      Nick|   Junior| 750.0|     |united KINGDOM|+44 111 111 1111|222 33 4444|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
- Let us transform bonus to 0 in case of null or empty, otherwise return the bonus amount. 
from pyspark.sql.functions import coalesce, lit, col
employeesDF. \
    withColumn('bonus1', coalesce(col('bonus').cast('int'), lit(0))). \
    show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|bonus1|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|    10|
|          2|     Henry|     Ford|1250.0| null|         India|+91 234 567 8901|456 78 9123|     0|
|          3|      Nick|   Junior| 750.0|     |united KINGDOM|+44 111 111 1111|222 33 4444|     0|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|    10|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+------+
from pyspark.sql.functions import expr
employeesDF. \
    withColumn(
        'bonus', 
        expr("""
            CASE WHEN bonus IS NULL OR bonus = '' THEN 0
            ELSE bonus
            END
            """)
    ). \
    show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|
|          2|     Henry|     Ford|1250.0|    0|         India|+91 234 567 8901|456 78 9123|
|          3|      Nick|   Junior| 750.0|    0|united KINGDOM|+44 111 111 1111|222 33 4444|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
from pyspark.sql.functions import when
when?
Signature: when(condition, value)
Docstring:
Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.
>>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
[Row(age=3), Row(age=4)]
>>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
[Row(age=3), Row(age=None)]
.. versionadded:: 1.4
File:      /usr/hdp/current/spark2-client/python/pyspark/sql/functions.py
Type:      function
employeesDF. \
    withColumn(
        'bonus',
        when((col('bonus').isNull()) | (col('bonus') == lit('')), 0).otherwise(col('bonus'))
    ). \
    show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|
|          2|     Henry|     Ford|1250.0|    0|         India|+91 234 567 8901|456 78 9123|
|          3|      Nick|   Junior| 750.0|    0|united KINGDOM|+44 111 111 1111|222 33 4444|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
- Create a dataframe using list called as persons and categorize them based up on following rules. 
| Age range | Category | 
|---|---|
| 0 to 2 Months | New Born | 
| 2+ Months to 12 Months | Infant | 
| 12+ Months to 48 Months | Toddler | 
| 48+ Months to 144 Months | Kids | 
| 144+ Months | Teenager or Adult | 
persons = [
    (1, 1),
    (2, 13),
    (3, 18),
    (4, 60),
    (5, 120),
    (6, 0),
    (7, 12),
    (8, 160)
]
personsDF = spark.createDataFrame(persons, schema='id INT, age INT')
personsDF.show()
+---+---+
| id|age|
+---+---+
|  1|  1|
|  2| 13|
|  3| 18|
|  4| 60|
|  5|120|
|  6|  0|
|  7| 12|
|  8|160|
+---+---+
personsDF. \
    withColumn(
        'category',
        expr("""
            CASE
            WHEN age BETWEEN 0 AND 2 THEN 'New Born'
            WHEN age > 2 AND age <= 12 THEN 'Infant'
            WHEN age > 12 AND age <= 48 THEN 'Toddler'
            WHEN age > 48 AND age <= 144 THEN 'Kid'
            ELSE 'Teenager or Adult'
            END
        """)
    ). \
    show()
+---+---+-----------------+
| id|age|         category|
+---+---+-----------------+
|  1|  1|         New Born|
|  2| 13|          Toddler|
|  3| 18|          Toddler|
|  4| 60|              Kid|
|  5|120|              Kid|
|  6|  0|         New Born|
|  7| 12|           Infant|
|  8|160|Teenager or Adult|
+---+---+-----------------+
personsDF. \
    withColumn(
        'category',
        when(col('age').between(0, 2), 'New Born').
        when((col('age') > 2) & (col('age') <= 12), 'Infant').
        when((col('age') > 12) & (col('age') <= 48), 'Toddler').
        when((col('age') > 48) & (col('age') <= 144), 'Kid').
        otherwise('Teenager or Adult')
    ). \
    show()
+---+---+-----------------+
| id|age|         category|
+---+---+-----------------+
|  1|  1|         New Born|
|  2| 13|          Toddler|
|  3| 18|          Toddler|
|  4| 60|              Kid|
|  5|120|              Kid|
|  6|  0|         New Born|
|  7| 12|           Infant|
|  8|160|Teenager or Adult|
+---+---+-----------------+