Calculate relative frequency of bigrams in PySpark - python

I'm trying to count word pairs in a text file. First, I've done some pre-processing on the text, and then I counted word pairs as shown below:
((Aspire, to), 1) ; ((to, inspire), 4) ; ((inspire, before), 38)...
Now, I want to report the 1000 most frequent pairs, sorted by :
Word (second word of the pair)
Relative frequency (pair occurences / 2nd word total occurences)
Here's what I've done so far
from pyspark.sql import SparkSession
import re
spark = SparkSession.builder.appName("Bigram occurences and relative frequencies").master("local[*]").getOrCreate()
sc = spark.sparkContext
text = sc.textFile("big.txt")
tokens = text.map(lambda x: x.lower()).map(lambda x: re.split("[\s,.;:!?]+", x))
pairs = tokens.flatMap(lambda xs: (tuple(x) for x in zip(xs, xs[1:]))).map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
frame = pairs.toDF(['pair', 'count'])
# Dataframe ordered by the most frequent pair to the least
most_frequent = frame.sort(frame['count'].desc())
# For each row, trying to add a column with the relative frequency, but I'm getting an error
with_rf = frame.withColumn("rf", frame['count'] / (frame.pair._2.sum()))
I think I'm relatively close to the result I want but I can't figure it out. I'm new to Spark and DataFrames in general.
I also tried
import pyspark.sql.functions as F
frame.groupBy(frame['pair._2']).agg((F.col('count') / F.sum('count')).alias('rf')).show()
Any help would be appreciated.
EDIT: here's a sample of the frame dataframe
+--------------------+-----+
| pair|count|
+--------------------+-----+
|{project, gutenberg}| 69|
| {gutenberg, ebook}| 14|
| {ebook, of}| 5|
| {adventures, of}| 6|
| {by, sir}| 12|
| {conan, doyle)}| 1|
| {changing, all}| 2|
| {all, over}| 24|
+--------------------+-----+
root
|-- pair: struct (nullable = true)
| |-- _1: string (nullable = true)
| |-- _2: string (nullable = true)
|-- count: long (nullable = true)

The relative frequency can be computed by using window function, that partitions by the second word in the pair and applies a sum operation.
Then, we limit the entries in the df to the top x, based on count and finally order by the second word in pair and the relative frequency.
from pyspark.sql import functions as F
from pyspark.sql import Window as W
data = [(("project", "gutenberg"), 69,),
(("gutenberg", "ebook"), 14,),
(("ebook", "of"), 5,),
(("adventures", "of"), 6,),
(("by", "sir"), 12,),
(("conan", "doyle"), 1,),
(("changing", "all"), 2,),
(("all", "over"), 24,), ]
df = spark.createDataFrame(data, ("pair", "count", ))
ws = W.partitionBy(F.col("pair")._2).rowsBetween(W.unboundedPreceding, W.unboundedFollowing)
(df.withColumn("relative_freq", F.col("count") / F.sum("count").over(ws))
.orderBy(F.col("count").desc())
.limit(3) # change here to select top 1000
.orderBy(F.desc(F.col("pair")._2), F.col("relative_freq").desc())
).show()
"""
+--------------------+-----+-------------+
| pair|count|relative_freq|
+--------------------+-----+-------------+
| {all, over}| 24| 1.0|
|{project, gutenberg}| 69| 1.0|
| {gutenberg, ebook}| 14| 1.0|
+--------------------+-----+-------------+
"""

Related

Pyspark use partition or groupby with agg and datediff

I'm new to Pyspark.
I would like to find the products not seen after 10 days from the first day they entered the store. And create a column in dataframe and set it to 1 for these products and 0 for the rest.
First I need to group the data based on product_id, then find the maximum of the seen_date. And finally calculate the difference between import_date and max(seen_date) in the groups. And finally create a new column based on the value of date_diff in each group.
Following is the code I used to first get the difference between the import_date and seen_date, but it gives error:
from pyspark.sql.window import Window
from pyspark.sql import functions as F
w = (Window()
.partitionBy(df.product_id)
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
df.withColumn("date_diff", F.datediff(F.max(F.from_unixtime(F.col("import_date")).over(w)), F.from_unixtime(F.col("seen_date"))))
Error:
AnalysisException: It is not allowed to use a window function inside an aggregate function. Please use the inner window function in a sub-query.
This is the rest of my code to define a new column based on the date_diff:
not_seen = udf(lambda x: 0 if x >10 else 1, IntegerType())
df = df.withColumn('not_seen', not_seen("date_diff"))
Q: Can someone provide a fix for this code or a better approach to solve this problem?
sample data generation:
columns = ["product_id","import_date", "seen_date"]
data = [("123", "2014-05-06", "2014-05-07"),
("123", "2014-05-06", "2014-06-11"),
("125", "2015-01-02", "2015-01-03"),
("125", "2015-01-02", "2015-01-04"),
("128", "2015-08-06", "2015-08-25")]
dfFromData2 = spark.createDataFrame(data).toDF(*columns)
dfFromData2 = dfFromData2.withColumn("import_date",F.unix_timestamp(F.col("import_date"),'yyyy-MM-dd'))
dfFromData2 = dfFromData2.withColumn("seen_date",F.unix_timestamp(F.col("seen_date"),'yyyy-MM-dd'))
+----------+-----------+----------+
|product_id|import_date| seen_date|
+----------+-----------+----------+
| 123| 1399334400|1399420800|
| 123| 1399334400|1402444800|
| 125| 1420156800|1420243200|
| 125| 1420156800|1420329600|
| 128| 1438819200|1440460800|
+----------+-----------+----------+
columns = ["product_id","import_date", "seen_date"]
data = [("123", "2014-05-06", "2014-05-07"),
("123", "2014-05-06", "2014-06-11"),
("125", "2015-01-02", "2015-01-03"),
("125", "2015-01-02", "2015-01-04"),
("128", "2015-08-06", "2015-08-25")]
df = spark.createDataFrame(data).toDF(*columns)
df = df.withColumn("import_date",F.to_date(F.col("import_date"),'yyyy-MM-dd'))
df = df.withColumn("seen_date",F.to_date(F.col("seen_date"),'yyyy-MM-dd'))
from pyspark.sql.window import Window
from pyspark.sql import functions as F
w = (Window()
.partitionBy(df.product_id)
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
df\
.withColumn('max_import_date', F.max(F.col("import_date")).over(w))\
.withColumn("date_diff", F.datediff(F.col('seen_date'), F.col('max_import_date')))\
.withColumn('not_seen', F.when(F.col('date_diff') > 10, 0).otherwise(1))\
.show()
+----------+-----------+----------+---------------+---------+--------+
|product_id|import_date| seen_date|max_import_date|date_diff|not_seen|
+----------+-----------+----------+---------------+---------+--------+
| 123| 2014-05-06|2014-05-07| 2014-05-06| 1| 1|
| 123| 2014-05-06|2014-06-11| 2014-05-06| 36| 0|
| 125| 2015-01-02|2015-01-03| 2015-01-02| 1| 1|
| 125| 2015-01-02|2015-01-04| 2015-01-02| 2| 1|
| 128| 2015-08-06|2015-08-25| 2015-08-06| 19| 0|
+----------+-----------+----------+---------------+---------+--------+
You can use the max windowing function to extract the max date.
dfFromData2 = dfFromData2.withColumn(
'not_seen',
F.expr('if(datediff(max(from_unixtime(seen_date)) over (partition by product_id), from_unixtime(import_date)) > 10, 1, 0)')
)
dfFromData2.show(truncate=False)
# +----------+-----------+----------+--------+
# |product_id|import_date|seen_date |not_seen|
# +----------+-----------+----------+--------+
# |125 |1420128000 |1420214400|0 |
# |125 |1420128000 |1420300800|0 |
# |123 |1399305600 |1399392000|1 |
# |123 |1399305600 |1402416000|1 |
# |128 |1438790400 |1440432000|1 |
# +----------+-----------+----------+--------+

Pyspark: sum over a window based on a condition

Consider the simple DataFrame:
from pyspark import SparkContext
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.sql.functions import pandas_udf, PandasUDFType
spark = SparkSession.builder.appName('Trial').getOrCreate()
simpleData = (("2000-04-17", "144", 1), \
("2000-07-06", "015", 1), \
("2001-01-23", "015", -1), \
("2001-01-18", "144", -1), \
("2001-04-17", "198", 1), \
("2001-04-18", "036", -1), \
("2001-04-19", "012", -1), \
("2001-04-19", "188", 1), \
("2001-04-25", "188", 1),\
("2001-04-27", "015", 1) \
)
columns= ["dates", "id", "eps"]
df = spark.createDataFrame(data = simpleData, schema = columns)
df.printSchema()
df.show(truncate=False)
Out:
root
|-- dates: string (nullable = true)
|-- id: string (nullable = true)
|-- eps: long (nullable = true)
+----------+---+---+
|dates |id |eps|
+----------+---+---+
|2000-04-17|144|1 |
|2000-07-06|015|1 |
|2001-01-23|015|-1 |
|2001-01-18|144|-1 |
|2001-04-17|198|1 |
|2001-04-18|036|-1 |
|2001-04-19|012|-1 |
|2001-04-19|188|1 |
|2001-04-25|188|1 |
|2001-04-27|015|1 |
+----------+---+---+
I would like to sum the values in the eps column over a rolling window keeping only the last value for any given ID in the id column. For example, defining a window of 5 rows and assuming we are on 2001-04-17, I want to sum only the last eps value for each given unique ID. In the 5 rows we have only 3 different ID, so the sum must be of 3 elements: -1 for the ID 144 (forth row), -1 for the ID 015 (third row) and 1 for the ID 198 (fifth row) for a total of -1.
In my mind, within the rolling window I should do something like F.sum(groupBy('id').agg(F.last('eps'))) that of course is not possible to achieve in a rolling window.
I obtained the desired result using a UDF.
#pandas_udf(IntegerType(), PandasUDFType.GROUPEDAGG)
def fun_sum(id, eps):
df = pd.DataFrame()
df['id'] = id
df['eps'] = eps
value = df.groupby('id').last().sum()
return value
And then:
w = Window.orderBy('dates').rowsBetween(-5,0)
df = df.withColumn('sum', fun_sum(F.col('id'), F.col('eps')).over(w))
The problem is that my dataset contains more than 8 milion rows and performing this task with this UDF takes about 2 hours.
I was wandering whether there is a way to achieve the same result with built-in PySpark functions avoiding using a UDF or at least whether there is a way to improve the performance of my UDF.
For completeness, the desired output should be:
+----------+---+---+----+
|dates |id |eps|sum |
+----------+---+---+----+
|2000-04-17|144|1 |1 |
|2000-07-06|015|1 |2 |
|2001-01-23|015|-1 |0 |
|2001-01-18|144|-1 |-2 |
|2001-04-17|198|1 |-1 |
|2001-04-18|036|-1 |-2 |
|2001-04-19|012|-1 |-3 |
|2001-04-19|188|1 |-1 |
|2001-04-25|188|1 |0 |
|2001-04-27|015|1 |0 |
+----------+---+---+----+
EDIT: the rseult must also be achievable using a .rangeBetween() window.
In case you haven't figured it out yet, here's one way of achieving it.
Assuming that df is defined and initialised the way you defined and initialised it in your question.
Import the required functions and classes:
from pyspark.sql.functions import row_number, col
from pyspark.sql.window import Window
Create the necessary WindowSpec:
window_spec = (
Window
# Partition by 'id'.
.partitionBy(df.id)
# Order by 'dates', latest dates first.
.orderBy(df.dates.desc())
)
Create a DataFrame with partitioned data:
partitioned_df = (
df
# Use the window function 'row_number()' to populate a new column
# containing a sequential number starting at 1 within a window partition.
.withColumn('row', row_number().over(window_spec))
# Only select the first entry in each partition (i.e. the latest date).
.where(col('row') == 1)
)
Just in case you want to double-check the data:
partitioned_df.show()
# +----------+---+---+---+
# | dates| id|eps|row|
# +----------+---+---+---+
# |2001-04-19|012| -1| 1|
# |2001-04-25|188| 1| 1|
# |2001-04-27|015| 1| 1|
# |2001-04-17|198| 1| 1|
# |2001-01-18|144| -1| 1|
# |2001-04-18|036| -1| 1|
# +----------+---+---+---+
Group and aggregate the data:
sum_rows = (
partitioned_df
# Aggragate data.
.groupBy()
# Sum all rows in 'eps' column.
.sum('eps')
# Get all records as a list of Rows.
.collect()
)
Get the result:
print(f"sum eps: {sum_rows[0][0]})
# sum eps: 0

spark extract columns from string

Need help in parsing a string, where it contains values for each attribute. below is my sample string...
otherPartofString Name=<Series VR> Type=<1Ac4> SqVal=<34> conn ID=<2>
sometimes, the string can include other values with a different delimiter like
otherPartofString Name=<Series X> Type=<1B3> SqVal=<34> conn ID=<2> conn Loc=sfo dest=chc bridge otherpartofString..
the output columns will be
Name | Type | SqVal | ID | Loc | dest
-------------------------------------------
Series VR | 1Ac4 | 34 | 2 | null | null
Series X | 1B3 | 34 | 2 | sfo | chc
As we discussed, to use str_to_map function on your sample data, we can setup pairDelim and keyValueDelim to the following:
pairDelim: '(?i)>? *(?=Name|Type|SqVal|conn ID|conn Loc|dest|$)'
keyValueDelim: '=<?'
Where pariDelim is case-insensitive (?i) with an optional > followed by zero or more SPACEs, then followed by one of the pre-defined keys (we use '|'.join(keys) to generate it dynamically) or the end of string anchor $. keyValueDelim is an '=' with an optional <.
from pyspark.sql import functions as F
df = spark.createDataFrame([
("otherPartofString Name=<Series VR> Type=<1Ac4> SqVal=<34> conn ID=<2>",),
("otherPartofString Name=<Series X> Type=<1B3> SqVal=<34> conn ID=<2> conn Loc=sfo dest=chc bridge otherpartofString..",)
],["value"])
keys = ["Name", "Type", "SqVal", "conn ID", "conn Loc", "dest"]
# add the following conf for Spark 3.0 to overcome duplicate map key ERROR
#spark.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN")
df.withColumn("m", F.expr("str_to_map(value, '(?i)>? *(?={}|$)', '=<?')".format('|'.join(keys)))) \
.select([F.col('m')[k].alias(k) for k in keys]) \
.show()
+---------+----+-----+-------+--------+--------------------+
| Name|Type|SqVal|conn ID|conn Loc| dest|
+---------+----+-----+-------+--------+--------------------+
|Series VR|1Ac4| 34| 2| null| null|
| Series X| 1B3| 34| 2| sfo|chc bridge otherp...|
+---------+----+-----+-------+--------+--------------------+
We will need to do some post-processing to the values of the last mapped-key, since there is no anchor or pattern to distinguish them from other unrelated text (this could be a problem as it might happen on any keys), please let me know if you can specify any pattern.
Edit: If using map is less efficient for case-insensitive search since it requires some expensive pre-processing, try the following:
ptn = '|'.join(keys)
df.select("*", *[F.regexp_extract('value', r'(?i)\b{0}=<?([^=>]+?)>? *(?={1}|$)'.format(k,ptn), 1).alias(k) for k in keys]).show()
In case the angle brackets < and > are used only when values or their next adjacent key contain any non-word chars, it can be simplified with some pre-processing:
df.withColumn('value', F.regexp_replace('value','=(\w+)','=<$1>')) \
.select("*", *[F.regexp_extract('value', r'(?i)\b{0}=<([^>]+)>'.format(k), 1).alias(k) for k in keys]) \
.show()
Edit-2: added a dictionary to handle key aliases:
keys = ["Name", "Type", "SqVal", "ID", "Loc", "dest"]
# aliases are case-insensitive and added only if exist
key_aliases = {
'Type': [ 'ThisType', 'AnyName' ],
'ID': ['conn ID'],
'Loc': ['conn Loc']
}
# set up regex pattern for each key differently
key_ptns = [ (k, '|'.join([k, *key_aliases[k]]) if k in key_aliases else k) for k in keys ]
#[('Name', 'Name'),
# ('Type', 'Type|ThisType|AnyName'),
# ('SqVal', 'SqVal'),
# ('ID', 'ID|conn ID'),
# ('Loc', 'Loc|conn Loc'),
# ('dest', 'dest')]
df.withColumn('value', F.regexp_replace('value','=(\w+)','=<$1>')) \
.select("*", *[F.regexp_extract('value', r'(?i)\b(?:{0})=<([^>]+)>'.format(p), 1).alias(k) for k,p in key_ptns]) \
.show()
+--------------------+---------+----+-----+---+---+----+
| value| Name|Type|SqVal| ID|Loc|dest|
+--------------------+---------+----+-----+---+---+----+
|otherPartofString...|Series VR|1Ac4| 34| 2| | |
|otherPartofString...| Series X| 1B3| 34| 2|sfo| chc|
+--------------------+---------+----+-----+---+---+----+

Creating a new dataframe from a pyspark dataframe column efficiently

I wonder what is the most efficient way to extract a column in pyspark dataframe and turn them into a new dataframe? The following code runs without any problem with small datasets, but runs very slow and even causes out-of-memory error. I wonder how can I improve the efficiency of this code?
pdf_edges = sdf_grp.rdd.flatMap(lambda x: x).collect()
edgelist = reduce(lambda a, b: a + b, pdf_edges, [])
sdf_edges = spark.createDataFrame(edgelist)
In pyspark dataframe sdf_grp, The column "pairs" contains information as below
+-------------------------------------------------------------------+
|pairs |
+-------------------------------------------------------------------+
|[[39169813, 24907492], [39169813, 19650174]] |
|[[10876191, 139604770]] |
|[[6481958, 22689674]] |
|[[73450939, 114203936], [73450939, 21226555], [73450939, 24367554]]|
|[[66306616, 32911686], [66306616, 19319140], [66306616, 48712544]] |
+-------------------------------------------------------------------+
with a schema of
root
|-- pairs: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- node1: integer (nullable = false)
| | |-- node2: integer (nullable = false)
I'd like to convert them into a new dataframe sdf_edges looks like below
+---------+---------+
| node1| node2|
+---------+---------+
| 39169813| 24907492|
| 39169813| 19650174|
| 10876191|139604770|
| 6481958| 22689674|
| 73450939|114203936|
| 73450939| 21226555|
| 73450939| 24367554|
| 66306616| 32911686|
| 66306616| 19319140|
| 66306616| 48712544|
+---------+---------+
The most efficient way to extract columns is avoiding collect(). When you call collect(), all the data is transfered to the driver and processed there. At better way to achieve what you want is using the explode() function. Have a look at the example below:
from pyspark.sql import types as T
import pyspark.sql.functions as F
schema = T.StructType([
T.StructField("pairs", T.ArrayType(
T.StructType([
T.StructField("node1", T.IntegerType()),
T.StructField("node2", T.IntegerType())
])
)
)
])
df = spark.createDataFrame(
[
([[39169813, 24907492], [39169813, 19650174]],),
([[10876191, 139604770]], ) ,
([[6481958, 22689674]] , ) ,
([[73450939, 114203936], [73450939, 21226555], [73450939, 24367554]],),
([[66306616, 32911686], [66306616, 19319140], [66306616, 48712544]],)
], schema)
df = df.select(F.explode('pairs').alias('exploded')).select('exploded.node1', 'exploded.node2')
df.show(truncate=False)
Output:
+--------+---------+
| node1 | node2 |
+--------+---------+
|39169813|24907492 |
|39169813|19650174 |
|10876191|139604770|
|6481958 |22689674 |
|73450939|114203936|
|73450939|21226555 |
|73450939|24367554 |
|66306616|32911686 |
|66306616|19319140 |
|66306616|48712544 |
+--------+---------+
Well, I just solve it with the below
sdf_edges = sdf_grp.select('pairs').rdd.flatMap(lambda x: x[0]).toDF()

Choosing random items from a Spark GroupedData Object

I'm new to using Spark in Python and have been unable to solve this problem: After running groupBy on a pyspark.sql.dataframe.DataFrame
df = sqlsc.read.json("data.json")
df.groupBy('teamId')
how can you choose N random samples from each resulting group (grouped by teamId) without replacement?
I'm basically trying to choose N random users from each team, maybe using groupBy is wrong to start with?
Well, it is kind of wrong. GroupedData is not really designed for a data access. It just describes grouping criteria and provides aggregation methods. See my answer to Using groupBy in Spark and getting back to a DataFrame for more details.
Another problem with this idea is selecting N random samples. It is a task which is really hard to achieve in parallel without psychical grouping of data and it is not something that happens when you call groupBy on a DataFrame:
There are at least two ways to handle this:
convert to RDD, groupBy and perform local sampling
import random
n = 3
def sample(iter, n):
rs = random.Random() # We should probably use os.urandom as a seed
return rs.sample(list(iter), n)
df = sqlContext.createDataFrame(
[(x, y, random.random()) for x in (1, 2, 3) for y in "abcdefghi"],
("teamId", "x1", "x2"))
grouped = df.rdd.map(lambda row: (row.teamId, row)).groupByKey()
sampled = sqlContext.createDataFrame(
grouped.flatMap(lambda kv: sample(kv[1], n)))
sampled.show()
## +------+---+-------------------+
## |teamId| x1| x2|
## +------+---+-------------------+
## | 1| g| 0.81921738561455|
## | 1| f| 0.8563875814036598|
## | 1| a| 0.9010425238735935|
## | 2| c| 0.3864428179837973|
## | 2| g|0.06233470405822805|
## | 2| d|0.37620872770129155|
## | 3| f| 0.7518901502732027|
## | 3| e| 0.5142305439671874|
## | 3| d| 0.6250620479303716|
## +------+---+-------------------+
use window functions
from pyspark.sql import Window
from pyspark.sql.functions import col, rand, rowNumber
w = Window.partitionBy(col("teamId")).orderBy(col("rnd_"))
sampled = (df
.withColumn("rnd_", rand()) # Add random numbers column
.withColumn("rn_", rowNumber().over(w)) # Add rowNumber over windw
.where(col("rn_") <= n) # Take n observations
.drop("rn_") # drop helper columns
.drop("rnd_"))
sampled.show()
## +------+---+--------------------+
## |teamId| x1| x2|
## +------+---+--------------------+
## | 1| f| 0.8563875814036598|
## | 1| g| 0.81921738561455|
## | 1| i| 0.8173912535268248|
## | 2| h| 0.10862995810038856|
## | 2| c| 0.3864428179837973|
## | 2| a| 0.6695356657072442|
## | 3| b|0.012329360826023095|
## | 3| a| 0.6450777858109182|
## | 3| e| 0.5142305439671874|
## +------+---+--------------------+
but I am afraid both will be rather expensive. If size of the individual groups is balanced and relatively large I would simply use DataFrame.randomSplit.
If number of groups is relatively small it is possible to try something else:
from pyspark.sql.functions import count, udf
from pyspark.sql.types import BooleanType
from operator import truediv
counts = (df
.groupBy(col("teamId"))
.agg(count("*").alias("n"))
.rdd.map(lambda r: (r.teamId, r.n))
.collectAsMap())
# This defines fraction of observations from a group which should
# be taken to get n values
counts_bd = sc.broadcast({k: truediv(n, v) for (k, v) in counts.items()})
to_take = udf(lambda k, rnd: rnd <= counts_bd.value.get(k), BooleanType())
sampled = (df
.withColumn("rnd_", rand())
.where(to_take(col("teamId"), col("rnd_")))
.drop("rnd_"))
sampled.show()
## +------+---+--------------------+
## |teamId| x1| x2|
## +------+---+--------------------+
## | 1| d| 0.14815204548854788|
## | 1| f| 0.8563875814036598|
## | 1| g| 0.81921738561455|
## | 2| a| 0.6695356657072442|
## | 2| d| 0.37620872770129155|
## | 2| g| 0.06233470405822805|
## | 3| b|0.012329360826023095|
## | 3| h| 0.9022527556458557|
## +------+---+--------------------+
In Spark 1.5+ you can replace udf with a call to sampleBy method:
df.sampleBy("teamId", counts_bd.value)
It won't give you exact number of observations but should be good enough most of the time as long as a number of observations per group is large enough to get proper samples. You can also use sampleByKey on a RDD in a similar way.
I found this one more dataframey, rather than going into rdd way.
You can use window function to create ranking within a group, where ranking can be random to suit your case. Then, you can filter based on the number of samples (N) you want for each group
window_1 = Window.partitionBy(data['teamId']).orderBy(F.rand())
data_1 = data.select('*', F.rank().over(window_1).alias('rank')).filter(F.col('rank') <= N).drop('rank')
Here's an alternative using Pandas DataFrame.Sample method. This uses the spark applyInPandas method to distribute the groups, available from Spark 3.0.0. This allows you to select an exact number of rows per group.
I've added args and kwargs to the function so you can access the other arguments of DataFrame.Sample.
def sample_n_per_group(n, *args, **kwargs):
def sample_per_group(pdf):
return pdf.sample(n, *args, **kwargs)
return sample_per_group
df = spark.createDataFrame(
[
(1, 1.0),
(1, 2.0),
(2, 3.0),
(2, 5.0),
(2, 10.0)
],
("id", "v")
)
(df.groupBy("id")
.applyInPandas(
sample_n_per_group(2, random_state=2),
schema=df.schema
)
)
To be aware of the limitations for very large groups, from the documentation:
This function requires a full shuffle. All the data of a group will be
loaded into memory, so the user should be aware of the potential OOM
risk if data is skewed and certain groups are too large to fit in
memory.
See also here:
How take a random row from a PySpark DataFrame?

Categories

Resources