I have a customer table that hosts information about several processes for each customer.
The goal is to extract features for each customer and each process. This means every feature is an mostly an aggregate or sort-compare computation on a .groupby(customerID, processID) object.
However, the goal is to be able to add more and more features over time. So basically the user should be able to define a new functions with some filters, metrics and aggregations, and add this new function to a pool of functions which operate on the table.
The output should be a customerID, processID table, with all features.
So I startet a little minimal working example:
l = [('CM1','aa1', 100,0.1),('CM1','aa1', 110,0.2),\
('CM1','aa1', 110,0.9),('CM1','aa1', 100,1.5),\
('CX2','bb9', 100,0.1),('CX2','bb9', 100,0.2),\
('CX2','bb9', 110,6.0),('CX2','bb9', 100,0.18)]
rdd = sc.parallelize(l)
df = sqlContext.createDataFrame(rdd,['customid','procid','speed','timestamp'])
+--------+------+-----+---------+
|customid|procid|speed|timestamp|
+--------+------+-----+---------+
| CM1| aa1| 100| 0.1|
| CM1| aa1| 110| 0.2|
| CM1| aa1| 110| 0.9|
| CM1| aa1| 100| 1.5|
| CX2| bb9| 100| 0.1|
| CX2| bb9| 100| 0.2|
| CX2| bb9| 110| 6.0|
| CX2| bb9| 100| 0.18|
+--------+------+-----+---------+
Then i define 2 arbitrary feature, which get extracted by these functions:
def extr_ft_1 (proc_data, limit=100):
proc_data = proc_data.filter(proc_data.speed > limit).agg(count(proc_data.speed))
proc_data = proc_data.select(col('count(speed)').alias('speed_feature'))
proc_data.show()
return proc_data
def extr_ft_0 (proc_data):
max_t = proc_data.agg(spark_max(proc_data.timestamp))
min_t = proc_data.agg(spark_min(proc_data.timestamp))
max_t = max_t.select(col('max(timestamp)').alias('max'))
min_t = min_t.select(col('min(timestamp)').alias('min'))
X = max_t.crossJoin(min_t)
X = X.withColumn('time_feature', X.max+X.min)
X = X.drop(X.min).drop(X.max)
X.show()
return (X)
They return 1-element RRDs which just hold an aggregate value.
Next, all feature functions are applied for a given process and combined in a result RDD for each process:
def get_proc_features(proc, data, *features):
proc_data = data.filter( data.customid == proc)
features_for_proc = [feature_value(proc_data) for feature_value in features]
for number, feature in enumerate(features_for_proc):
if number == 0:
l = [(proc,'dummy')]
rdd = sc.parallelize(l)
df = sqlContext.createDataFrame(rdd,['customid','dummy'])
df = df.drop(df.dummy)
df.show()
features_for_proc_rdd = feature
features_for_proc_rdd = features_for_proc_rdd.crossJoin(df)
continue
features_for_proc_rdd = features_for_proc_rdd.crossJoin(feature)
features_for_proc_rdd.show()
return features_for_proc_rdd
They last step is to append all rows which contain the features for each process to one dataframe:
for number, proc in enumerate(customer_list_1):
if number == 0:
#results = get_trip_features(trip, df, extr_ft_0, extr_ft_1)
results = get_proc_features(proc, df, *extr_feature_funcs)
continue
results = results.unionAll(get_proc_features(proc, df, *extr_feature_funcs))
results.show()
The chain of transformations goes like this:
get features 1 and 2 for customer 1:
+------------+
|time_feature|
+------------+
| 1.6|
+------------+
+-------------+
|speed_feature|
+-------------+
| 2|
+-------------+
Combine them to:
+------------+--------+-------------+
|time_feature|customid|speed_feature|
+------------+--------+-------------+
| 1.6| CM1| 2|
+------------+--------+-------------+
Do the same for customer 2 and append all RDDs to the final result RDD:
+------------+--------+-------------+
|time_feature|customid|speed_feature|
+------------+--------+-------------+
| 1.6| CM1| 2|
| 6.1| CX2| 1|
+------------+--------+-------------+
If I run the code on the cluster, it works for 2 customers.
But when I tested it on a reasonable amount of customers, i get mostly GC and heap memory errors.
Do I work with to many RDDs here? I am afraid my code is very inefficient but I don't know where to start to optimize it. I think I just call one action at the end (I drop all shows() in live mode and just collect() the very last RDD).
I would really appreciate your help.
Your code needs refactoring, the problem is not the RDD but the fact that you filter it to work on unitary keys and then cross join. Iterating through values makes you lose the distributed aspect of pyspark. Keep in mind that you should always keep your one work table if you don't need features from another one.
The best way to do it is using dataframes and window functions.
First let's rewrite your functions:
import pyspark.sql.functions as psf
def extr_ft_1 (proc_data, w, limit=100):
return proc_data.withColumn(
"speed_feature",
psf.sum((proc_data.speed > limit).cast("int")).over(w)
)
def extr_ft_0(proc_data, w):
return proc_data.withColumn(
"time_feature",
psf.min(proc_data.timestamp).over(w) + psf.max(proc_data.timestamp).over(w)
)
Where w is a window spec:
from pyspark.sql import Window
w = Window.partitionBy("customid")
df1 = extr_ft_1(df, w)
df0 = extr_ft_0(df1, w)
df0.show()
+--------+------+-----+---------+-------------+------------+
|customid|procid|speed|timestamp|speed_feature|time_feature|
+--------+------+-----+---------+-------------+------------+
| CM1| aa1| 100| 0.1| 2| 1.6|
| CM1| aa1| 110| 0.2| 2| 1.6|
| CM1| aa1| 110| 0.9| 2| 1.6|
| CM1| aa1| 100| 1.5| 2| 1.6|
| CX2| bb9| 100| 0.1| 1| 6.1|
| CX2| bb9| 100| 0.2| 1| 6.1|
| CX2| bb9| 110| 6.0| 1| 6.1|
| CX2| bb9| 100| 0.18| 1| 6.1|
+--------+------+-----+---------+-------------+------------+
Here we never lose information (we keep all the lines) so if you want to add extra features you can. If you want a final aggregated results just run it through a groupBy("customid").
Note that you can also modify the aggregation key in the window spec to include procid for instance.
Related
I have a dataframe like as provided below:
+--------+-------+-------+--------------+--------------------+-----------------+----------+--------------------+------------+
|sequence|recType|valCode|registerNumber| rest| errorCode|errorType | errorDescription|isSuccessful|
+--------+-------+-------+--------------+--------------------+-----------------+----------+--------------------+------------+
| 9| 11| 0| XXXX2288|110XXXX2288MKKKKK...| CHAR0088| ERROR|Records out of se...| N|
| 9| 12| 0| XXXX2288|130XXXX22880011ZZ...| CHAR0088| ERROR|Records out of se...| N|
| 9| 18| 0| XXXX2288|140XXXX2288 ...| CHAR0088| ERROR|Records out of se...| N|
+--------+-------+-------+--------------+--------------------+-----------------+----------+--------------------+------------+ N|
The below code uses UDF to populate the data for errorType and errorDescription columns.
The UDFs i.e. resolveErrorTypeUDF and resolveErrorDescUDF take one errorCode as input and provide the respective errorType and errorDescription in output respectively.
errorFinalDf = errorDfAll.na.fill("") \
.withColumn("errorType", resolveErrorTypeUDF(col("errorCode"))) \
.withColumn("errorDescription", resolveErrorDescUDF(col("errorCode"))) \
.withColumn("isSuccessful", when(trim(col("errorCode")).eqNullSafe(""), "Y").otherwise("N")) \
.dropDuplicates()
Please notice that, I used to get only one error code in errorCode column. Now onwards, I will be getting single/multiple - separated error codes in the errorCode column. And I need to populate all the mapping errorType and errorDescription and write them into respective column with - separation.
The new dataframe would look like this.
+--------+-------+-------+--------------+--------------------+-----------------+----------+--------------------+------------+
|sequence|recType|valCode|registerNumber| rest| errorCode|errorType | errorDescription|isSuccessful|
+--------+-------+-------+--------------+--------------------+-----------------+----------+--------------------+------------+
| 7| 1| 0| XXXX8822|010XXXX8822XBCDEF...|CHAR0009-CHAR0021|ERROR-WARN|Short Failed-Miss...| N|
| 7| 11| 0| XXXX8822|110XXXX8822LLLLLL...|CHAR0009-CHAR0021|ERROR-WARN|Short Failed-Miss...| N|
| 7| 12| 0| XXXX8822|120XXXX8822011GB ...|CHAR0009-CHAR0021|ERROR-WARN|Short Failed-Miss...| N|
| 7| 18| 0| XXXX8822|180XXXX8822 ...|CHAR0009-CHAR0021|ERROR-WARN|Short Failed-Miss...| N|
| 7| 18| 0| XXXX8822|180XXXX88220 ...|CHAR0009-CHAR0021|ERROR-WARN|Short Failed-Miss...| N|
+--------+-------+-------+--------------+--------------------+-----------------+----------+--------------------+------------+
What changes would be needed to accommodate the new scenario. Please help. Thank you.
You need minimal changes, limited only to your UDFs.
Suppose you have a simple python function, get_type_from_code able to convert a string with the error code to the correspondent type (the same applies to the description).
from pyspark.sql import functions as F, types as T
def get_type_from_code(c: str) -> str:
"""Function to convert error code to error type.
Mind the interface: string in, string out
"""
return {'CHAR0009': 'ERROR', 'CHAR0021': 'WARNING'}.get(c, 'UNKNOWN')
#F.udf(returnType=T.StringType())
def convert_errcodes_to_types(codes: str) -> str:
"""Convert a string of error codes separated by '-' into a string of types concatenated with '-'"""
return '-'.join(
map(get_type_from_code, codes.split('-'))
)
Done!
My dataset:
+--------------------+----------+----------+-------------------+--------------------+--------+-------+---------+--------------------+-------------------+-------------------+-----------+----+---------------+-----------------+-----------+------------+----------------------+------------------+---------+
| event_time|event_type|product_id| category_id| category_code| brand| price| user_id| user_session| Event_time_NoUTC| Event_timestamp|day_of_week|hour|primaryCategory|secondaryCategory|eventVisits|productCount|secondaryCategoryCount| AvgCatExpense|SessCount|
+--------------------+----------+----------+-------------------+--------------------+--------+-------+---------+--------------------+-------------------+-------------------+-----------+----+---------------+-----------------+-----------+------------+----------------------+------------------+---------+
|2019-10-06 07:04:...| view| 1004565|2053013555631882655|electronics.smart...| huawei| 169.84|231943435|428ebb99-3568-4e1...|2019-10-06 07:04:50|2019-10-06 07:04:50| 1| 7| electronics| smartphone| 1| 1| 1| 380.2349402627628| 1|
|2019-10-25 03:50:...| view| 5100337|2053013553341792533| electronics.clocks| apple| 319.34|266287781|f55edf02-3fd4-48f...|2019-10-25 03:50:28|2019-10-25 03:50:28| 6| 3| electronics| clocks| 7| 7| 7| 369.7054359810376| 4|
|2019-10-25 03:52:...| view| 1005105|2053013555631882655|electronics.smart...| apple|1397.09|266287781|118dbcd6-fe31-4cc...|2019-10-25 03:52:09|2019-10-25 03:52:09| 6| 3| electronics| smartphone| 7| 7| 7| 369.7054359810376| 4|
|2019-10-26 12:15:...| view| 6000157|2053013560807654091|auto.accessories....|starline| 91.12|266287781|992d03b4-c561-4fb...|2019-10-26 12:15:56|2019-10-26 12:15:56| 7| 12| auto| accessories| 7| 7| 7| 369.7054359810376| 4|
The event type has three categories: View, cart and Purchase. I want to classify the user_id and product_id with a new column is_purchased=1 if it has event type as purchase and others will be 0. After that, I would remove the redundant rows as shown below which would basically help me classify my data whether a customer will churn or not.
I am thinking of partitioning data with user_id and product_id and then classify for those which has purchase. Please suggest your approaches to solve this?
Step 1: group the data by user and product and mark if each group contains the event purchase:
from pyspark.sql import functions as F
data = [("A",123, "view", "other attributes 1"),
("A",123, "cart", "other attributes 2"),
("A",123, "purchase", "other attributes 3"),
("B",123, "cart", "other attributes 4")]
df = spark.createDataFrame(data, schema = ["user", "product", "event", "other"])
is_purchased = df.groupBy("user", "product").agg(
F.array_contains(F.collect_set("event"), "purchase").alias("is_purchased"))
# +----+-------+------------+
# |user|product|is_purchased|
# +----+-------+------------+
# | A| 123| true|
# | B| 123| false|
# +----+-------+------------+
Step 2: join the result from step 1 with the original data and filter out the redundant rows:
result = df.join(is_purchased, on=["user", "product"], how="left") \
.filter("event= 'cart'")
# +----+-------+-----+------------------+------------+
# |user|product|event| other|is_purchased|
# +----+-------+-----+------------------+------------+
# | A| 123| cart|other attributes 2| true|
# | B| 123| cart|other attributes 4| false|
# +----+-------+-----+------------------+------------+
You can also apply a window function and get all events of each user and product, then filter (I'm using same sample data as #werner)
from pyspark.sql import functions as F
from pyspark.sql import Window as W
(df
.withColumn('events', F.collect_set('event').over(W.partitionBy('user', 'product')))
.withColumn('is_purchased', F.array_contains(F.col('events'), 'purchase'))
.withColumn('is_purchased', F.array_contains(F.col('events'), 'purchase'))
.where(F.col('event') == 'cart')
.show(10, False)
)
+----+-------+-----+------------------+----------------------+------------+
|user|product|event|other |events |is_purchased|
+----+-------+-----+------------------+----------------------+------------+
|A |123 |cart |other attributes 2|[cart, view, purchase]|true |
|B |abc |cart |other attributes 4|[cart] |false |
+----+-------+-----+------------------+----------------------+------------+
I have been trying to join two dataframes using the following list of join key passed as a list and I want to add the functionality to join on a subset of the keys if one of the key value is null
I have been trying to join two dataframes df_1 and df_2.
data1 = [[1,'2018-07-31',215,'a'],
[2,'2018-07-30',None,'b'],
[3,'2017-10-28',201,'c']
]
df_1 = sqlCtx.createDataFrame(data1,
['application_number','application_dt','account_id','var1'])
and
data2 = [[1,'2018-07-31',215,'aaa'],
[2,'2018-07-30',None,'bbb'],
[3,'2017-10-28',201,'ccc']
]
df_2 = sqlCtx.createDataFrame(data2,
['application_number','application_dt','account_id','var2'])
The code I use to join is this:
key_a = ['application_number','application_dt','account_id']
new = df_1.join(df_2,key_a,'left')
The output for the same is:
+------------------+--------------+----------+----+----+
|application_number|application_dt|account_id|var1|var2|
+------------------+--------------+----------+----+----+
| 1| 2018-07-31| 215| a| aaa|
| 3| 2017-10-28| 201| c| ccc|
| 2| 2018-07-30| null| b|null|
+------------------+--------------+----------+----+----+
My concern here is, in the case where account_id is null, the join should still work by comparing other 2 keys.
The required output should be like this:
+------------------+--------------+----------+----+----+
|application_number|application_dt|account_id|var1|var2|
+------------------+--------------+----------+----+----+
| 1| 2018-07-31| 215| a| aaa|
| 3| 2017-10-28| 201| c| ccc|
| 2| 2018-07-30| null| b| bbb|
+------------------+--------------+----------+----+----+
I have found a similar approach to do so by using the statement:
join_elem = "df_1.application_number ==
df_2.application_number|df_1.application_dt ==
df_2.application_dt|F.coalesce(df_1.account_id,F.lit(0)) ==
F.coalesce(df_2.account_id,F.lit(0))".split("|")
join_elem_column = [eval(x) for x in join_elem]
But the design consideration do not allow me to use a full join expression and i am stuck with using the list of column names as join-key.
I have been trying to find a way to accommodate this coalesce thing into this list itself but have not found any success so far.
I would call this solution a workaround.
The issue here is that we have Null value for one of the keys in the DataFrame and the OP wants that rest of the key columns to be used instead. Why not assign an arbitrary value to this Null and then apply the join. Effectively this would be same thing like making a join on the remaining two keys.
# Let's replace Null with an arbitrary value, which has
# little chance of occurring in the Dataset. For eg; -100000
df1 = df1.withColumn('account_id', when(col('account_id').isNull(),-100000).otherwise(col('account_id')))
df2 = df2.withColumn('account_id', when(col('account_id').isNull(),-100000).otherwise(col('account_id')))
# Do a FULL Join
df = df1.join(df2,['application_number','application_dt','account_id'],'full')
# Replace the arbitrary value back with Null.
df = df.withColumn('account_id', when(col('account_id')== -100000, None).otherwise(col('account_id')))
df.show()
+------------------+--------------+----------+----+----+
|application_number|application_dt|account_id|var1|var2|
+------------------+--------------+----------+----+----+
| 1| 2018-07-31| 215| a| aaa|
| 2| 2018-07-30| null| b| bbb|
| 3| 2017-10-28| 201| c| ccc|
+------------------+--------------+----------+----+----+
I'm using the Spark Graphframes library to create an identity resolution system. I have been able to use spark to find matches. My plan was to use a graph to find transient links between people and assign a single id to them for further analysis etc.
I used the following data (from the public febrl database):
vertex data sample:
+----------+--------+-------------+-------------------+--------------------+----------------+--------+-----+-------------+----------+---+-----+
|given_name| surname|street_number| address_1| address_2| suburb|postcode|state|date_of_birth|soc_sec_id| id|block|
+----------+--------+-------------+-------------------+--------------------+----------------+--------+-----+-------------+----------+---+-----+
| michaela| neumann| 8| stanley street| miami| winston hills| 4223| nsw| 19151111| 5304218| 0| mneu|
| courtney| painter| 12| pinkerton circuit| bega flats| richlands| 4560| vic| 19161214| 4066625| 1| cpai|
| charles| green| 38|salkauskas crescent| kela| dapto| 4566| nsw| 19480930| 4365168| 2| cgre|
| vanessa| parr| 905| macquoid place| broadbridge manor| south grafton| 2135| sa| 19951119| 9239102| 3| vpar|
| mikayla|malloney| 37| randwick road| avalind|hoppers crossing| 4552| vic| 19860208| 7207688| 4| mmal|
| blake| howie| 1| cutlack street|belmont park belt...| budgewoi| 6017| vic| 19250301| 5180548| 5| bhow|
| blakeston| broadby| 53| traeger street| valley of springs| north ward| 3083| qld| 19120907| 4308555| 7| bbro|
| edward| denholm| 10| corin place| gold tyne| clayfield| 4221| vic| 19660306| 7119771| 9| eden|
| charlie|alderson| 266|hawkesbury crescent|deergarden caravn...| cooma| 4128| vic| 19440908| 1256748| 10| cald|
| molly| roche| 59|willoughby crescent| donna valley| carrara| 4825| nsw| 19200712| 1847058| 11| mroc|
+----------+--------+-------------+-------------------+--------------------+----------------+--------+-----+-------------+----------+---+-----+
Edge data sample:
+---+-----+-----+
|src| dst|match|
+---+-----+-----+
| 0|10000| 1|
| 1|17750| 1|
| 1|10001| 1|
| 1| 7750| 1|
| 2|19656| 1|
| 2|10002| 1|
| 2| 9656| 1|
| 3|19119| 1|
| 3|10003| 1|
| 3| 9119| 1|
+---+-----+-----+
created graph:
g = GraphFrame(vertix_data, edge_data)
used connected components:
connected = g.connectedComponents(algorithm='graphframes')
which results in:
+----------+--------+-------------+-------------------+--------------------+----------------+--------+-----+-------------+----------+---+-----+---------+
|given_name| surname|street_number| address_1| address_2| suburb|postcode|state|date_of_birth|soc_sec_id| id|block|component|
+----------+--------+-------------+-------------------+--------------------+----------------+--------+-----+-------------+----------+---+-----+---------+
| michaela| neumann| 8| stanley street| miami| winston hills| 4223| nsw| 19151111| 5304218| 0| mneu| 0|
| courtney| painter| 12| pinkerton circuit| bega flats| richlands| 4560| vic| 19161214| 4066625| 1| cpai| 1|
| charles| green| 38|salkauskas crescent| kela| dapto| 4566| nsw| 19480930| 4365168| 2| cgre| 2|
| vanessa| parr| 905| macquoid place| broadbridge manor| south grafton| 2135| sa| 19951119| 9239102| 3| vpar| 3|
| mikayla|malloney| 37| randwick road| avalind|hoppers crossing| 4552| vic| 19860208| 7207688| 4| mmal| 4|
| blake| howie| 1| cutlack street|belmont park belt...| budgewoi| 6017| vic| 19250301| 5180548| 5| bhow| 5|
| blakeston| broadby| 53| traeger street| valley of springs| north ward| 3083| qld| 19120907| 4308555| 7| bbro| 7|
| edward| denholm| 10| corin place| gold tyne| clayfield| 4221| vic| 19660306| 7119771| 9| eden| 9|
| charlie|alderson| 266|hawkesbury crescent|deergarden caravn...| cooma| 4128| vic| 19440908| 1256748| 10| cald| 10|
| molly| roche| 59|willoughby crescent| donna valley| carrara| 4825| nsw| 19200712| 1847058| 11| mroc| 11|
+----------+--------+-------------+-------------------+--------------------+----------------+--------+-----+-------------+----------+---+-----+---------+
The component column doesn't always increase in increments of 1 but seems to randomly skip numbers, I would like to make sure that the increase in increments of one as using this number to assign each person an id.
Does anybody know why Graphframes does this?
When I look further into this, for the approx 20,000 rows in my development dataframe approx 17% of entries have a skip in them. In extreme cases the gap can be up to around 20-30, i.e. one rows id is 5846 and the next one is 5868. My worry is, when I go scale in millions and hundreds of millions the gaps will get very large between id's which could create problems down the line.
TL;DR: Why does Sparks connected components seem to randomly skip values and not always increment by 1?
Graphframes documentation never promises consecutive ids - instead the only guarantee it provides is:
The resulting DataFrame contains all the vertex information and one additional column:
component (LongType): unique ID for this component
In practice GraphX implementation uses the smallest ID in the component ("return a graph with the vertex value containing the lowest vertex id in the connected component containing that vertex"), and Graphframes seems to do the same thing.
Like #user10802135 said, the component values are not guaranteed to be sequential. If you want to make them sequential, you'll need to do some post-processing on the component field. A pyspark solution to this would look something like this:
import pyspark.sql.functions as F
from pyspark.sql import Window
# Define our window for partitioning data on - necessary for dense_rank() function
windowSpec = Window.partitionBy(F.lit(1)).orderBy('component')
# Redefine the component field, now in sequential order
df = df.withColumn('component', F.dense_rank().over(windowSpec))
By partitioning by the literal value of 1, all rows are considered in the dense_rank(), and ranking order is determined by the .orderBy() argument. In this case the .orderBy() argument is set to 'component', which will order in ascending order by default. The .dense_rank() functionality ensures that records under the same component will be given the same returned value, something that using rank() does NOT ensure.
There are some great examples and explanations of .dense_rank() and other window functions here.
I'm new to using Spark in Python and have been unable to solve this problem: After running groupBy on a pyspark.sql.dataframe.DataFrame
df = sqlsc.read.json("data.json")
df.groupBy('teamId')
how can you choose N random samples from each resulting group (grouped by teamId) without replacement?
I'm basically trying to choose N random users from each team, maybe using groupBy is wrong to start with?
Well, it is kind of wrong. GroupedData is not really designed for a data access. It just describes grouping criteria and provides aggregation methods. See my answer to Using groupBy in Spark and getting back to a DataFrame for more details.
Another problem with this idea is selecting N random samples. It is a task which is really hard to achieve in parallel without psychical grouping of data and it is not something that happens when you call groupBy on a DataFrame:
There are at least two ways to handle this:
convert to RDD, groupBy and perform local sampling
import random
n = 3
def sample(iter, n):
rs = random.Random() # We should probably use os.urandom as a seed
return rs.sample(list(iter), n)
df = sqlContext.createDataFrame(
[(x, y, random.random()) for x in (1, 2, 3) for y in "abcdefghi"],
("teamId", "x1", "x2"))
grouped = df.rdd.map(lambda row: (row.teamId, row)).groupByKey()
sampled = sqlContext.createDataFrame(
grouped.flatMap(lambda kv: sample(kv[1], n)))
sampled.show()
## +------+---+-------------------+
## |teamId| x1| x2|
## +------+---+-------------------+
## | 1| g| 0.81921738561455|
## | 1| f| 0.8563875814036598|
## | 1| a| 0.9010425238735935|
## | 2| c| 0.3864428179837973|
## | 2| g|0.06233470405822805|
## | 2| d|0.37620872770129155|
## | 3| f| 0.7518901502732027|
## | 3| e| 0.5142305439671874|
## | 3| d| 0.6250620479303716|
## +------+---+-------------------+
use window functions
from pyspark.sql import Window
from pyspark.sql.functions import col, rand, rowNumber
w = Window.partitionBy(col("teamId")).orderBy(col("rnd_"))
sampled = (df
.withColumn("rnd_", rand()) # Add random numbers column
.withColumn("rn_", rowNumber().over(w)) # Add rowNumber over windw
.where(col("rn_") <= n) # Take n observations
.drop("rn_") # drop helper columns
.drop("rnd_"))
sampled.show()
## +------+---+--------------------+
## |teamId| x1| x2|
## +------+---+--------------------+
## | 1| f| 0.8563875814036598|
## | 1| g| 0.81921738561455|
## | 1| i| 0.8173912535268248|
## | 2| h| 0.10862995810038856|
## | 2| c| 0.3864428179837973|
## | 2| a| 0.6695356657072442|
## | 3| b|0.012329360826023095|
## | 3| a| 0.6450777858109182|
## | 3| e| 0.5142305439671874|
## +------+---+--------------------+
but I am afraid both will be rather expensive. If size of the individual groups is balanced and relatively large I would simply use DataFrame.randomSplit.
If number of groups is relatively small it is possible to try something else:
from pyspark.sql.functions import count, udf
from pyspark.sql.types import BooleanType
from operator import truediv
counts = (df
.groupBy(col("teamId"))
.agg(count("*").alias("n"))
.rdd.map(lambda r: (r.teamId, r.n))
.collectAsMap())
# This defines fraction of observations from a group which should
# be taken to get n values
counts_bd = sc.broadcast({k: truediv(n, v) for (k, v) in counts.items()})
to_take = udf(lambda k, rnd: rnd <= counts_bd.value.get(k), BooleanType())
sampled = (df
.withColumn("rnd_", rand())
.where(to_take(col("teamId"), col("rnd_")))
.drop("rnd_"))
sampled.show()
## +------+---+--------------------+
## |teamId| x1| x2|
## +------+---+--------------------+
## | 1| d| 0.14815204548854788|
## | 1| f| 0.8563875814036598|
## | 1| g| 0.81921738561455|
## | 2| a| 0.6695356657072442|
## | 2| d| 0.37620872770129155|
## | 2| g| 0.06233470405822805|
## | 3| b|0.012329360826023095|
## | 3| h| 0.9022527556458557|
## +------+---+--------------------+
In Spark 1.5+ you can replace udf with a call to sampleBy method:
df.sampleBy("teamId", counts_bd.value)
It won't give you exact number of observations but should be good enough most of the time as long as a number of observations per group is large enough to get proper samples. You can also use sampleByKey on a RDD in a similar way.
I found this one more dataframey, rather than going into rdd way.
You can use window function to create ranking within a group, where ranking can be random to suit your case. Then, you can filter based on the number of samples (N) you want for each group
window_1 = Window.partitionBy(data['teamId']).orderBy(F.rand())
data_1 = data.select('*', F.rank().over(window_1).alias('rank')).filter(F.col('rank') <= N).drop('rank')
Here's an alternative using Pandas DataFrame.Sample method. This uses the spark applyInPandas method to distribute the groups, available from Spark 3.0.0. This allows you to select an exact number of rows per group.
I've added args and kwargs to the function so you can access the other arguments of DataFrame.Sample.
def sample_n_per_group(n, *args, **kwargs):
def sample_per_group(pdf):
return pdf.sample(n, *args, **kwargs)
return sample_per_group
df = spark.createDataFrame(
[
(1, 1.0),
(1, 2.0),
(2, 3.0),
(2, 5.0),
(2, 10.0)
],
("id", "v")
)
(df.groupBy("id")
.applyInPandas(
sample_n_per_group(2, random_state=2),
schema=df.schema
)
)
To be aware of the limitations for very large groups, from the documentation:
This function requires a full shuffle. All the data of a group will be
loaded into memory, so the user should be aware of the potential OOM
risk if data is skewed and certain groups are too large to fit in
memory.
See also here:
How take a random row from a PySpark DataFrame?