The use of Window in Apache Spark

When processing data we often find ourselves in a situation where we want to calculate
variables over certain subset of observations. For example, we might be interested in the
average value per group or the maximum value for each group. The groupBy function
available in many programming or consulting programs allows us to do these calculations easily, simply by specifying the variables that define the final group and the functions that we want to apply within each group.

Normally after applying groupBy, only one observation is obtained for each of the groups are defined, which means that if we wanted to compare the aggregates obtained with the
observations that generated them, we would have to create a specific function or use a join of the original table with the results of the groupBy.

To facilitate tasks like this, many programming languages have the Window function. This has a similar function to groupBy with the difference that Window does not modify the number of rows of the DataFrame. In this post we will explain how to use Window in Apache Spark, specifically in its implementation in pySpark.

To compare the behaviour of groupBy with Window, let’s imagine the following problem: We have a set of students and for each one we have the class they were in and the grade
obtained. From this data, we want to determine if a student got a grade higher than the
average of his class.

Let’s generate some simulated data for testing:

~~~~ python
grades = [('ES001', 'English', 18), 
          ('ES002', 'English', 14),
          ('ES003', 'English', 19), 
          ('ES001', 'French', 13), 
          ('ES002', 'French', 16), 
          ('ES003', 'French', 11), 
          ('ES001', 'Maths', 15), 
          ('ES002', 'Maths', 17), 
          ('ES003', 'Maths', 19)]

gradesSpark = spark.createDataFrame(grades, schema=['student', 'course', 'grade']).cache() 

gradesSpark.show()
~~~~

~~~~ console
+-------+-------+-----+
|student| course|grade|
+-------+-------+-----+
|  ES001|English|   18|
|  ES002|English|   14|
|  ES003|English|   19|
|  ES001| French|   13|
|  ES002| French|   16|
|  ES003| French|   11|
|  ES001|  Maths|   15|
|  ES002|  Maths|   17|
|  ES003|  Maths|   19|
+-------+-------+-----+
~~~~

To solve this problem using only groupBy we first group by class to obtain the average and then to match this variable with the original data we use a join:

~~~~ python
result = gradesSpark.join(gradesSpark.groupBy('course').agg(avg(col('grade')).alias('courseAvg')), 
                          on=['course'],
                          how='left')

result = result.withColumn('aboveAvg', when(col('grade')>col('courseAvg'), lit(1)).otherwise(lit(0)))

result.show()
~~~~

~~~~ console
+-------+-------+-----+------------------+--------+
| course|student|grade|         courseAvg|aboveAvg|
+-------+-------+-----+------------------+--------+
|English|  ES001|   18|              17.0|       1|
|English|  ES002|   14|              17.0|       0|
|English|  ES003|   19|              17.0|       1|
| French|  ES001|   13|13.333333333333334|       0|
| French|  ES002|   16|13.333333333333334|       1|
| French|  ES003|   11|13.333333333333334|       0|
|  Maths|  ES001|   15|              17.0|       0|
|  Maths|  ES002|   17|              17.0|       0|
|  Maths|  ES003|   19|              17.0|       1|
+-------+-------+-----+------------------+--------+
~~~~

This same example using Window is much simpler since we only have to define the window in which we want to apply the function and then create the variable using the over() function indicating the window we want to use for the calculation:

~~~~ python
win = Window.partitionBy('course')

result = gradesSpark.withColumn('courseAvg', avg(col('grade')).over(win))

result = result.withColumn('aboveAvg', when(col('grade')>col('courseAvg'), lit(1)).otherwise(lit(0)))

result.show()
~~~~

~~~~ console
+-------+-------+-----+------------------+--------+
|student| course|grade|         courseAvg|aboveAvg|
+-------+-------+-----+------------------+--------+
|  ES003|English|   19|              17.0|       1|
|  ES001|English|   18|              17.0|       1|
|  ES002|English|   14|              17.0|       0|
|  ES001| French|   13|13.333333333333334|       0|
|  ES002| French|   16|13.333333333333334|       1|
|  ES003| French|   11|13.333333333333334|       0|
|  ES001|  Maths|   15|              17.0|       0|
|  ES002|  Maths|   17|              17.0|       0|
|  ES003|  Maths|   19|              17.0|       1|
+-------+-------+-----+------------------+--------+
~~~~

Window coding is much less complicated and more efficient since there is no need to do a join which can often be very computationally expensive.

Defining windows

Spark provides us with different functions that we can combine to sit many window types easily:

partitionBy

This function is the one we used in the previous example and serves to define the groups
that we have inside the data. It needs as an argument a list with the variables that delimit
such groups as when we use a groupBy.

orderBy

It allows us to organise the data within a group. In addition, when using orderBy alone or
with partitionBy when defining the window, the calculations are performed for all rows whose value precedes or is equal to that of the current row. This way, if two rows have the same value for the variable being ordered, both will enter the calculation.

To exemplify the use of orderBy let’s use data on the contribution of different sectors to PIB and imagine that we want to keep the largest sectors whose combined contribution is at least 80%. For this we can use a window with orderBy to define a descending order
according to the percentage of PIB and sum up:

~~~~ python
gdp = [('Agriculture', 5), 
       ('Telecom', 25), 
       ('Tourism', 40), 
       ('Petrochemical', 25),
       ('Construction', 5)]

gdpSpark = spark.createDataFrame(gdp, schema=['sector', 'percGdp'])

win = Window().orderBy(col('percGdp').desc())
win2 = Window().orderBy(col('percGdp').desc()).rowsBetween(Window.unboundedPreceding, Window.currentRow)


gdpSpark.withColumn('percBiggerSectors', sum(col('percGdp')).over(win)).show()
~~~~

~~~~ console
+-------------+-------+-----------------+
|       sector|percGdp|percBiggerSectors|
+-------------+-------+-----------------+
|      Tourism|     40|               40|
|      Telecom|     25|               90|
|Petrochemical|     25|               90|
| Construction|      5|              100|
|  Agriculture|      5|              100|
+-------------+-------+-----------------+
~~~~

Note that for Telecommunications the value of contributions MasGrandes is 90=40+25+25, since orderBy does not take into account that Petroquímica is in the next row, but that the value 25 is the same in both rows. Another important feature that emerges from this example is that it is not necessary to use partitionBy together with orderBy. If partionBy is not included, in the group is simply all data.

orderBy is also useful when we want to use order to create variables. For example, with the grade data we used in the previous section we could create a ranking of students according to their grade within each class:

~~~~ python
win = Window.partitionBy('course').orderBy(col('grade').desc())

result = gradesSpark.withColumn('ranking', rank().over(win))

result.show()
~~~~

~~~~ console
+-------+-------+-----+-------+
|student| course|grade|ranking|
+-------+-------+-----+-------+
|  ES003|English|   19|      1|
|  ES001|English|   18|      2|
|  ES002|English|   14|      3|
|  ES002| French|   16|      1|
|  ES001| French|   13|      2|
|  ES003| French|   11|      3|
|  ES003|  Maths|   19|      1|
|  ES002|  Maths|   17|      2|
|  ES001|  Maths|   15|      3|
+-------+-------+-----+-------+
~~~~

rowsBetween

This function helps us define dynamic windows that depend on the row in which a calculation is being made and requires two elements that indicate the beginning and the end of the dynamic window. These two values are integers relative to the current row, for example rowsBetween(-1, 1) would indicate that the window goes from the row before the current row to the row after the current row. However, Apache Spark provides and recommends using, where possible, several functions to define these intervals, in order to minimise human error and increase the readability of the code:

  • currentRow indicates the current row where the calculation is being performed.
  • unboundedPreceding indicates the first row of the group.
  • unboundedFollowing indicates the last row of the group.

To observe how rowsBetween works, let’s assume we have the following data about the
price of a stock:

~~~~ python
prices = [('AAPL', '2021-01-01', 110), 
          ('AAPL', '2021-01-02', 120),
          ('AAPL', '2021-01-03', 110), 
          ('AAPL', '2021-01-04', 115), 
          ('AAPL', '2021-01-05', 150), 
          ('MSFT', '2021-01-01', 150), 
          ('MSFT', '2021-01-02', 130), 
          ('MSFT', '2021-01-03', 140), 
          ('MSFT', '2021-01-04', 120), 
          ('MSFT', '2021-01-05', 140)]

pricesSpark = spark.createDataFrame(prices, schema=['company', 'date', 'price'])\
                    .withColumn('date', to_date(col('date'))).cache() 


pricesSpark.show()
~~~~

~~~~ console
+-------+----------+-----+
|company|      date|price|
+-------+----------+-----+
|   AAPL|2021-01-01|  110|
|   AAPL|2021-01-02|  120|
|   AAPL|2021-01-03|  110|
|   AAPL|2021-01-04|  115|
|   AAPL|2021-01-05|  150|
|   MSFT|2021-01-01|  150|
|   MSFT|2021-01-02|  130|
|   MSFT|2021-01-03|  140|
|   MSFT|2021-01-04|  120|
|   MSFT|2021-01-05|  140|
+-------+----------+-----+
~~~~

If we want to calculate a moving average from the first date to the date in the current row, we could do the following in the current row:

~~~~ python
win = Window.partitionBy('company').orderBy('date').rowsBetween(Window.unboundedPreceding, Window.currentRow)
win2 = Window.partitionBy('company').orderBy('date')

pricesSpark.withColumn('histAvg', avg(col('price')).over(win))\
            .withColumn('histAvg2', avg(col('price')).over(win2)).show()
~~~~

~~~~ console
+-------+----------+-----+------------------+------------------+
|company|      date|price|           histAvg|          histAvg2|
+-------+----------+-----+------------------+------------------+
|   AAPL|2021-01-01|  110|             110.0|             110.0|
|   AAPL|2021-01-02|  120|             115.0|             115.0|
|   AAPL|2021-01-03|  110|113.33333333333333|113.33333333333333|
|   AAPL|2021-01-04|  115|            113.75|            113.75|
|   AAPL|2021-01-05|  150|             121.0|             121.0|
|   MSFT|2021-01-01|  150|             150.0|             150.0|
|   MSFT|2021-01-02|  130|             140.0|             140.0|
|   MSFT|2021-01-03|  140|             140.0|             140.0|
|   MSFT|2021-01-04|  120|             135.0|             135.0|
|   MSFT|2021-01-05|  140|             136.0|             136.0|
+-------+----------+-----+------------------+------------------+
~~~~

See that using the win window or the win 2 window it gives the same result because in this case there are no repeated dates within each group, so the behaviour of orderBy does not manifest itself. However, if we use the example with data from the PIB we notice that when using rowsBetween the data used to calculate the sum is only up to the current row, so that in contributionMostGreatest2 the telecommunication value is only 65 = 40 + 25:

~~~~ python
win = Window().orderBy(col('percGdp').desc())
win2 = Window().orderBy(col('percGdp').desc()).rowsBetween(Window.unboundedPreceding, Window.currentRow)

gdpSpark.withColumn('percBiggerSectors', sum(col('percGdp')).over(win))\
        .withColumn('percBiggerSectors2', sum(col('percGdp')).over(win2)).show()
~~~~

~~~~ console
+-------------+-------+-----------------+------------------+
|       sector|percGdp|percBiggerSectors|percBiggerSectors2|
+-------------+-------+-----------------+------------------+
|      Tourism|     40|               40|                40|
|      Telecom|     25|               90|                65|
|Petrochemical|     25|               90|                90|
|  Agriculture|      5|              100|                95|
| Construction|      5|              100|               100|
+-------------+-------+-----------------+------------------+
~~~~

rangeBetween

This function is similar to rowsBetween but works with the values of the variable used inside orderBy instead of the rows. A window created with rangeBetween has the same behaviour as when using orderBy alone but it allows us to define a different range and not only from the first value to the current one. It receives two arguments that indicate how many values below and above the current value you want within the window.

For example, if the salary you have for a row is 10 euros and you have a window defined by window.orderBy(‘salary’).rangeBetween(-5, 5); if we carry out a calculation on this window the set that will be taken into account for this row will be all the rows whose wages are between 5 and 15 euros.

Using the students’ data we might want to compare them with those students who get similar grades, for example we could compare them with the average of their class but only those students who get +-2 points more than them. Using rangeBetween the code would be:

~~~~ python
win = Window.partitionBy('course').orderBy('grade').rangeBetween(-2, 2)

result = gradesSpark.withColumn('similarStudentsAvg', mean(col('grade')).over(win))

result = result.withColumn('aboveAvg', when(col('grade')>=col('similarStudentsAvg'), lit(1)).otherwise(lit(0)))

result.show()
~~~~

~~~~ console
+-------+-------+-----+------------------+--------+
|student| course|grade|similarStudentsAvg|aboveAvg|
+-------+-------+-----+------------------+--------+
|  ES002|English|   14|              14.0|       1|
|  ES001|English|   18|              18.5|       0|
|  ES003|English|   19|              18.5|       1|
|  ES003| French|   11|              12.0|       0|
|  ES001| French|   13|              12.0|       1|
|  ES002| French|   16|              16.0|       1|
|  ES001|  Maths|   15|              16.0|       0|
|  ES002|  Maths|   17|              17.0|       1|
|  ES003|  Maths|   19|              18.0|       1|
+-------+-------+-----+------------------+--------+
~~~~

Functions to create variables with windows

In Apache Spark we can divide the functions that can be used on a window into two main
groups. In addition, users can define their own functions, just like when using groupBy (the use of udfs should be avoided as they tend to perform very poorly).

Analytical functions

They return a value for each row of a set, which may be different for each row.

  • rank → Allows us to assign to each row within a group a natural number that indicates the place the row has according to the value of the variable used in the orderBy. If two rows are tied, they are assigned the same number in the rank and a number is skipped for the next row.
  • denseRank → Has the same purpose as rank only that when there are ties a number is not skipped for the row after the tied rows.
  • percentRank → Gives us the percentile that corresponds to the row when the data is sorted according to the variable used in orderBy.
  • rowNumber → Allows us to assign to each row within a group a different natural number that indicates the place that the row has according to the value of the variable used in the orderBy. If two rows are tied the ranking is assigned randomly so the results are not always equal.
  • nTile(n) → Used to distribute the rows of a group among n groups. It is very useful when you want to make an ordinal variable from a numeric variable; for example 4 income levels with the same number of observations each.
  • lag(col, n, default) → Gives us the value of the column col, n rows before the current row.
    If n is a negative number it implies that we get the value that the column col, n rows after the current row has. If the group does not have enough rows before or after the current row lag returns nan unless the function is called with the optional parameter default in which case it will return said value. The order of the rows is given by the column used in the window’s orderBy so when using lag it is necessary that the window contains this clause.
  • lead(col, n, default) → Same as lag only that positive numbers imply later rows and
    negative numbers earlier rows.
  • cumeDist → Gives the cumulative probability of each row according to the variable used in orderBy. In other words, it tells us the proportion of rows whose value in the variable used in orderBy is less than or equal to the value of that variable for the current row.
  • last(col, ignorenulls) → Returns the last value observed in the col column before the current row within the specified window. Requires the window to have an orderBy. The argument ignorenulls=True allows ignoring all missing values which causes the first non-null value before the current row to be returned.

Aggregation functions

All the functions that compact a set of data into a single value that represents the set, such as sum, min, max, avg and count functions.

Additional usage examples

rangeBetween and window dates

To apply rangeBetween to create window dates, simply transform the dates to timestamp and divide by the number of seconds in the period we want to calculate the window, for example dividing by 86400 for days or 3600 for hours. In the following example, if we wanted to find the three day moving average we couldn’t use rowsBetween as before because we have missing dates but we can use rangeBetween:

~~~~ python
prices2 = [('AAPL', '2021-01-05', 110), 
           ('AAPL', '2021-01-07', 120),
           ('AAPL', '2021-01-13', 110), 
           ('AAPL', '2021-01-14', 115), 
           ('AAPL', '2021-01-15', 150), 
           ('MSFT', '2021-01-01', 150), 
           ('MSFT', '2021-01-05', 130), 
           ('MSFT', '2021-01-06', 140), 
           ('MSFT', '2021-01-09', 120), 
           ('MSFT', '2021-01-12', 140)]

pricesSpark2 = spark.createDataFrame(prices2, schema=['company', 'date', 'price'])\
                     .withColumn('date', to_date(col('date'))).cache() 

pricesSpark2.show()
~~~~

~~~~ console
+-------+----------+-----+
|company|      date|price|
+-------+----------+-----+
|   AAPL|2021-01-05|  110|
|   AAPL|2021-01-07|  120|
|   AAPL|2021-01-13|  110|
|   AAPL|2021-01-14|  115|
|   AAPL|2021-01-15|  150|
|   MSFT|2021-01-01|  150|
|   MSFT|2021-01-05|  130|
|   MSFT|2021-01-06|  140|
|   MSFT|2021-01-09|  120|
|   MSFT|2021-01-12|  140|
+-------+----------+-----+
~~~~

~~~~ python
days = lambda i: i * 86400 
win = Window.partitionBy('company')\
            .orderBy(col('date').cast("timestamp").cast("long"))\
            .rangeBetween(-days(3), 0)

pricesSpark2.withColumn('movingAvg3d', avg(col('price')).over(win)).show()
~~~~

~~~~ console
+-------+----------+-----+-----------+
|company|      date|price|movingAvg3d|
+-------+----------+-----+-----------+
|   AAPL|2021-01-05|  110|      110.0|
|   AAPL|2021-01-07|  120|      115.0|
|   AAPL|2021-01-13|  110|      110.0|
|   AAPL|2021-01-14|  115|      112.5|
|   AAPL|2021-01-15|  150|      125.0|
|   MSFT|2021-01-01|  150|      150.0|
|   MSFT|2021-01-05|  130|      130.0|
|   MSFT|2021-01-06|  140|      135.0|
|   MSFT|2021-01-09|  120|      130.0|
|   MSFT|2021-01-12|  140|      130.0|
+-------+----------+-----+-----------+
~~~~

Creating lags for time series models

If we have a dataset that has no missing dates then obtaining a lagged variable is as simple as using lag with the desired window. Let’s exemplify by applying this procedure to the set of pricesSpark:

~~~~ python
days = lambda i: i * 86400 
win = Window.partitionBy('company')\
            .orderBy(col('date').cast("timestamp").cast("long"))
pricesSpark2.withColumn('price_1', lag(col('price')).over(win)).show()
~~~~

~~~~ console
+-------+----------+-----+-------+
|company|      date|price|price_1|
+-------+----------+-----+-------+
|   AAPL|2021-01-05|  110|   null|
|   AAPL|2021-01-07|  120|    110|
|   AAPL|2021-01-13|  110|    120|
|   AAPL|2021-01-14|  115|    110|
|   AAPL|2021-01-15|  150|    115|
|   MSFT|2021-01-01|  150|   null|
|   MSFT|2021-01-05|  130|    150|
|   MSFT|2021-01-06|  140|    130|
|   MSFT|2021-01-09|  120|    140|
|   MSFT|2021-01-12|  140|    120|
+-------+----------+-----+-------+
~~~~

However, in the case of the dataset pricesSpark2 it’s not possible to implement this
methodology since the value of the previous row can correspond to either 1 day before or
several days before. In this case we could use any aggregation function (e.g. sum) together with a window containing rangeBetween to limit the dates on which the function is calculated to exactly the number of previous days we want:

~~~~ python
days = lambda i: i * 86400 
win = Window.partitionBy('company')\
            .orderBy(col('date').cast("timestamp").cast("long"))\
            .rangeBetween(-days(1),-days(1))
pricesSpark2.withColumn('price_1', sum(col('price')).over(win)).show()
~~~~

~~~~ console
+-------+----------+-----+-------+
|company|      date|price|price_1|
+-------+----------+-----+-------+
|   AAPL|2021-01-05|  110|   null|
|   AAPL|2021-01-07|  120|   null|
|   AAPL|2021-01-13|  110|   null|
|   AAPL|2021-01-14|  115|    110|
|   AAPL|2021-01-15|  150|    115|
|   MSFT|2021-01-01|  150|   null|
|   MSFT|2021-01-05|  130|   null|
|   MSFT|2021-01-06|  140|    130|
|   MSFT|2021-01-09|  120|   null|
|   MSFT|2021-01-12|  140|   null|
+-------+----------+-----+-------+
~~~~

Time that has passed since an event

In many situations we find ourselves with datasets where in each row we have an event and we want to know the date between some of them. For example, we could have a dataset with the logs of certain machines along with the failures they have had:

~~~~ python
events = [('m1', '2021-01-01', 'e1'), 
          ('m1', '2021-02-02', 'e1'),
          ('m1', '2021-03-03', 'e2'), 
          ('m1', '2021-03-04', 'critical'), 
          ('m1', '2021-07-05', 'e1'), 
          ('m2', '2021-01-01', 'e1'), 
          ('m2', '2021-06-02', 'critical'), 
          ('m2', '2021-08-03', 'e2'), 
          ('m2', '2021-09-04', 'e1'), 
          ('m2', '2021-10-17', 'critical'),
          ('m2', '2021-10-25', 'e1')]

eventsSpark = spark.createDataFrame(events, schema=['machine', 'date', 'error'])\
                   .withColumn('date', to_date(col('date'))).cache() 
    
eventsSpark.show()
~~~~

~~~~ console
+-------+----------+--------+
|machine|      date|   error|
+-------+----------+--------+
|     m1|2021-01-01|      e1|
|     m1|2021-02-02|      e1|
|     m1|2021-03-03|      e2|
|     m1|2021-03-04|critical|
|     m1|2021-07-05|      e1|
|     m2|2021-01-01|      e1|
|     m2|2021-06-02|critical|
|     m2|2021-08-03|      e2|
|     m2|2021-09-04|      e1|
|     m2|2021-10-17|critical|
|     m2|2021-10-25|      e1|
+-------+----------+--------+
~~~~

If we are interested in the time that has elapsed from each of the error types to a critical error we could use the last function along with some auxiliary variables that have missing values unless the error is of a certain type:

~~~~ python
errors = ['e1', 'e2']

# We create variables that have a value if the error is of a certain type
for i in errors:
    eventsSpark = eventsSpark.withColumn('date_'+i, when(col('error')==i, col('date')).otherwise(None))
    
    
# We calculate the time from each type of error until a critical error
win = Window.partitionBy('machine').orderBy('date')
for i in errors:
    eventsSpark = eventsSpark.withColumn('daysFromLast_'+i, last(col('date_'+i), True).over(win))\
                             .withColumn('daysFromLast_'+i, 
                                         (unix_timestamp(col('date'))-unix_timestamp(col('daysFromLast_'+i)))/86400)

eventsSpark.filter(col('error')=='critical').show()
~~~~

~~~~ console
+-------+----------+--------+-------+-------+-----------+-----------+-----------+-----------+---------------+---------------+
|machine|      date|   error|date_e1|date_e2|dateLast_e1|timeLast_e1|dateLast_e2|timeLast_e2|daysFromLast_e1|daysFromLast_e2|
+-------+----------+--------+-------+-------+-----------+-----------+-----------+-----------+---------------+---------------+
|     m1|2021-03-04|critical|   null|   null| 2021-02-02|       30.0| 2021-03-03|        1.0|           30.0|            1.0|
|     m2|2021-06-02|critical|   null|   null| 2021-01-01|      152.0|       null|       null|          152.0|           null|
|     m2|2021-10-17|critical|   null|   null| 2021-09-04|       43.0| 2021-08-03|       75.0|           43.0|           75.0|
+-------+----------+--------+-------+-------+-----------+-----------+-----------+-----------+---------------+---------------+
~~~~
Conclusion

In this post we have learned what are windows in Spark, the syntax for their definition and the most common functions used with them. In addition, we have seen examples of tasks that can be solved with windows as well as the advantages of its use.

In future posts we will talk about how to define a simple demand model and a recursive strategy to find the optimal selling price. In the meantime, we encourage you to visit the Data Science category of the Damavis blog and see other articles similar to this one.

We invite you to share this article with your friends. Remember to tag us to let us know what you think about it (@DamavisStudio). See you on Networks!
Carlos Rodriguez
Carlos Rodriguez
Articles: 10