I am trying to generate multiple lag for the 'status' variable in the dataframe below. The data that I have is timeseries and it is possible to have gap within the data. What I am trying to do is generate the lags but when there is a gap, the value should be set to Missing/Null.
Input DF:
+---+----------+------+
| id| s_date|status|
+---+----------+------+
|123|2007-01-31| 1|
|123|2007-02-28| 1|
|123|2007-03-31| 2|
|123|2007-04-30| 2|
|123|2007-05-31| 1|
|123|2007-06-30| 1|
|123|2007-07-31| 2|
|123|2007-08-31| 2|
|345|2007-08-31| 3|
|123|2007-09-30| 2|
|345|2007-09-30| 2|
|123|2007-10-31| 1|
|345|2007-10-31| 1|
|123|2007-11-30| 1|
|345|2007-11-30| 2|
|123|2008-01-31| 3|
|345|2007-12-31| 2|
|567|2007-12-31| 3|
|123|2008-03-31| 4|
|345|2008-01-31| 2|
+---+----------+------+
from datetime import date
rdd = sc.parallelize([
[123,date(2007,1,31),1],
[123,date(2007,2,28),1],
[123,date(2007,3,31),2],
[123,date(2007,4,30),2],
[123,date(2007,5,31),1],
[123,date(2007,6,30),1],
[123,date(2007,7,31),2],
[123,date(2007,8,31),2],
[345,date(2007,8,31),3],
[123,date(2007,9,30),2],
[345,date(2007,9,30),2],
[123,date(2007,10,31),1],
[345,date(2007,10,31),1],
[123,date(2007,11,30),1],
[345,date(2007,11,30),2],
[123,date(2008,1,31),3],
[345,date(2007,12,31),2],
[567,date(2007,12,31),3],
[123,date(2008,3,31),4],
[345,date(2008,1,31),2],
[567,date(2008,1,31),3],
[123,date(2008,4,30),3],
[123,date(2008,5,31),2],
[123,date(2008,6,30),1]
])
df = rdd.toDF(['id','s_date','status'])
df.show()
# Below is the code that works
from pyspark.sql.window import Window
w = Window().partitionBy('id').orderBy('s_date')
for i in range(1, 25):
df = df.withColumn(f"lag_{i}", fn.lag(fn.col('status'), i).over(w))\
.withColumn(f"lag_month_{i}", fn.lag(fn.col('s_date), i).over(w))\
.withColumn(f"lag_status_{i}", fn.expr("case when "f"{i} = 1 and (last_day(add_months("f"lag_month_{i}"",1)) = last_day(s_date)) then "f"lag_{i}"" else null end"))
In the code above, column 'lag_status_1" is correctly getting population for i = 1. This column has value of Null for Jan and March 2008 which is what I want for every lag. However, I add below line to handle other lags (i.e. Lag_2, Lag_3, etc.) the code does not work.
.withColumn(f"lag_status_{i}", fn.expr("case when "f"{i} = 1 and (last_day(add_months("f"lag_month_{i}"",1)) = last_day(s_date)) then "f"lag_{i}"" " +
"when "f"{i} > 1 and (last_day(add_months("f"lag_month_{i}"",1)) = last_day("f"lag_month_{i-1})) then "f"lag_{i}"" else null end"))\
Output DF: lag_dlq_2 and lag_dlq_3 is Null but would like to populate with similar logic as lag_dlq_1 (but with respect to its lag)
+---+----------+------+---------+---------+---------+
| id| s_date|status|lag_status_1|lag_status_2|lag_status_3|
+---+----------+------+---------+---------+---------+
|123|2007-01-31| 1| null| null| null|
|123|2007-02-28| 1| 1| null| null|
|123|2007-03-31| 2| 1| null| null|
|123|2007-04-30| 2| 2| null| null|
|123|2007-05-31| 1| 2| null| null|
|123|2007-06-30| 1| 1| null| null|
|123|2007-07-31| 2| 1| null| null|
|123|2007-08-31| 2| 2| null| null|
|123|2007-09-30| 2| 2| null| null|
|123|2007-10-31| 1| 2| null| null|
|123|2007-11-30| 1| 1| null| null|
|123|2008-01-31| 3| null| null| null|
|123|2008-03-31| 4| null| null| null|
|123|2008-04-30| 3| 4| null| null|
|123|2008-05-31| 2| 3| null| null|
|123|2008-06-30| 1| 2| null| null|
|345|2007-08-31| 3| null| null| null|
|345|2007-09-30| 2| 3| null| null|
|345|2007-10-31| 1| 2| null| null|
|345|2007-11-30| 2| 1| null| null|
+---+----------+------+---------+---------+---------+
Can you please guide on how to resolve? If there is a better or more efficient solution, please suggest.
Related
I wanted to re-label healthy examples (0) as failure (1) for 2 days before the actual failure for all serial numbers in the failure column. Here is my code:
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('spark3.2show').getOrCreate()
print('Spark info :')
spark
url="https://gist.githubusercontent.com/JishanAhmed2019/e464ca4da5c871428ca9ed9264467aa0/raw/da3921c1953fefbc66dddc3ce238dac53142dba8/failure.csv"
from pyspark import SparkFiles
spark.sparkContext.addFile(url)
df=spark.read.csv(SparkFiles.get("failure.csv"), header=True,sep='\t')
I wanted to re-label the red marked 0 as 1. Also, Serial C was mistakenly present in the database as healthy even after the actual failure.
I would recast the date column as a Timestamp because this will allow you to take the difference between any two Timestamps, which we will need to do.
Then you can create a new column called failure_dates that contains the date whenever a failure occurs, and is null otherwise.
Next, create a new column called 2_days_to_failure where you partition by serial_number and take the difference between the max value in the failure_date column each date inside the partition to get the number of days to failure, returning 1 whenever there is 2 days or fewer to failure.
Finally, we can create a column called failure_relabeled by combining the information from the columns 2_days_to_failure and the original failure column.
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number
window = Window.partitionBy("serial_number")
df.withColumn(
'date', F.to_timestamp(F.col('date'), 'M/D/yyyy')
).withColumn(
"failure_dates", F.when(F.col('failure') == 1, F.col('date'))
).withColumn(
"2_days_to_failure", F.datediff(F.max(F.col('failure_dates')).over(window), F.col('date')) <= 2
).withColumn(
"failure_relabeled", F.when((F.col('2_days_to_failure') | (F.col('failure') == 1)), F.lit(1)).otherwise(F.lit(0))
).orderBy('serial_number','date').show()
+-------------------+-------------+-------+-----------+-------------+-------------------+-----------------+-----------------+
| date|serial_number|failure|smart_5_raw|smart_187_raw| failure_dates|2_days_to_failure|failure_relabeled|
+-------------------+-------------+-------+-----------+-------------+-------------------+-----------------+-----------------+
|2014-01-01 00:00:00| A| 0| 0| 60| null| false| 0|
|2014-01-02 00:00:00| A| 0| 0| 180| null| false| 0|
|2014-01-03 00:00:00| A| 0| 0| 140| null| true| 1|
|2014-01-04 00:00:00| A| 0| 0| 280| null| true| 1|
|2014-01-05 00:00:00| A| 1| 0| 400|2014-01-05 00:00:00| true| 1|
|2014-01-01 00:00:00| B| 0| 0| 40| null| null| 0|
|2014-01-02 00:00:00| B| 0| 0| 160| null| null| 0|
|2014-01-03 00:00:00| B| 0| 0| 100| null| null| 0|
|2014-01-04 00:00:00| B| 0| 0| 320| null| null| 0|
|2014-01-05 00:00:00| B| 0| 0| 340| null| null| 0|
|2014-01-06 00:00:00| B| 0| 0| 400| null| null| 0|
|2014-01-01 00:00:00| C| 0| 0| 80| null| true| 1|
|2014-01-02 00:00:00| C| 0| 0| 200| null| true| 1|
|2014-01-03 00:00:00| C| 1| 0| 120|2014-01-03 00:00:00| true| 1|
|2014-01-04 00:00:00| D| 0| 0| 300| null| null| 0|
|2014-01-05 00:00:00| D| 0| 0| 360| null| null| 0|
+-------------------+-------------+-------+-----------+-------------+-------------------+-----------------+-----------------+
I have a dataframew like below in Pyspark
df = spark.createDataFrame([(2,'john',1,1),
(2,'john',1,2),
(3,'pete',8,3),
(3,'pete',8,4),
(5,'steve',9,5)],
['id','/na/me','val/ue', 'rank/'])
df.show()
+---+------+------+-----+
| id|/na/me|val/ue|rank/|
+---+------+------+-----+
| 2| john| 1| 1|
| 2| john| 1| 2|
| 3| pete| 8| 3|
| 3| pete| 8| 4|
| 5| steve| 9| 5|
+---+------+------+-----+
Now in this data frame I want to replace the column names where / to under scrore _. But if the / comes at the start or end of the column name then remove the / but don't replace with _.
I have done like below
for name in df.schema.names:
df = df.withColumnRenamed(name, name.replace('/', '_'))
>>> df
DataFrame[id: bigint, _na_me: string, val_ue: bigint, rank_: bigint]
>>>df.show()
+---+------+------+-----+
| id|_na_me|val_ue|rank_|
+---+------+------+-----+
| 2| john| 1| 1|
| 2| john| 1| 2|
| 3| pete| 8| 3|
| 3| pete| 8| 4|
| 5| steve| 9| 5|
+---+------+------+-----+
How can I achieve my desired result which is below
+---+------+------+-----+
| id| na_me|val_ue| rank|
+---+------+------+-----+
| 2| john| 1| 1|
| 2| john| 1| 2|
| 3| pete| 8| 3|
| 3| pete| 8| 4|
| 5| steve| 9| 5|
+---+------+------+-----+
Try with regular expression replace(re.sub) in python way.
import re
cols=[re.sub(r'(^_|_$)','',f.replace("/","_")) for f in df.columns]
df = spark.createDataFrame([(2,'john',1,1),
(2,'john',1,2),
(3,'pete',8,3),
(3,'pete',8,4),
(5,'steve',9,5)],
['id','/na/me','val/ue', 'rank/'])
df.toDF(*cols).show()
#+---+-----+------+----+
#| id|na_me|val_ue|rank|
#+---+-----+------+----+
#| 2| john| 1| 1|
#| 2| john| 1| 2|
#| 3| pete| 8| 3|
#| 3| pete| 8| 4|
#| 5|steve| 9| 5|
#+---+-----+------+----+
#or using for loop on schema.names
for name in df.schema.names:
df = df.withColumnRenamed(name, re.sub(r'(^_|_$)','',name.replace('/', '_')))
df.show()
#+---+-----+------+----+
#| id|na_me|val_ue|rank|
#+---+-----+------+----+
#| 2| john| 1| 1|
#| 2| john| 1| 2|
#| 3| pete| 8| 3|
#| 3| pete| 8| 4|
#| 5|steve| 9| 5|
#+---+-----+------+----+
I have a dataframe like as shown below
df = pd.DataFrame({
'subject_id':[1,1,1,1,2,2,2,2,3,3,4,4,4,4,4],
'readings' : ['READ_1','READ_2','READ_1','READ_3','READ_1','READ_5','READ_6','READ_8','READ_10','READ_12','READ_11','READ_14','READ_09','READ_08','READ_07'],
'val' :[5,6,7,11,5,7,16,12,13,56,32,13,45,43,46],
})
My above input dataframe looks like this
Though the code below works fine(thanks to Jezrael) in Python pandas, when I apply this to real data (more than 4M records), it runs for a long time. So I was trying to use pyspark . Please note I already tried Dask,modin,pandarallel which are equivalent to pandas for large scale processing but didn't help either. What the below codes does is it generates the summary statistics for each subject for each reading. You can have a look at the expected output below to get an idea
df_op = (df.groupby(['subject_id','readings'])['val']
.describe()
.unstack()
.swaplevel(0,1,axis=1)
.reindex(df['readings'].unique(), axis=1, level=0))
df_op.columns = df_op.columns.map('_'.join)
df_op = df_op.reset_index()
Can you help me achieve the above operation in pyspark? When I tried the below, it threw an error
df.groupby(['subject_id','readings'])['val']
For example - subject_id = 1 has 4 readings but 3 unique readings. So we get 3 * 8 = 24 columns for subject_id = 1. Why 8? Because it's MIN,MAX,COUNT,Std,MEAN,25%percentile,50th percentile,75th percentile. Hope this helps
When I started off with this in pyspark, it returns the below error
TypeError: 'GroupedData' object is not subscriptable
I expect my output to be like as shown below
You need to groupby and get the statistics for each reading first, and then you make a pivot to get an expected outcome
import pyspark.sql.functions as F
agg_df = df.groupby("subject_id", "readings").agg(F.mean(F.col("val")), F.min(F.col("val")), F.max(F.col("val")),
F.count(F.col("val")),
F.expr('percentile_approx(val, 0.25)').alias("quantile_25"),
F.expr('percentile_approx(val, 0.75)').alias("quantile_75"))
This will give you the following output:
+----------+--------+--------+--------+--------+----------+-----------+-----------+
|subject_id|readings|avg(val)|min(val)|max(val)|count(val)|quantile_25|quantile_75|
+----------+--------+--------+--------+--------+----------+-----------+-----------+
| 2| READ_1| 5.0| 5| 5| 1| 5| 5|
| 2| READ_5| 7.0| 7| 7| 1| 7| 7|
| 2| READ_8| 12.0| 12| 12| 1| 12| 12|
| 4| READ_08| 43.0| 43| 43| 1| 43| 43|
| 1| READ_2| 6.0| 6| 6| 1| 6| 6|
| 1| READ_1| 6.0| 5| 7| 2| 5| 7|
| 2| READ_6| 16.0| 16| 16| 1| 16| 16|
| 1| READ_3| 11.0| 11| 11| 1| 11| 11|
| 4| READ_11| 32.0| 32| 32| 1| 32| 32|
| 3| READ_10| 13.0| 13| 13| 1| 13| 13|
| 3| READ_12| 56.0| 56| 56| 1| 56| 56|
| 4| READ_14| 13.0| 13| 13| 1| 13| 13|
| 4| READ_07| 46.0| 46| 46| 1| 46| 46|
| 4| READ_09| 45.0| 45| 45| 1| 45| 45|
+----------+--------+--------+--------+--------+----------+-----------+-----------+
Using groupby subject_id if you pivot readings, you will get the expected output:
agg_df2 = df.groupby("subject_id").pivot("readings").agg(F.mean(F.col("val")), F.min(F.col("val")), F.max(F.col("val")),
F.count(F.col("val")),
F.expr('percentile_approx(val, 0.25)').alias("quantile_25"),
F.expr('percentile_approx(val, 0.75)').alias("quantile_75"))
for i in agg_df2.columns:
agg_df2 = agg_df2.withColumnRenamed(i, i.replace("(val)", ""))
agg_df2.show()

|subject_id|READ_07_avg(val)|READ_07_min(val)|READ_07_max(val)|READ_07_count(val)|READ_07_quantile_25|READ_07_quantile_75|READ_08_avg(val)|READ_08_min(val)|READ_08_max(val)|READ_08_count(val)|READ_08_quantile_25|READ_08_quantile_75|READ_09_avg(val)|READ_09_min(val)|READ_09_max(val)|READ_09_count(val)|READ_09_quantile_25|READ_09_quantile_75|READ_1_avg(val)|READ_1_min(val)|READ_1_max(val)|READ_1_count(val)|READ_1_quantile_25|READ_1_quantile_75|READ_10_avg(val)|READ_10_min(val)|READ_10_max(val)|READ_10_count(val)|READ_10_quantile_25|READ_10_quantile_75|READ_11_avg(val)|READ_11_min(val)|READ_11_max(val)|READ_11_count(val)|READ_11_quantile_25|READ_11_quantile_75|READ_12_avg(val)|READ_12_min(val)|READ_12_max(val)|READ_12_count(val)|READ_12_quantile_25|READ_12_quantile_75|READ_14_avg(val)|READ_14_min(val)|READ_14_max(val)|READ_14_count(val)|READ_14_quantile_25|READ_14_quantile_75|READ_2_avg(val)|READ_2_min(val)|READ_2_max(val)|READ_2_count(val)|READ_2_quantile_25|READ_2_quantile_75|READ_3_avg(val)|READ_3_min(val)|READ_3_max(val)|READ_3_count(val)|READ_3_quantile_25|READ_3_quantile_75|READ_5_avg(val)|READ_5_min(val)|READ_5_max(val)|READ_5_count(val)|READ_5_quantile_25|READ_5_quantile_75|READ_6_avg(val)|READ_6_min(val)|READ_6_max(val)|READ_6_count(val)|READ_6_quantile_25|READ_6_quantile_75|READ_8_avg(val)|READ_8_min(val)|READ_8_max(val)|READ_8_count(val)|READ_8_quantile_25|READ_8_quantile_75|
+----------+----------------+----------------+----------------+------------------+-------------------+-------------------+----------------+----------------+----------------+------------------+-------------------+-------------------+----------------+----------------+----------------+------------------+-------------------+-------------------+---------------+---------------+---------------+-----------------+------------------+------------------+----------------+----------------+----------------+------------------+-------------------+-------------------+----------------+----------------+----------------+------------------+-------------------+-------------------+----------------+----------------+----------------+------------------+-------------------+-------------------+----------------+----------------+----------------+------------------+-------------------+-------------------+---------------+---------------+---------------+-----------------+------------------+------------------+---------------+---------------+---------------+-----------------+------------------+------------------+---------------+---------------+---------------+-----------------+------------------+------------------+---------------+---------------+---------------+-----------------+------------------+------------------+---------------+---------------+---------------+-----------------+------------------+------------------+
| 1| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 6.0| 5| 7| 2| 5| 7| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 6.0| 6| 6| 1| 6| 6| 11.0| 11| 11| 1| 11| 11| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null|
| 3| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 13.0| 13| 13| 1| 13| 13| null| null| null| null| null| null| 56.0| 56| 56| 1| 56| 56| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null|
| 2| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 5.0| 5| 5| 1| 5| 5| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 7.0| 7| 7| 1| 7| 7| 16.0| 16| 16| 1| 16| 16| 12.0| 12| 12| 1| 12| 12|
| 4| 46.0| 46| 46| 1| 46| 46| 43.0| 43| 43| 1| 43| 43| 45.0| 45| 45| 1| 45| 45| null| null| null| null| null| null| null| null| null| null| null| null| 32.0| 32| 32| 1| 32| 32| null| null| null| null| null| null| 13.0| 13| 13| 1| 13| 13| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null|

I have the following example Spark DataFrame:
rdd = sc.parallelize([(1,"19:00:00", "19:30:00", 30), (1,"19:30:00", "19:40:00", 10),(1,"19:40:00", "19:43:00", 3), (2,"20:00:00", "20:10:00", 10), (1,"20:05:00", "20:15:00", 10),(1,"20:15:00", "20:35:00", 20)])
df = spark.createDataFrame(rdd, ["user_id", "start_time", "end_time", "duration"])
df.show()
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:30:00| 30|
| 1| 19:30:00|19:40:00| 10|
| 1| 19:40:00|19:43:00| 3|
| 2| 20:00:00|20:10:00| 10|
| 1| 20:05:00|20:15:00| 10|
| 1| 20:15:00|20:35:00| 20|
+-------+----------+--------+--------+
I want to group consecutive rows based on the start and end times. For instance, for the same user_id, if a row's start time is the same as the previous row's end time, I want to group them together and sum the duration.
The desired result is:
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:43:00| 43|
| 2| 20:00:00|20:10:00| 10|
| 1| 20:05:00|20:35:00| 30|
+-------+----------+--------+--------+
The first three rows of the dataframe were grouped together because they all correspond to user_id 1 and the start times and end times form a continuous timeline.
This was my initial approach:
Use the lag function to get the next start time:
from pyspark.sql.functions import *
from pyspark.sql import Window
import sys
# compute next start time
window = Window.partitionBy('user_id').orderBy('start_time')
df = df.withColumn("next_start_time", lag(df.start_time, -1).over(window))
df.show()
+-------+----------+--------+--------+---------------+
|user_id|start_time|end_time|duration|next_start_time|
+-------+----------+--------+--------+---------------+
| 1| 19:00:00|19:30:00| 30| 19:30:00|
| 1| 19:30:00|19:40:00| 10| 19:40:00|
| 1| 19:40:00|19:43:00| 3| 20:05:00|
| 1| 20:05:00|20:15:00| 10| 20:15:00|
| 1| 20:15:00|20:35:00| 20| null|
| 2| 20:00:00|20:10:00| 10| null|
+-------+----------+--------+--------+---------------+
get the difference between the current row's end time and the next row's start time:
time_fmt = "HH:mm:ss"
timeDiff = unix_timestamp('next_start_time', format=time_fmt) - unix_timestamp('end_time', format=time_fmt)
df = df.withColumn("difference", timeDiff)
df.show()
+-------+----------+--------+--------+---------------+----------+
|user_id|start_time|end_time|duration|next_start_time|difference|
+-------+----------+--------+--------+---------------+----------+
| 1| 19:00:00|19:30:00| 30| 19:30:00| 0|
| 1| 19:30:00|19:40:00| 10| 19:40:00| 0|
| 1| 19:40:00|19:43:00| 3| 20:05:00| 1320|
| 1| 20:05:00|20:15:00| 10| 20:15:00| 0|
| 1| 20:15:00|20:35:00| 20| null| null|
| 2| 20:00:00|20:10:00| 10| null| null|
+-------+----------+--------+--------+---------------+----------+
Now my idea was to use the sum function with a window to get the cumulative sum of duration and then do a groupBy. But my approach was flawed for many reasons.
Here's one approach:
Gather together rows into groups where a group is a set of rows with the same user_id that are consecutive (start_time matches previous end_time). Then you can use this group to do your aggregation.
A way to get here is by creating intermediate indicator columns to tell you if the user has changed or the time is not consecutive. Then perform a cumulative sum over the indicator column to create the group.
For example:
import pyspark.sql.functions as f
from pyspark.sql import Window
w1 = Window.orderBy("start_time")
df = df.withColumn(
"userChange",
(f.col("user_id") != f.lag("user_id").over(w1)).cast("int")
)\
.withColumn(
"timeChange",
(f.col("start_time") != f.lag("end_time").over(w1)).cast("int")
)\
.fillna(
0,
subset=["userChange", "timeChange"]
)\
.withColumn(
"indicator",
(~((f.col("userChange") == 0) & (f.col("timeChange")==0))).cast("int")
)\
.withColumn(
"group",
f.sum(f.col("indicator")).over(w1.rangeBetween(Window.unboundedPreceding, 0))
)
df.show()
#+-------+----------+--------+--------+----------+----------+---------+-----+
#|user_id|start_time|end_time|duration|userChange|timeChange|indicator|group|
#+-------+----------+--------+--------+----------+----------+---------+-----+
#| 1| 19:00:00|19:30:00| 30| 0| 0| 0| 0|
#| 1| 19:30:00|19:40:00| 10| 0| 0| 0| 0|
#| 1| 19:40:00|19:43:00| 3| 0| 0| 0| 0|
#| 2| 20:00:00|20:10:00| 10| 1| 1| 1| 1|
#| 1| 20:05:00|20:15:00| 10| 1| 1| 1| 2|
#| 1| 20:15:00|20:35:00| 20| 0| 0| 0| 2|
#+-------+----------+--------+--------+----------+----------+---------+-----+
Now that we have the group column, we can aggregate as follows to get the desired result:
df.groupBy("user_id", "group")\
.agg(
f.min("start_time").alias("start_time"),
f.max("end_time").alias("end_time"),
f.sum("duration").alias("duration")
)\
.drop("group")\
.show()
#+-------+----------+--------+--------+
#|user_id|start_time|end_time|duration|
#+-------+----------+--------+--------+
#| 1| 19:00:00|19:43:00| 43|
#| 1| 20:05:00|20:35:00| 30|
#| 2| 20:00:00|20:10:00| 10|
#+-------+----------+--------+--------+
Here is a working solution derived from Pault's answer:
Create the Dataframe:
rdd = sc.parallelize([(1,"19:00:00", "19:30:00", 30), (1,"19:30:00", "19:40:00", 10),(1,"19:40:00", "19:43:00", 3), (2,"20:00:00", "20:10:00", 10), (1,"20:05:00", "20:15:00", 10),(1,"20:15:00", "20:35:00", 20)])
df = spark.createDataFrame(rdd, ["user_id", "start_time", "end_time", "duration"])
df.show()
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:30:00| 30|
| 1| 19:30:00|19:40:00| 10|
| 1| 19:40:00|19:43:00| 3|
| 1| 20:05:00|20:15:00| 10|
| 1| 20:15:00|20:35:00| 20|
+-------+----------+--------+--------+
Create an indicator column that indicates whenever the time has changed, and use cumulative sum to give each group a unique id:
import pyspark.sql.functions as f
from pyspark.sql import Window
w1 = Window.partitionBy('user_id').orderBy('start_time')
df = df.withColumn(
"indicator",
(f.col("start_time") != f.lag("end_time").over(w1)).cast("int")
)\
.fillna(
0,
subset=[ "indicator"]
)\
.withColumn(
"group",
f.sum(f.col("indicator")).over(w1.rangeBetween(Window.unboundedPreceding, 0))
)
df.show()
+-------+----------+--------+--------+---------+-----+
|user_id|start_time|end_time|duration|indicator|group|
+-------+----------+--------+--------+---------+-----+
| 1| 19:00:00|19:30:00| 30| 0| 0|
| 1| 19:30:00|19:40:00| 10| 0| 0|
| 1| 19:40:00|19:43:00| 3| 0| 0|
| 1| 20:05:00|20:15:00| 10| 1| 1|
| 1| 20:15:00|20:35:00| 20| 0| 1|
+-------+----------+--------+--------+---------+-----+
Now GroupBy on user id and the group variable.
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:43:00| 43|
| 1| 20:05:00|20:35:00| 30|
+-------+----------+--------+--------+
I got this dataframe Sample:
from pyspark.sql.types import *
schema = StructType([
StructField("ClientId", IntegerType(), True),
StructField("m_ant21", IntegerType(), True),
StructField("m_ant22", IntegerType(), True),
StructField("m_ant23", IntegerType(), True),
StructField("m_ant24", IntegerType(), True)])
df = sqlContext.createDataFrame(
data=[(0, None, None, None, None),
(1, 23, 13, 17, 99),
(2, 0, 0, 0, 1),
(3, 0, None, 1, 0),
(4, None, None, None, None)],
schema=schema)
I have this data frame:
+--------+-------+-------+-------+-------+
|ClientId|m_ant21|m_ant22|m_ant23|m_ant24|
+--------+-------+-------+-------+-------+
| 0| null| null| null| null|
| 1| 23| 13| 17| 99|
| 2| 0| 0| 0| 1|
| 3| 0| null| 1| 0|
| 4| null| null| null| null|
+--------+-------+-------+-------+-------+
And I need to solve this question:
I'd like to create a new variable which counts how many null values have the data per row. For example:
ClientId 0 should be 4
ClientId 1 should be 0
ClientId 3 should be 1
Note that df is a pyspark.sql.dataframe.DataFrame.
Here is one option:
from pyspark.sql import Row
# add the column schema to the original schema
schema.add(StructField("count_null", IntegerType(), True))
# convert data frame to rdd and append an element to each row to count the number of nulls
df.rdd.map(lambda row: row + Row(sum(x is None for x in row))).toDF(schema).show()
+--------+-------+-------+-------+-------+----------+
|ClientId|m_ant21|m_ant22|m_ant23|m_ant24|count_null|
+--------+-------+-------+-------+-------+----------+
| 0| null| null| null| null| 4|
| 1| 23| 13| 17| 99| 0|
| 2| 0| 0| 0| 1| 0|
| 3| 0| null| 1| 0| 1|
| 4| null| null| null| null| 4|
+--------+-------+-------+-------+-------+----------+
If you don't want to deal with schema, here is another option:
from pyspark.sql.functions import col, when
df.withColumn("count_null", sum([when(col(x).isNull(),1).otherwise(0) for x in df.columns])).show()
+--------+-------+-------+-------+-------+----------+
|ClientId|m_ant21|m_ant22|m_ant23|m_ant24|count_null|
+--------+-------+-------+-------+-------+----------+
| 0| null| null| null| null| 4|
| 1| 23| 13| 17| 99| 0|
| 2| 0| 0| 0| 1| 0|
| 3| 0| null| 1| 0| 1|
| 4| null| null| null| null| 4|
+--------+-------+-------+-------+-------+----------+