Total AggregationsΒΆ

Let us go through the details related to total aggregations using Spark.

  • We can perform total aggregations directly on Dataframe or we can perform aggregations after grouping by a key(s).

  • Here are the functions which we typically use to perform aggregations.

    • count

    • sum, avg

    • min, max

Let us start spark context for this Notebook so that we can execute the code provided. You can sign up for our 10 node state of the art cluster/labs to learn Spark SQL using our unique integrated LMS.

from pyspark.sql import SparkSession

import getpass
username = getpass.getuser()

spark = SparkSession. \
    builder. \
    config('spark.ui.port', '0'). \
    config("spark.sql.warehouse.dir", f"/user/{username}/warehouse"). \
    enableHiveSupport(). \
    appName(f'{username} | Python - Basic Transformations'). \
    master('yarn'). \
    getOrCreate()

If you are going to use CLIs, you can use Spark SQL using one of the 3 approaches.

Using Spark SQL

spark2-sql \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse

Using Scala

spark2-shell \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse

Using Pyspark

pyspark2 \
    --master yarn \
    --conf spark.ui.port=0 \
    --conf spark.sql.warehouse.dir=/user/${USER}/warehouse
airtraffic_path = "/public/airtraffic_all/airtraffic-part/flightmonth=200801"
airtraffic = spark. \
    read. \
    parquet(airtraffic_path)
  • Get number of flights in the month of 2008 January.

airtraffic.count()
605659
from pyspark.sql.functions import count
airtraffic.select(count("*").alias('count')).show()
+------+
| count|
+------+
|605659|
+------+
from pyspark.sql.functions import lit
airtraffic.select(count(lit(1)).alias('count')).show()
+------+
| count|
+------+
|605659|
+------+
airtraffic. \
    select('Year', 'Month', 'DayOfMonth'). \
    describe(). \
    show()
+-------+------+------+------------------+
|summary|  Year| Month|        DayOfMonth|
+-------+------+------+------------------+
|  count|605659|605659|            605659|
|   mean|2008.0|   1.0|15.908469947610785|
| stddev|   0.0|   0.0| 8.994294747375292|
|    min|  2008|     1|                 1|
|    max|  2008|     1|                31|
+-------+------+------+------------------+
airtraffic. \
    select('Year', 'Month', 'DayOfMonth'). \
    summary(). \
    show()
+-------+------+------+------------------+
|summary|  Year| Month|        DayOfMonth|
+-------+------+------+------------------+
|  count|605659|605659|            605659|
|   mean|2008.0|   1.0|15.908469947610785|
| stddev|   0.0|   0.0| 8.994294747375292|
|    min|  2008|     1|                 1|
|    25%|  2008|     1|                 8|
|    50%|  2008|     1|                16|
|    75%|  2008|     1|                24|
|    max|  2008|     1|                31|
+-------+------+------+------------------+
  • Get number of distinct dates from airtraffic data frame which is created using 2008 January data.

airtraffic. \
    select('Year', 'Month', 'DayOfMonth'). \
    distinct(). \
    count()
31
from pyspark.sql.functions import countDistinct
airtraffic. \
    select(countDistinct('Year', 'Month', 'DayOfMonth').alias('countDistinct')). \
    show()
+-------------+
|countDistinct|
+-------------+
|           31|
+-------------+
from pyspark.sql.functions import concat, lpad
airtraffic. \
    select(countDistinct(
        concat('Year', 
               lpad('Month', 2, '0'), 
               lpad('DayOfMonth', 2, '0')
              )).alias('countDistinct')). \
    show()
+-------------+
|countDistinct|
+-------------+
|           31|
+-------------+
  • Get the total bonus amount from employees data set. We need to use sum to get total bonus amount. We also have functions such as min, max, avg etc to take care of common aggregations.

employees = [(1, "Scott", "Tiger", 1000.0, 10,
                      "united states", "+1 123 456 7890", "123 45 6789"
                     ),
                     (2, "Henry", "Ford", 1250.0, None,
                      "India", "+91 234 567 8901", "456 78 9123"
                     ),
                     (3, "Nick", "Junior", 750.0, '',
                      "united KINGDOM", "+44 111 111 1111", "222 33 4444"
                     ),
                     (4, "Bill", "Gomes", 1500.0, 10,
                      "AUSTRALIA", "+61 987 654 3210", "789 12 6118"
                     )
                ]
employeesDF = spark. \
    createDataFrame(employees,
                    schema="""employee_id INT, first_name STRING, 
                    last_name STRING, salary FLOAT, bonus STRING, nationality STRING,
                    phone_number STRING, ssn STRING"""
                   )
employeesDF.show()
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|employee_id|first_name|last_name|salary|bonus|   nationality|    phone_number|        ssn|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
|          1|     Scott|    Tiger|1000.0|   10| united states| +1 123 456 7890|123 45 6789|
|          2|     Henry|     Ford|1250.0| null|         India|+91 234 567 8901|456 78 9123|
|          3|      Nick|   Junior| 750.0|     |united KINGDOM|+44 111 111 1111|222 33 4444|
|          4|      Bill|    Gomes|1500.0|   10|     AUSTRALIA|+61 987 654 3210|789 12 6118|
+-----------+----------+---------+------+-----+--------------+----------------+-----------+
from pyspark.sql.functions import col, coalesce, sum
employeesDF. \
    select(((sum(coalesce(col('bonus').cast('int'), lit(0)) * col('salary'))) / lit(100)).alias('total_bonus')). \
    show()
+-----------+
|total_bonus|
+-----------+
|      250.0|
+-----------+
employeesDF. \
    selectExpr('sum((coalesce(cast(bonus AS INT), 0) * salary) / 100) AS total_bonus'). \
    show()
+-----------+
|total_bonus|
+-----------+
|      250.0|
+-----------+
  • Get revenue generated for a given order from order_items.

order_items = spark.read.json('/public/retail_db_json/order_items')
order_id = input('Enter order_id:')
Enter order_id: 2
order_items. \
    filter(f'order_item_order_id = {int(order_id)}'). \
    show()
+-------------+-------------------+---------------------+------------------------+-------------------+-------------------+
|order_item_id|order_item_order_id|order_item_product_id|order_item_product_price|order_item_quantity|order_item_subtotal|
+-------------+-------------------+---------------------+------------------------+-------------------+-------------------+
|            2|                  2|                 1073|                  199.99|                  1|             199.99|
|            3|                  2|                  502|                    50.0|                  5|              250.0|
|            4|                  2|                  403|                  129.99|                  1|             129.99|
+-------------+-------------------+---------------------+------------------------+-------------------+-------------------+
order_items. \
    filter(f'order_item_order_id = {int(order_id)}'). \
    select(sum('order_item_subtotal').alias('order_revenue')). \
    show()
+-------------+
|order_revenue|
+-------------+
|       579.98|
+-------------+
order_items. \
    filter(col('order_item_order_id') == lit(int(order_id))). \
    select(sum('order_item_subtotal').alias('order_revenue')). \
    show()
+-------------+
|order_revenue|
+-------------+
|       579.98|
+-------------+