grouping consecutive rows in PySpark Dataframe - python

I have the following example Spark DataFrame:
rdd = sc.parallelize([(1,"19:00:00", "19:30:00", 30), (1,"19:30:00", "19:40:00", 10),(1,"19:40:00", "19:43:00", 3), (2,"20:00:00", "20:10:00", 10), (1,"20:05:00", "20:15:00", 10),(1,"20:15:00", "20:35:00", 20)])
df = spark.createDataFrame(rdd, ["user_id", "start_time", "end_time", "duration"])
df.show()
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:30:00| 30|
| 1| 19:30:00|19:40:00| 10|
| 1| 19:40:00|19:43:00| 3|
| 2| 20:00:00|20:10:00| 10|
| 1| 20:05:00|20:15:00| 10|
| 1| 20:15:00|20:35:00| 20|
+-------+----------+--------+--------+
I want to group consecutive rows based on the start and end times. For instance, for the same user_id, if a row's start time is the same as the previous row's end time, I want to group them together and sum the duration.
The desired result is:
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:43:00| 43|
| 2| 20:00:00|20:10:00| 10|
| 1| 20:05:00|20:35:00| 30|
+-------+----------+--------+--------+
The first three rows of the dataframe were grouped together because they all correspond to user_id 1 and the start times and end times form a continuous timeline.
This was my initial approach:
Use the lag function to get the next start time:
from pyspark.sql.functions import *
from pyspark.sql import Window
import sys
# compute next start time
window = Window.partitionBy('user_id').orderBy('start_time')
df = df.withColumn("next_start_time", lag(df.start_time, -1).over(window))
df.show()
+-------+----------+--------+--------+---------------+
|user_id|start_time|end_time|duration|next_start_time|
+-------+----------+--------+--------+---------------+
| 1| 19:00:00|19:30:00| 30| 19:30:00|
| 1| 19:30:00|19:40:00| 10| 19:40:00|
| 1| 19:40:00|19:43:00| 3| 20:05:00|
| 1| 20:05:00|20:15:00| 10| 20:15:00|
| 1| 20:15:00|20:35:00| 20| null|
| 2| 20:00:00|20:10:00| 10| null|
+-------+----------+--------+--------+---------------+
get the difference between the current row's end time and the next row's start time:
time_fmt = "HH:mm:ss"
timeDiff = unix_timestamp('next_start_time', format=time_fmt) - unix_timestamp('end_time', format=time_fmt)
df = df.withColumn("difference", timeDiff)
df.show()
+-------+----------+--------+--------+---------------+----------+
|user_id|start_time|end_time|duration|next_start_time|difference|
+-------+----------+--------+--------+---------------+----------+
| 1| 19:00:00|19:30:00| 30| 19:30:00| 0|
| 1| 19:30:00|19:40:00| 10| 19:40:00| 0|
| 1| 19:40:00|19:43:00| 3| 20:05:00| 1320|
| 1| 20:05:00|20:15:00| 10| 20:15:00| 0|
| 1| 20:15:00|20:35:00| 20| null| null|
| 2| 20:00:00|20:10:00| 10| null| null|
+-------+----------+--------+--------+---------------+----------+
Now my idea was to use the sum function with a window to get the cumulative sum of duration and then do a groupBy. But my approach was flawed for many reasons.

Here's one approach:
Gather together rows into groups where a group is a set of rows with the same user_id that are consecutive (start_time matches previous end_time). Then you can use this group to do your aggregation.
A way to get here is by creating intermediate indicator columns to tell you if the user has changed or the time is not consecutive. Then perform a cumulative sum over the indicator column to create the group.
For example:
import pyspark.sql.functions as f
from pyspark.sql import Window
w1 = Window.orderBy("start_time")
df = df.withColumn(
"userChange",
(f.col("user_id") != f.lag("user_id").over(w1)).cast("int")
)\
.withColumn(
"timeChange",
(f.col("start_time") != f.lag("end_time").over(w1)).cast("int")
)\
.fillna(
0,
subset=["userChange", "timeChange"]
)\
.withColumn(
"indicator",
(~((f.col("userChange") == 0) & (f.col("timeChange")==0))).cast("int")
)\
.withColumn(
"group",
f.sum(f.col("indicator")).over(w1.rangeBetween(Window.unboundedPreceding, 0))
)
df.show()
#+-------+----------+--------+--------+----------+----------+---------+-----+
#|user_id|start_time|end_time|duration|userChange|timeChange|indicator|group|
#+-------+----------+--------+--------+----------+----------+---------+-----+
#| 1| 19:00:00|19:30:00| 30| 0| 0| 0| 0|
#| 1| 19:30:00|19:40:00| 10| 0| 0| 0| 0|
#| 1| 19:40:00|19:43:00| 3| 0| 0| 0| 0|
#| 2| 20:00:00|20:10:00| 10| 1| 1| 1| 1|
#| 1| 20:05:00|20:15:00| 10| 1| 1| 1| 2|
#| 1| 20:15:00|20:35:00| 20| 0| 0| 0| 2|
#+-------+----------+--------+--------+----------+----------+---------+-----+
Now that we have the group column, we can aggregate as follows to get the desired result:
df.groupBy("user_id", "group")\
.agg(
f.min("start_time").alias("start_time"),
f.max("end_time").alias("end_time"),
f.sum("duration").alias("duration")
)\
.drop("group")\
.show()
#+-------+----------+--------+--------+
#|user_id|start_time|end_time|duration|
#+-------+----------+--------+--------+
#| 1| 19:00:00|19:43:00| 43|
#| 1| 20:05:00|20:35:00| 30|
#| 2| 20:00:00|20:10:00| 10|
#+-------+----------+--------+--------+

Here is a working solution derived from Pault's answer:
Create the Dataframe:
rdd = sc.parallelize([(1,"19:00:00", "19:30:00", 30), (1,"19:30:00", "19:40:00", 10),(1,"19:40:00", "19:43:00", 3), (2,"20:00:00", "20:10:00", 10), (1,"20:05:00", "20:15:00", 10),(1,"20:15:00", "20:35:00", 20)])
df = spark.createDataFrame(rdd, ["user_id", "start_time", "end_time", "duration"])
df.show()
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:30:00| 30|
| 1| 19:30:00|19:40:00| 10|
| 1| 19:40:00|19:43:00| 3|
| 1| 20:05:00|20:15:00| 10|
| 1| 20:15:00|20:35:00| 20|
+-------+----------+--------+--------+
Create an indicator column that indicates whenever the time has changed, and use cumulative sum to give each group a unique id:
import pyspark.sql.functions as f
from pyspark.sql import Window
w1 = Window.partitionBy('user_id').orderBy('start_time')
df = df.withColumn(
"indicator",
(f.col("start_time") != f.lag("end_time").over(w1)).cast("int")
)\
.fillna(
0,
subset=[ "indicator"]
)\
.withColumn(
"group",
f.sum(f.col("indicator")).over(w1.rangeBetween(Window.unboundedPreceding, 0))
)
df.show()
+-------+----------+--------+--------+---------+-----+
|user_id|start_time|end_time|duration|indicator|group|
+-------+----------+--------+--------+---------+-----+
| 1| 19:00:00|19:30:00| 30| 0| 0|
| 1| 19:30:00|19:40:00| 10| 0| 0|
| 1| 19:40:00|19:43:00| 3| 0| 0|
| 1| 20:05:00|20:15:00| 10| 1| 1|
| 1| 20:15:00|20:35:00| 20| 0| 1|
+-------+----------+--------+--------+---------+-----+
Now GroupBy on user id and the group variable.
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:43:00| 43|
| 1| 20:05:00|20:35:00| 30|
+-------+----------+--------+--------+

Related

How to generate date values into a dataframe in PySpark?

May I know what's wrong with below code? it does not print anything.
from pyspark.sql import SparkSession
from pyspark.sql.functions import current_date, current_timestamp,last_day,next_day, date_format, date_add, year, month, dayofmonth, dayofyear, dayofweek, date_trunc, date_sub, to_date, add_months, weekofyear, quarter, col
from pyspark.sql.types import StructType,StructField,StringType, IntegerType
ss = SparkSession.builder.appName('DateDim').master('local[1]').getOrCreate()
df = ss.createDataFrame([],StructType([]))
current_date()
df = df.select(current_date().alias("current_date"),next_day(current_date(), 'sunday').alias("next_day"),dayofweek(current_date()).alias("day_of_week"),dayofmonth(current_date()).alias("day_of_month"),dayofyear(current_date()).alias("day_of_year"),last_day(current_date()).alias("last_day"),year(current_date()).alias("year"),month(current_date()).alias("month"), weekofyear(current_date()).alias("week_of_year"),quarter(current_date()).alias("quarter")).collect()
print(df)
for i in range(1, 1000):
print(i)
for i in range(1, 1000):
v_date = date_add(v_date, i)
df.unionAll(df.select(v_date.alias("current_date"),next_day(v_date,'sunday').alias("next_day"),dayofweek(v_date).alias("day_of_week"),dayofmonth(v_date).alias("day_of_month"),dayofyear(v_date).alias("day_of_year"),last_day(v_date).alias("last_day"),year(v_date).alias("year"),month(v_date).alias("month"), weekofyear(v_date).alias("week_of_year"),quarter(v_date).alias("quarter")))
df.show()
You're getting zero rows because there are no rows in the initial df. Any columns being created will have no values as there are no rows in df.
It seems you're trying to create a dataframe with 1000 dates starting from the current day. There's a simple approach using sequence function.
data_sdf = spark.createDataFrame([(1,)], 'id string')
data_sdf. \
withColumn('min_dt', func.current_date().cast('date')). \
withColumn('max_dt', func.date_add('min_dt', 1000).cast('date')). \
withColumn('all_dates', func.expr('sequence(min_dt, max_dt, interval 1 day)')). \
withColumn('dates_exp', func.explode('all_dates')). \
drop('id'). \
show(10)
# +----------+----------+--------------------+----------+
# | min_dt| max_dt| all_dates| dates_exp|
# +----------+----------+--------------------+----------+
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-08-27|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-08-28|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-08-29|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-08-30|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-08-31|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-09-01|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-09-02|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-09-03|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-09-04|
# |2022-08-27|2025-05-23|[2022-08-27, 2022...|2022-09-05|
# +----------+----------+--------------------+----------+
# only showing top 10 rows
select the dates_exp field for further use.
You want to use the range() function which generates rows (using sequence you will generate an array which you then need to explode into rows).
That's how you can use it:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, next_day, dayofweek, dayofmonth, dayofyear, last_day, year, month, \
weekofyear, quarter, current_date
spark = SparkSession.builder.getOrCreate()
(
spark
.range(0, 1000)
.alias("id")
.select(
(current_date() + col('id').cast("int")).alias("date")
)
.select(
"date",
next_day("date", 'sunday').alias("next_sunday"),
dayofweek("date").alias("day_of_week"),
dayofmonth("date").alias("day_of_month"),
dayofyear("date").alias("day_of_year"),
last_day("date").alias("last_day"),
year("date").alias("year"),
month("date").alias("month"),
weekofyear("date").alias("week_of_year"),
quarter("date").alias("quarter")
)
).show()
it returns
+----------+-----------+-----------+------------+-----------+----------+----+-----+------------+-------+
| date|next_sunday|day_of_week|day_of_month|day_of_year| last_day|year|month|week_of_year|quarter|
+----------+-----------+-----------+------------+-----------+----------+----+-----+------------+-------+
|2022-09-22| 2022-09-25| 5| 22| 265|2022-09-30|2022| 9| 38| 3|
|2022-09-23| 2022-09-25| 6| 23| 266|2022-09-30|2022| 9| 38| 3|
|2022-09-24| 2022-09-25| 7| 24| 267|2022-09-30|2022| 9| 38| 3|
|2022-09-25| 2022-10-02| 1| 25| 268|2022-09-30|2022| 9| 38| 3|
|2022-09-26| 2022-10-02| 2| 26| 269|2022-09-30|2022| 9| 39| 3|
|2022-09-27| 2022-10-02| 3| 27| 270|2022-09-30|2022| 9| 39| 3|
|2022-09-28| 2022-10-02| 4| 28| 271|2022-09-30|2022| 9| 39| 3|
|2022-09-29| 2022-10-02| 5| 29| 272|2022-09-30|2022| 9| 39| 3|
|2022-09-30| 2022-10-02| 6| 30| 273|2022-09-30|2022| 9| 39| 3|
|2022-10-01| 2022-10-02| 7| 1| 274|2022-10-31|2022| 10| 39| 4|
|2022-10-02| 2022-10-09| 1| 2| 275|2022-10-31|2022| 10| 39| 4|
|2022-10-03| 2022-10-09| 2| 3| 276|2022-10-31|2022| 10| 40| 4|
|2022-10-04| 2022-10-09| 3| 4| 277|2022-10-31|2022| 10| 40| 4|
|2022-10-05| 2022-10-09| 4| 5| 278|2022-10-31|2022| 10| 40| 4|
|2022-10-06| 2022-10-09| 5| 6| 279|2022-10-31|2022| 10| 40| 4|
|2022-10-07| 2022-10-09| 6| 7| 280|2022-10-31|2022| 10| 40| 4|
|2022-10-08| 2022-10-09| 7| 8| 281|2022-10-31|2022| 10| 40| 4|
|2022-10-09| 2022-10-16| 1| 9| 282|2022-10-31|2022| 10| 40| 4|
|2022-10-10| 2022-10-16| 2| 10| 283|2022-10-31|2022| 10| 41| 4|
|2022-10-11| 2022-10-16| 3| 11| 284|2022-10-31|2022| 10| 41| 4|
+----------+-----------+-----------+------------+-----------+----------+----+-----+------------+-------+
only showing top 20 rows

GroupBy a dataframe records and display all columns with PySpark

I have the following dataframe
dataframe - columnA, columnB, columnC, columnD, columnE
I want to groupBy columnC and then consider max value of columnE
dataframe .select('*').groupBy('columnC').max('columnE')
expected output
dataframe - columnA, columnB, columnC, columnD, columnE
Real output
dataframe - columnC, columnE
Why all columns in the dataframe are not displayed as expected ?
For Spark version >= 3.0.0 you can use max_by to select the additional columns.
import random
from pyspark.sql import functions as F
#create some testdata
df = spark.createDataFrame(
[[random.randint(1,3)] + random.sample(range(0, 30), 4) for _ in range(10)],
schema=["columnC", "columnB", "columnA", "columnD", "columnE"]) \
.select("columnA", "columnB", "columnC", "columnD", "columnE")
df.groupBy("columnC") \
.agg(F.max("columnE"),
F.expr("max_by(columnA, columnE) as columnA"),
F.expr("max_by(columnB, columnE) as columnB"),
F.expr("max_by(columnD, columnE) as columnD")) \
.show()
For the testdata
+-------+-------+-------+-------+-------+
|columnA|columnB|columnC|columnD|columnE|
+-------+-------+-------+-------+-------+
| 25| 20| 2| 0| 2|
| 14| 2| 2| 24| 6|
| 26| 13| 3| 2| 1|
| 5| 24| 3| 19| 17|
| 22| 5| 3| 14| 21|
| 24| 5| 1| 8| 4|
| 7| 22| 3| 16| 20|
| 6| 17| 1| 5| 7|
| 24| 22| 2| 8| 3|
| 4| 14| 1| 16| 11|
+-------+-------+-------+-------+-------+
the result is
+-------+------------+-------+-------+-------+
|columnC|max(columnE)|columnA|columnB|columnD|
+-------+------------+-------+-------+-------+
| 1| 11| 4| 14| 16|
| 3| 21| 22| 5| 14|
| 2| 6| 14| 2| 24|
+-------+------------+-------+-------+-------+
What you want to achieve can be done via WINDOW function. Not groupBy
partition your data by columnC
Order your data within each partition in desc (rank)
filter out your desired result.
from pyspark.sql.window import Window
from pyspark.sql.functions import rank
from pyspark.sql.functions import col
windowSpec = Window.partitionBy("columnC").orderBy(col("columnE").desc())
expectedDf = df.withColumn("rank", rank().over(windowSpec)) \
.filter(col("rank") == 1)
You might wanna restructure your question.

pyspark - join with OR condition

I would like to join two pyspark dataframes if at least one of two conditions is satisfied.
Toy data:
df1 = spark.createDataFrame([
(10, 1, 666),
(20, 2, 777),
(30, 1, 888),
(40, 3, 999),
(50, 1, 111),
(60, 2, 222),
(10, 4, 333),
(50, None, 444),
(10, 0, 555),
(50, 0, 666)
],
['var1', 'var2', 'other_var']
)
df2 = spark.createDataFrame([
(10, 1),
(20, 2),
(30, None),
(30, 0)
],
['var1_', 'var2_']
)
I would like to maintain all those rows of df1 where var1 is present in the distinct values of df2.var1_ OR var2 is present in the distinct values of df2.var2_ (but not in the case where such value is 0).
So, the expected output would be
+----+----+---------+-----+-----+
|var1|var2|other_var|var1_|var2_|
+----+----+---------+-----+-----+
| 10| 1| 666| 10| 1| # join on both var1 and var2
| 20| 2| 777| 20| 2| # join on both var1 and var2
| 30| 1| 888| 10| 1| # join on both var1 and var2
| 50| 1| 111| 10| 1| # join on var2
| 60| 2| 222| 20| 2| # join on var2
| 10| 4| 333| 10| 1| # join on var1
| 10| 0| 555| 10| 1| # join on var1
+----+----+---------+-----+-----+
Among the other attempts, I tried
cond = [(df1.var1 == (df2.select('var1_').distinct()).var1_) | (df1.var2 == (df2.filter(F.col('var2_') != 0).select('var2_').distinct()).var2_)]
df1\
.join(df2, how='inner', on=cond)\
.show()
+----+----+---------+-----+-----+
|var1|var2|other_var|var1_|var2_|
+----+----+---------+-----+-----+
| 10| 1| 666| 10| 1|
| 20| 2| 777| 20| 2|
| 30| 1| 888| 10| 1|
| 50| 1| 111| 10| 1|
| 30| 1| 888| 30| null|
| 30| 1| 888| 30| 0|
| 60| 2| 222| 20| 2|
| 10| 4| 333| 10| 1|
| 10| 0| 555| 10| 1|
| 10| 0| 555| 30| 0|
| 50| 0| 666| 30| 0|
+----+----+---------+-----+-----+
but I obtained more rows than expected, and the rows where var2 == 0 were also preserved.
What am I doing wrong?
Note: I'm not using the .isin method because my actual df2 has around 20k rows and I've read here that this method with a large number of IDs could have a bad performance.
Try the condition below:
cond = (df2.var2_ != 0) & ((df1.var1 == df2.var1_) | (df1.var2 == df2.var2_))
df1\
.join(df2, how='inner', on=cond)\
.show()
+----+----+---------+-----+-----+
|var1|var2|other_var|var1_|var2_|
+----+----+---------+-----+-----+
| 10| 1| 666| 10| 1|
| 30| 1| 888| 10| 1|
| 20| 2| 777| 20| 2|
| 50| 1| 111| 10| 1|
| 60| 2| 222| 20| 2|
| 10| 4| 333| 10| 1|
| 10| 0| 555| 10| 1|
+----+----+---------+-----+-----+
The condition should only include the columns from the two dataframes to be joined. If you want to remove var2_ = 0, you can put them as a join condition, rather than as a filter.
There is also no need to specify distinct, because it does not affect the equality condition, and also adds an unnecessary step.

How can we set a flag for last occurence of a value in a column of Pyspark Dataframe

Requirement : Here when last occurence of loyal with value is 1 then set flag as 1 else 0
Input:
+-----------+----------+----------+-------+-----+---------+-------+---+
|consumer_id|product_id| TRX_ID|pattern|loyal| trx_date|row_num| mx|
+-----------+----------+----------+-------+-----+---------+-------+---+
| 11| 1|1152397078| VVVVM| 1| 3/5/2020| 1| 5|
| 11| 1|1152944770| VVVVV| 1| 3/6/2020| 2| 5|
| 11| 1|1153856408| VVVVV| 1|3/15/2020| 3| 5|
| 11| 2|1155884040| MVVVV| 1| 4/2/2020| 4| 5|
| 11| 2|1156854301| MMVVV| 0|4/17/2020| 5| 5|
| 12| 1|1156854302| VVVVM| 1| 3/6/2020| 1| 3|
| 12| 1|1156854303| VVVVV| 1| 3/7/2020| 2| 3|
| 12| 2|1156854304| MVVVV| 1|3/16/2020| 3| 3|
+-----------+----------+----------+-------+-----+---------+-------+---+
df = spark.createDataFrame(
[('11','1','1152397078','VVVVM',1,'3/5/2020',1,5),
('11','1','1152944770','VVVVV',1,'3/6/2020',2,5),
('11','1','1153856408','VVVVV',1,'3/15/2020',3,5),
('11','2','1155884040','MVVVV',1,'4/2/2020',4,5),
('11','2','1156854301','MMVVV',0,'4/17/2020',5,5),
('12','1','1156854302','VVVVM',1,'3/6/2020',1,3),
('12','1','1156854303','VVVVV',1,'3/7/2020',2,3),
('12','2','1156854304','MVVVV',1,'3/16/2020',3,3)
]
,["consumer_id","product_id","TRX_ID","pattern","loyal","trx_date","row_num","mx"])
df.show()
Output Required:
Note : Here Flag only checks whether last loyal value containing 1 and set the flag.
+-----------+----------+----------+-------+-----+---------+-------+---+----+
|consumer_id|product_id| TRX_ID|pattern|loyal| trx_date|row_num| mx|Flag|
+-----------+----------+----------+-------+-----+---------+-------+---+----+
| 11| 1|1152397078| VVVVM| 1| 3/5/2020| 1| 5| 0|
| 11| 1|1152944770| VVVVV| 1| 3/6/2020| 2| 5| 0|
| 11| 1|1153856408| VVVVV| 1|3/15/2020| 3| 5| 0|
| 11| 2|1155884040| MVVVV| 1| 4/2/2020| 4| 5| 1|
| 11| 2|1156854301| MMVVV| 0|4/17/2020| 5| 5| 0|
| 12| 1|1156854302| VVVVM| 1| 3/6/2020| 1| 3| 0|
| 12| 1|1156854303| VVVVV| 1| 3/7/2020| 2| 3| 0|
| 12| 2|1156854304| MVVVV| 1|3/16/2020| 3| 3| 1|
+-----------+----------+----------+-------+-----+---------+-------+---+----+
What i tried :
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w2 = Window().partitionBy("consumer_id").orderBy('row_num')
df = spark.sql("""select * from inter_table""")
df = df.withColumn("Flag",F.when(F.last(F.col('loyal') == 1).over(w),1).otherwise(0))
Here there are two scenarios :
1. Value 1 with preceding 0 (for your reference row_num 4 for consumer_id 11)
2. Value 1 with no preceding (for your reference row_num 3 for consumer_id 12)
To add in the Murtaza's answer
We can add a new column which will check your second scenario for preceding null
window = Window.partitionBy('Consumer_id').orderBy('row_num')
df.withColumn('Flag',f.when((f.col('loyal')==1)
& ((f.lead(f.col('loyal')).over(window)==0)
| (f.lead(f.col('loyal')).over(window).isNull())), f.lit('1')).otherwise(f.lit('0'))).show()
+-----------+----------+----------+-------+-----+---------+-------+---+----+
|consumer_id|product_id| TRX_ID|pattern|loyal| trx_date|row_num| mx|Flag|
+-----------+----------+----------+-------+-----+---------+-------+---+----+
| 11| 1|1152397078| VVVVM| 1| 3/5/2020| 1| 5| 0|
| 11| 1|1152944770| VVVVV| 1| 3/6/2020| 2| 5| 0|
| 11| 1|1153856408| VVVVV| 1|3/15/2020| 3| 5| 0|
| 11| 2|1155884040| MVVVV| 1| 4/2/2020| 4| 5| 1|
| 11| 2|1156854300| MMVVV| 0|4/17/2020| 5| 5| 0|
| 12| 1|1156854300| VVVVM| 1| 3/6/2020| 1| 4| 0|
| 12| 1|1156854300| VVVVV| 1| 3/7/2020| 2| 4| 0|
| 12| 2|1156854300| MVVVV| 1|3/16/2020| 3| 4| 0|
| 12| 1|1156854300| MVVVV| 1| 4/3/2020| 4| 4| 1|
+-----------+----------+----------+-------+-----+---------+-------+---+----+
Try this.
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w = Window().partitionBy("product_id").orderBy('row_num')
df.withColumn("flag", F.when((F.col("loyal")==1)\
&(F.lead("loyal").over(w)==0),F.lit(1))\
.otherwise(F.lit(0))).show()
#+-----------+----------+----------+-------+-----+---------+-------+---+----+
#|consumer_id|product_id| TRX_ID|pattern|loyal| trx_date|row_num| mx|flag|
#+-----------+----------+----------+-------+-----+---------+-------+---+----+
#| 11| 1|1152397078| VVVVM| 1| 3/5/2020| 1| 5| 0|
#| 11| 1|1152944770| VVVVV| 1| 3/6/2020| 2| 5| 0|
#| 11| 1|1153856408| VVVVV| 1|3/15/2020| 3| 5| 0|
#| 11| 2|1155884040| MVVVV| 1| 4/2/2020| 4| 5| 1|
#| 11| 2|1156854300| MMVVV| 0|4/17/2020| 5| 5| 0|
#+-----------+----------+----------+-------+-----+---------+-------+---+----+
UPDATE:
from pypsark.sql import functions as F
from pyspark.sql.window import Window
w = Window().partitionBy("consumer_id").orderBy('row_num')
lead=F.lead("loyal").over(w)
df.withColumn("Flag", F.when(((F.col("loyal")==1)\
&((lead==0)|(lead.isNull()))),F.lit(1))\
.otherwise(F.lit(0))).show()

Partition pyspark dataframe based on the change in column value

I have a dataframe in pyspark.
Say the has some columns a,b,c...
I want to group the data into groups as the value of column changes. Say
A B
1 x
1 y
0 x
0 y
0 x
1 y
1 x
1 y
There will be 3 groups as (1x,1y),(0x,0y,0x),(1y,1x,1y)
And corresponding row data
If I understand correctly you want to create a distinct group every time column A changes values.
First we'll create a monotonically increasing id to keep the row order as it is:
import pyspark.sql.functions as psf
df = sc.parallelize([[1,'x'],[1,'y'],[0,'x'],[0,'y'],[0,'x'],[1,'y'],[1,'x'],[1,'y']])\
.toDF(['A', 'B'])\
.withColumn("rn", psf.monotonically_increasing_id())
df.show()
+---+---+----------+
| A| B| rn|
+---+---+----------+
| 1| x| 0|
| 1| y| 1|
| 0| x| 2|
| 0| y| 3|
| 0| x|8589934592|
| 1| y|8589934593|
| 1| x|8589934594|
| 1| y|8589934595|
+---+---+----------+
Now we'll use a window function to create a column that contains 1 every time column A changes:
from pyspark.sql import Window
w = Window.orderBy('rn')
df = df.withColumn("changed", (df.A != psf.lag('A', 1, 0).over(w)).cast('int'))
+---+---+----------+-------+
| A| B| rn|changed|
+---+---+----------+-------+
| 1| x| 0| 1|
| 1| y| 1| 0|
| 0| x| 2| 1|
| 0| y| 3| 0|
| 0| x|8589934592| 0|
| 1| y|8589934593| 1|
| 1| x|8589934594| 0|
| 1| y|8589934595| 0|
+---+---+----------+-------+
Finally we'll use another window function to allocate different numbers to each group:
df = df.withColumn("group_id", psf.sum("changed").over(w)).drop("rn").drop("changed")
+---+---+--------+
| A| B|group_id|
+---+---+--------+
| 1| x| 1|
| 1| y| 1|
| 0| x| 2|
| 0| y| 2|
| 0| x| 2|
| 1| y| 3|
| 1| x| 3|
| 1| y| 3|
+---+---+--------+
Now you can build you groups

Categories

Resources