Aggregate data using groupByΒΆ
Let us go through the details related to aggregations using groupBy
in Spark.
Here are the APIs which we typically use to group the data using a key. As part of this topic, we will primarily focus on
groupBy
.groupBy
rollup
cube
Here are the functions which we typically use to perform aggregations.
count
sum
,avg
min
,max
If we want to provide aliases to the aggregated fields then we have to use
agg
aftergroupBy
.
Let us start spark context for this Notebook so that we can execute the code provided. You can sign up for our 10 node state of the art cluster/labs to learn Spark SQL using our unique integrated LMS.
from pyspark.sql import SparkSession
import getpass
username = getpass.getuser()
spark = SparkSession. \
builder. \
config('spark.ui.port', '0'). \
config("spark.sql.warehouse.dir", f"/user/{username}/warehouse"). \
enableHiveSupport(). \
appName(f'{username} | Python - Basic Transformations'). \
master('yarn'). \
getOrCreate()
If you are going to use CLIs, you can use Spark SQL using one of the 3 approaches.
Using Spark SQL
spark2-sql \
--master yarn \
--conf spark.ui.port=0 \
--conf spark.sql.warehouse.dir=/user/${USER}/warehouse
Using Scala
spark2-shell \
--master yarn \
--conf spark.ui.port=0 \
--conf spark.sql.warehouse.dir=/user/${USER}/warehouse
Using Pyspark
pyspark2 \
--master yarn \
--conf spark.ui.port=0 \
--conf spark.sql.warehouse.dir=/user/${USER}/warehouse
airtraffic_path = "/public/airtraffic_all/airtraffic-part/flightmonth=200801"
airtraffic = spark. \
read. \
parquet(airtraffic_path)
airtraffic.printSchema()
root
|-- Year: integer (nullable = true)
|-- Month: integer (nullable = true)
|-- DayofMonth: integer (nullable = true)
|-- DayOfWeek: integer (nullable = true)
|-- DepTime: string (nullable = true)
|-- CRSDepTime: integer (nullable = true)
|-- ArrTime: string (nullable = true)
|-- CRSArrTime: integer (nullable = true)
|-- UniqueCarrier: string (nullable = true)
|-- FlightNum: integer (nullable = true)
|-- TailNum: string (nullable = true)
|-- ActualElapsedTime: string (nullable = true)
|-- CRSElapsedTime: integer (nullable = true)
|-- AirTime: string (nullable = true)
|-- ArrDelay: string (nullable = true)
|-- DepDelay: string (nullable = true)
|-- Origin: string (nullable = true)
|-- Dest: string (nullable = true)
|-- Distance: string (nullable = true)
|-- TaxiIn: string (nullable = true)
|-- TaxiOut: string (nullable = true)
|-- Cancelled: integer (nullable = true)
|-- CancellationCode: string (nullable = true)
|-- Diverted: integer (nullable = true)
|-- CarrierDelay: string (nullable = true)
|-- WeatherDelay: string (nullable = true)
|-- NASDelay: string (nullable = true)
|-- SecurityDelay: string (nullable = true)
|-- LateAircraftDelay: string (nullable = true)
|-- IsArrDelayed: string (nullable = true)
|-- IsDepDelayed: string (nullable = true)
airtraffic.count()
605659
Get number of flights scheduled each day for the month of January 2008
from pyspark.sql.functions import concat, lpad
airtraffic. \
groupBy(concat("year",
lpad("Month", 2, "0"),
lpad("DayOfMonth", 2, "0")
).alias("FlightDate")
). \
count(). \
show(31)
+----------+-----+
|FlightDate|count|
+----------+-----+
| 20080120|18653|
| 20080130|19766|
| 20080115|19503|
| 20080118|20347|
| 20080122|19504|
| 20080104|20929|
| 20080125|20313|
| 20080102|20953|
| 20080105|18066|
| 20080111|20349|
| 20080109|19820|
| 20080127|18903|
| 20080101|19175|
| 20080128|20147|
| 20080119|16249|
| 20080106|19893|
| 20080123|19769|
| 20080117|20273|
| 20080116|19764|
| 20080112|16572|
| 20080103|20937|
| 20080126|16276|
| 20080108|19603|
| 20080110|20297|
| 20080121|20133|
| 20080129|19485|
| 20080131|20260|
| 20080124|20257|
| 20080107|20341|
| 20080113|18946|
| 20080114|20176|
+----------+-----+
from pyspark.sql.functions import count, lit
airtraffic. \
groupBy(concat("year",
lpad("Month", 2, "0"),
lpad("DayOfMonth", 2, "0")
).alias("FlightDate")
). \
agg(count(lit(1)).alias("FlightCount")). \
show(31)
+----------+-----------+
|FlightDate|FlightCount|
+----------+-----------+
| 20080120| 18653|
| 20080130| 19766|
| 20080115| 19503|
| 20080118| 20347|
| 20080122| 19504|
| 20080104| 20929|
| 20080125| 20313|
| 20080102| 20953|
| 20080105| 18066|
| 20080111| 20349|
| 20080109| 19820|
| 20080127| 18903|
| 20080101| 19175|
| 20080128| 20147|
| 20080119| 16249|
| 20080106| 19893|
| 20080123| 19769|
| 20080117| 20273|
| 20080116| 19764|
| 20080112| 16572|
| 20080103| 20937|
| 20080126| 16276|
| 20080108| 19603|
| 20080110| 20297|
| 20080121| 20133|
| 20080129| 19485|
| 20080131| 20260|
| 20080124| 20257|
| 20080107| 20341|
| 20080113| 18946|
| 20080114| 20176|
+----------+-----------+
Get count of flights departed, total departure delay and average departure delay for each day over the month of January 2008
from pyspark.sql.functions import sum, avg
airtraffic. \
filter('Cancelled = 0'). \
groupBy(concat("year",
lpad("Month", 2, "0"),
lpad("DayOfMonth", 2, "0")
).alias("FlightDate")
). \
agg(
count(lit(1)).alias("FlightCount"),
sum('DepDelay').alias('TotalDepDelay'),
avg('DepDelay').alias('AverageDepDelay')
). \
show(31)
+----------+-----------+-------------+------------------+
|FlightDate|FlightCount|TotalDepDelay| AverageDepDelay|
+----------+-----------+-------------+------------------+
| 20080120| 18406| 117460.0| 6.381614690861675|
| 20080130| 19072| 129345.0| 6.781931627516778|
| 20080115| 19204| 75096.0|3.9104353259737556|
| 20080118| 20117| 223738.0|11.121837252075359|
| 20080122| 18716| 303796.0| 16.23188715537508|
| 20080104| 20160| 277373.0|13.758581349206349|
| 20080125| 19787| 229850.0|11.616212664880983|
| 20080102| 20442| 452979.0|22.159230995010272|
| 20080105| 17610| 306068.0|17.380352072685973|
| 20080111| 19825| 190918.0| 9.63016393442623|
| 20080109| 19443| 89595.0| 4.608085172041352|
| 20080127| 18265| 365491.0|20.010457158499865|
| 20080101| 18623| 354108.0| 19.01455189819041|
| 20080128| 19493| 220046.0|11.288462525008978|
| 20080119| 15373| 155488.0|10.114356339035972|
| 20080106| 19210| 323214.0| 16.82529932326913|
| 20080123| 19239| 190807.0| 9.917719216175477|
| 20080117| 19401| 341271.0|17.590381939075307|
| 20080116| 19232| 61021.0| 3.172888935108153|
| 20080112| 16346| 24876.0|1.5218402055548759|
| 20080103| 20462| 329690.0|16.112305737464567|
| 20080126| 15860| 92129.0| 5.808890290037831|
| 20080108| 19140| 200670.0|10.484326018808778|
| 20080110| 19956| 148603.0| 7.446532371216676|
| 20080121| 19658| 370196.0| 18.83182419371248|
| 20080129| 18596| 184855.0| 9.940578619057861|
| 20080131| 19179| 396280.0|20.662182595547215|
| 20080124| 19935| 158134.0| 7.932480561825934|
| 20080107| 19762| 238431.0|12.065124987349458|
| 20080113| 18587| 101753.0| 5.474417603701512|
| 20080114| 19267| 98261.0| 5.099963668448643|
+----------+-----------+-------------+------------------+
from pyspark.sql.functions import round
airtraffic. \
filter('Cancelled = 0'). \
groupBy(concat("year",
lpad("Month", 2, "0"),
lpad("DayOfMonth", 2, "0")
).alias("FlightDate")
). \
agg(
count(lit(1)).alias("FlightCount"),
sum('DepDelay').alias('TotalDepDelay'),
round(avg('DepDelay'), 2).alias('AverageDepDelay')
). \
show(31)
+----------+-----------+-------------+---------------+
|FlightDate|FlightCount|TotalDepDelay|AverageDepDelay|
+----------+-----------+-------------+---------------+
| 20080120| 18406| 117460.0| 6.38|
| 20080130| 19072| 129345.0| 6.78|
| 20080115| 19204| 75096.0| 3.91|
| 20080118| 20117| 223738.0| 11.12|
| 20080122| 18716| 303796.0| 16.23|
| 20080104| 20160| 277373.0| 13.76|
| 20080125| 19787| 229850.0| 11.62|
| 20080102| 20442| 452979.0| 22.16|
| 20080105| 17610| 306068.0| 17.38|
| 20080111| 19825| 190918.0| 9.63|
| 20080109| 19443| 89595.0| 4.61|
| 20080127| 18265| 365491.0| 20.01|
| 20080101| 18623| 354108.0| 19.01|
| 20080128| 19493| 220046.0| 11.29|
| 20080119| 15373| 155488.0| 10.11|
| 20080106| 19210| 323214.0| 16.83|
| 20080123| 19239| 190807.0| 9.92|
| 20080117| 19401| 341271.0| 17.59|
| 20080116| 19232| 61021.0| 3.17|
| 20080112| 16346| 24876.0| 1.52|
| 20080103| 20462| 329690.0| 16.11|
| 20080126| 15860| 92129.0| 5.81|
| 20080108| 19140| 200670.0| 10.48|
| 20080110| 19956| 148603.0| 7.45|
| 20080121| 19658| 370196.0| 18.83|
| 20080129| 18596| 184855.0| 9.94|
| 20080131| 19179| 396280.0| 20.66|
| 20080124| 19935| 158134.0| 7.93|
| 20080107| 19762| 238431.0| 12.07|
| 20080113| 18587| 101753.0| 5.47|
| 20080114| 19267| 98261.0| 5.1|
+----------+-----------+-------------+---------------+
Using order_items, get revenue for each order.
order_items_path = '/public/retail_db_json/order_items'
order_items = spark. \
read. \
json(order_items_path)
order_items.printSchema()
root
|-- order_item_id: long (nullable = true)
|-- order_item_order_id: long (nullable = true)
|-- order_item_product_id: long (nullable = true)
|-- order_item_product_price: double (nullable = true)
|-- order_item_quantity: long (nullable = true)
|-- order_item_subtotal: double (nullable = true)
order_items. \
groupBy('order_item_order_id'). \
sum('order_item_subtotal'). \
show()
+-------------------+------------------------+
|order_item_order_id|sum(order_item_subtotal)|
+-------------------+------------------------+
| 29| 1109.85|
| 474| 774.8199999999999|
| 964| 739.8800000000001|
| 1677| 649.9200000000001|
| 1806| 789.94|
| 1950| 1015.8700000000001|
| 2214| 449.96|
| 2250| 889.94|
| 2453| 999.9300000000001|
| 2509| 889.94|
| 2529| 59.99|
| 2927| 999.9100000000001|
| 3091| 469.93000000000006|
| 3764| 95.98|
| 4590| 949.83|
| 4894| 899.94|
| 5385| 629.86|
| 5409| 699.9200000000001|
| 6721| 139.99|
| 7225| 774.86|
+-------------------+------------------------+
only showing top 20 rows
order_items. \
groupBy('order_item_order_id'). \
agg(sum('order_item_subtotal').alias('revenue_per_order')). \
show()
+-------------------+------------------+
|order_item_order_id| revenue_per_order|
+-------------------+------------------+
| 29| 1109.85|
| 474| 774.8199999999999|
| 964| 739.8800000000001|
| 1677| 649.9200000000001|
| 1806| 789.94|
| 1950|1015.8700000000001|
| 2214| 449.96|
| 2250| 889.94|
| 2453| 999.9300000000001|
| 2509| 889.94|
| 2529| 59.99|
| 2927| 999.9100000000001|
| 3091|469.93000000000006|
| 3764| 95.98|
| 4590| 949.83|
| 4894| 899.94|
| 5385| 629.86|
| 5409| 699.9200000000001|
| 6721| 139.99|
| 7225| 774.86|
+-------------------+------------------+
only showing top 20 rows
order_items. \
groupBy('order_item_order_id'). \
agg(round(sum('order_item_subtotal'), 2).alias('revenue_per_order')). \
show()
+-------------------+-----------------+
|order_item_order_id|revenue_per_order|
+-------------------+-----------------+
| 29| 1109.85|
| 474| 774.82|
| 964| 739.88|
| 1677| 649.92|
| 1806| 789.94|
| 1950| 1015.87|
| 2214| 449.96|
| 2250| 889.94|
| 2453| 999.93|
| 2509| 889.94|
| 2529| 59.99|
| 2927| 999.91|
| 3091| 469.93|
| 3764| 95.98|
| 4590| 949.83|
| 4894| 899.94|
| 5385| 629.86|
| 5409| 699.92|
| 6721| 139.99|
| 7225| 774.86|
+-------------------+-----------------+
only showing top 20 rows
Get min and max order_item_subtotal for each order id.
from pyspark.sql.functions import min, max
order_items. \
groupBy('order_item_order_id'). \
agg(
round(sum('order_item_subtotal'), 2).alias('revenue_per_order'),
min('order_item_subtotal').alias('order_item_subtotal_min'),
max('order_item_subtotal').alias('order_item_subtotal_max')
). \
show()
+-------------------+-----------------+-----------------------+-----------------------+
|order_item_order_id|revenue_per_order|order_item_subtotal_min|order_item_subtotal_max|
+-------------------+-----------------+-----------------------+-----------------------+
| 39713| 599.97| 199.99| 399.98|
| 40395| 939.94| 50.0| 399.98|
| 40436| 229.98| 99.99| 129.99|
| 40557| 549.95| 50.0| 200.0|
| 40634| 1119.88| 99.99| 499.95|
| 41424| 829.95| 129.99| 299.98|
| 41895| 649.9| 50.0| 199.99|
| 41988| 669.9| 129.99| 299.95|
| 42126| 899.88| 249.9| 399.98|
| 42852| 1039.88| 129.99| 399.98|
| 42969| 561.96| 31.99| 399.98|
| 43367| 1079.87| 129.99| 399.98|
| 44134| 759.92| 119.98| 399.98|
| 44342| 119.98| 119.98| 119.98|
| 44901| 1129.87| 129.99| 399.98|
| 45166| 424.92| 50.0| 129.99|
| 45298| 759.93| 79.98| 199.99|
| 45726| 1149.88| 119.97| 399.98|
| 46044| 1159.9| 39.98| 399.98|
| 46424| 509.84| 109.94| 149.94|
+-------------------+-----------------+-----------------------+-----------------------+
only showing top 20 rows