Compare two different columns from two different pyspark dataframe - python

I'm trying to compare two different columns which are in two different data frames, and if I found a match I'm returning value 1 else None -
df1 =
df2 =
df1 (Expected_Output) =
I have tried the below code -
def getImpact(row):
match = df2.filter(df2.second_key == row)
if match.count() > 0:
return 1
return None
udf_sol = udf(lambda x: getImpact(x), IntegerType())
df1 = df1.withcolumn('impact',udf_sol(df1.first_key))
But getting below error -
TypeError: cannot pickle '_thread.RLock' object
Can anyone help me to achieve the expected output as shown above?
Thanks

Assuming first_key and second_key are unique , you can opt for a join across the dataframes -
More examples and explanation can be found here
from pyspark import SparkContext
from pyspark.sql import SQLContext
from functools import reduce
import pyspark.sql.functions as F
from pyspark.sql import Window
data_list1 = [
("abcd","Key1")
,("jkasd","Key2")
,("oigoa","Key3")
,("ad","Key4")
,("bas","Key5")
,("lkalsjf","Key6")
,("bsawva","Key7")
]
data_list2 = [
("cashj","Key1",10)
,("ax","Key11",12)
,("safa","Key5",21)
,("safasf","Key6",78)
,("vasv","Key3",4)
,("wgaga","Key8",0)
,("saasfas","Key7",10)
]
sparkDF1 = sql.createDataFrame(data_list1,['data','first_key'])
sparkDF2 = sql.createDataFrame(data_list2,['temp_data','second_key','frinks'])
>>> sparkDF1
+-------+---------+
| data|first_key|
+-------+---------+
| abcd| Key1|
| jkasd| Key2|
| oigoa| Key3|
| ad| Key4|
| bas| Key5|
|lkalsjf| Key6|
| bsawva| Key7|
+-------+---------+
>>> sparkDF2
+---------+----------+------+
|temp_data|second_key|frinks|
+---------+----------+------+
| cashj| Key1| 10|
| ax| Key11| 12|
| safa| Key5| 21|
| safasf| Key6| 78|
| vasv| Key3| 4|
| wgaga| Key8| 0|
| saasfas| Key7| 10|
+---------+----------+------+
#### Joining the dataframes on common columns
finalDF = sparkDF1.join(
sparkDF2
,(sparkDF1['first_key'] == sparkDF2['second_key'])
,'left'
).select(sparkDF1['*'],sparkDF2['frinks']).orderBy('frinks')
### Identifying impact if the frinks value is Null or Not
finalDF = finalDF.withColumn('impact',F.when(F.col('frinks').isNull(),0).otherwise(1))
>>> finalDF.show()
+-------+---------+------+------+
| data|first_key|frinks|impact|
+-------+---------+------+------+
| jkasd| Key2| null| 0|
| ad| Key4| null| 0|
| oigoa| Key3| 4| 1|
| abcd| Key1| 10| 1|
| bsawva| Key7| 10| 1|
| bas| Key5| 21| 1|
|lkalsjf| Key6| 78| 1|
+-------+---------+------+------+

import numpy as np
df1['final']= np.where(df1['first_key']==df2['second_key'],'1','None')

Related

filter then count for many different threshold

I want to calculate the number of lines that satisfy a condition on a very large dataframe which can be achieved by
df.filter(col("value") >= thresh).count()
I want to know the result for each threshold in range [1, 10]. Enumerate each threshold then do this action will scan the dataframe for 10 times. It's slow.
If I can achieve it by scanning the df only once?
Create an indicator column for each threshold, then sum:
import random
import pyspark.sql.functions as F
from pyspark.sql import Row
df = spark.createDataFrame([Row(value=random.randint(0,10)) for _ in range(1_000_000)])
df.select([
(F.col("value") >= thresh)
.cast("int")
.alias(f"ind_{thresh}")
for thresh in range(1,11)
]).groupBy().sum().show()
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
# |sum(ind_1)|sum(ind_2)|sum(ind_3)|sum(ind_4)|sum(ind_5)|sum(ind_6)|sum(ind_7)|sum(ind_8)|sum(ind_9)|sum(ind_10)|
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
# | 908971| 818171| 727240| 636334| 545463| 454279| 363143| 272460| 181729| 90965|
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
Using conditional aggregation with when expressions should do the job.
Here's an example:
from pyspark.sql import functions as F
df = spark.createDataFrame([(1,), (2,), (3,), (4,), (4,), (6,), (7,)], ["value"])
count_expr = [
F.count(F.when(F.col("value") >= th, 1)).alias(f"gte_{th}")
for th in range(1, 11)
]
df.select(*count_expr).show()
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
#|gte_1|gte_2|gte_3|gte_4|gte_5|gte_6|gte_7|gte_8|gte_9|gte_10|
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
#| 7| 6| 5| 4| 2| 2| 1| 0| 0| 0|
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
Using a user-defined function udf from pyspark.sql.functions:
import pandas as pd
import numpy as np
df = pd.DataFrame(np.random.randint(0,100, size=(20)), columns=['val'])
thres = [90, 80, 30] # these are the thresholds
thres.sort(reverse=True) # list needs to be sorted
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
spark = SparkSession.builder \
.master("local[2]") \
.appName("myApp") \
.getOrCreate()
sparkDF = spark.createDataFrame(df)
myUdf = udf(lambda x: 0 if x>thres[0] else 1 if x>thres[1] else 2 if x>thres[2] else 3)
sparkDF = sparkDF.withColumn("rank", myUdf(sparkDF.val))
sparkDF.show()
# +---+----+
# |val|rank|
# +---+----+
# | 28| 3|
# | 54| 2|
# | 19| 3|
# | 4| 3|
# | 74| 2|
# | 62| 2|
# | 95| 0|
# | 19| 3|
# | 55| 2|
# | 62| 2|
# | 33| 2|
# | 93| 0|
# | 81| 1|
# | 41| 2|
# | 80| 2|
# | 53| 2|
# | 14| 3|
# | 16| 3|
# | 30| 3|
# | 77| 2|
# +---+----+
sparkDF.groupby(['rank']).count().show()
# Out:
# +----+-----+
# |rank|count|
# +----+-----+
# | 3| 7|
# | 0| 2|
# | 1| 1|
# | 2| 10|
# +----+-----+
A value gets rank i if it's strictly greater than thres[i] but smaller or equal thres[i-1]. This should minimize the number of comparisons.
For thres = [90, 80, 30] we have the ranks 0-> [max, 90[, 1-> [90, 80[, 2->[80, 30[, 3->[30, min]

py4JJavaError. pyspark applying function to column

spark = 2.x
New to pyspark.
While encoding date related columns for training DNN keep on facing error mentioned in the title.
from df
day month ...
1 1
2 3
3 1 ...
I am trying to get cos, sine value for each column in order to capture their cyclic nature.
When applying function to column in pyspark udf worked fine until now. But below code doesn't work
def to_cos(x, _max):
return np.sin(2*np.pi*x / _max)
to_cos_udf = udf(to_cos, DecimalType())
df = df.withColumn("month", to_cos_udf("month", 12))
I've tried it with IntegerType and tried it with only one variable def to_cos(x) however none of them seem to work and outputs:
Py4JJavaError: An error occurred while calling 0.24702.showString.
Since you havent shared the entire Stacktrack from the error , not sure what is the actual error which is causing the failure
However by the code snippets you have shared , Firstly you need to update your UDF definition as below -
Will passing arguments to a UDF function using it with lambda is probably the best approach towards it , apart from that you can use partial
Data Preparation
df = pd.DataFrame({
'month':[i for i in range(0,12)],
})
sparkDF = sql.createDataFrame(df)
sparkDF.show()
+-----+
|month|
+-----+
| 0|
| 1|
| 2|
| 3|
| 4|
| 5|
| 6|
| 7|
| 8|
| 9|
| 10|
| 11|
+-----+
Custom UDF
def to_cos(x,_max):
try:
res = np.sin(2*np.pi*x / _max)
except Exception as e:
res = 0.0
return float(res)
max_cos = 12
to_cos_udf = F.udf(lambda x: to_cos(x,max_cos),FloatType())
sparkDF = sparkDF.withColumn('month_cos',to_cos_udf('month'))
sparkDF.show()
+-----+-------------+
|month| month_cos|
+-----+-------------+
| 0| 0.0|
| 1| 0.5|
| 2| 0.8660254|
| 3| 1.0|
| 4| 0.8660254|
| 5| 0.5|
| 6|1.2246469E-16|
| 7| -0.5|
| 8| -0.8660254|
| 9| -1.0|
| 10| -0.8660254|
| 11| -0.5|
+-----+-------------+
Custom UDF - Partial
from functools import partial
partial_func = partial(to_cos,_max=max_cos)
to_cos_partial_udf = F.udf(partial_func)
sparkDF = sparkDF.withColumn('month_cos',to_cos_partial_udf('month'))
sparkDF.show()
+-----+--------------------+
|month| month_cos|
+-----+--------------------+
| 0| 0.0|
| 1| 0.49999999999999994|
| 2| 0.8660254037844386|
| 3| 1.0|
| 4| 0.8660254037844388|
| 5| 0.49999999999999994|
| 6|1.224646799147353...|
| 7| -0.4999999999999998|
| 8| -0.8660254037844384|
| 9| -1.0|
| 10| -0.8660254037844386|
| 11| -0.5000000000000004|
+-----+--------------------+

Compare two dataframes Pyspark

I'm trying to compare two data frames with have same number of columns i.e. 4 columns with id as key column in both data frames
df1 = spark.read.csv("/path/to/data1.csv")
df2 = spark.read.csv("/path/to/data2.csv")
Now I want to append new column to DF2 i.e. column_names which is the list of the columns with different values than df1
df2.withColumn("column_names",udf())
DF1
+------+---------+--------+------+
| id | |name | sal | Address |
+------+---------+--------+------+
| 1| ABC | 5000 | US |
| 2| DEF | 4000 | UK |
| 3| GHI | 3000 | JPN |
| 4| JKL | 4500 | CHN |
+------+---------+--------+------+
DF2:
+------+---------+--------+------+
| id | |name | sal | Address |
+------+---------+--------+------+
| 1| ABC | 5000 | US |
| 2| DEF | 4000 | CAN |
| 3| GHI | 3500 | JPN |
| 4| JKL_M | 4800 | CHN |
+------+---------+--------+------+
Now I want DF3
DF3:
+------+---------+--------+------+--------------+
| id | |name | sal | Address | column_names |
+------+---------+--------+------+--------------+
| 1| ABC | 5000 | US | [] |
| 2| DEF | 4000 | CAN | [address] |
| 3| GHI | 3500 | JPN | [sal] |
| 4| JKL_M | 4800 | CHN | [name,sal] |
+------+---------+--------+------+--------------+
I saw this SO question, How to compare two dataframe and print columns that are different in scala. Tried that, however the result is different.
I'm thinking of going with a UDF function by passing row from each dataframe to udf and compare column by column and return column list. However for that both the data frames should be in sorted order so that same id rows will be sent to udf. Sorting is costly operation here. Any solution?
Assuming that we can use id to join these two datasets I don't think that there is a need for UDF. This could be solved just by using inner join, array and array_remove functions among others.
First let's create the two datasets:
df1 = spark.createDataFrame([
[1, "ABC", 5000, "US"],
[2, "DEF", 4000, "UK"],
[3, "GHI", 3000, "JPN"],
[4, "JKL", 4500, "CHN"]
], ["id", "name", "sal", "Address"])
df2 = spark.createDataFrame([
[1, "ABC", 5000, "US"],
[2, "DEF", 4000, "CAN"],
[3, "GHI", 3500, "JPN"],
[4, "JKL_M", 4800, "CHN"]
], ["id", "name", "sal", "Address"])
First we do an inner join between the two datasets then we generate the condition df1[col] != df2[col] for each column except id. When the columns aren't equal we return the column name otherwise an empty string. The list of conditions will consist the items of an array from which finally we remove the empty items:
from pyspark.sql.functions import col, array, when, array_remove
# get conditions for all columns except id
conditions_ = [when(df1[c]!=df2[c], lit(c)).otherwise("") for c in df1.columns if c != 'id']
select_expr =[
col("id"),
*[df2[c] for c in df2.columns if c != 'id'],
array_remove(array(*conditions_), "").alias("column_names")
]
df1.join(df2, "id").select(*select_expr).show()
# +---+-----+----+-------+------------+
# | id| name| sal|Address|column_names|
# +---+-----+----+-------+------------+
# | 1| ABC|5000| US| []|
# | 3| GHI|3500| JPN| [sal]|
# | 2| DEF|4000| CAN| [Address]|
# | 4|JKL_M|4800| CHN| [name, sal]|
# +---+-----+----+-------+------------+
Here is your solution with UDF, I have changed first dataframe name dynamically so that it will be not ambiguous during check. Go through below code and let me know in case any concerns.
>>> from pyspark.sql.functions import *
>>> df.show()
+---+----+----+-------+
| id|name| sal|Address|
+---+----+----+-------+
| 1| ABC|5000| US|
| 2| DEF|4000| UK|
| 3| GHI|3000| JPN|
| 4| JKL|4500| CHN|
+---+----+----+-------+
>>> df1.show()
+---+----+----+-------+
| id|name| sal|Address|
+---+----+----+-------+
| 1| ABC|5000| US|
| 2| DEF|4000| CAN|
| 3| GHI|3500| JPN|
| 4|JKLM|4800| CHN|
+---+----+----+-------+
>>> df2 = df.select([col(c).alias("x_"+c) for c in df.columns])
>>> df3 = df1.join(df2, col("id") == col("x_id"), "left")
//udf declaration
>>> def CheckMatch(Column,r):
... check=''
... ColList=Column.split(",")
... for cc in ColList:
... if(r[cc] != r["x_" + cc]):
... check=check + "," + cc
... return check.replace(',','',1).split(",")
>>> CheckMatchUDF = udf(CheckMatch)
//final column that required to select
>>> finalCol = df1.columns
>>> finalCol.insert(len(finalCol), "column_names")
>>> df3.withColumn("column_names", CheckMatchUDF(lit(','.join(df1.columns)),struct([df3[x] for x in df3.columns])))
.select(finalCol)
.show()
+---+----+----+-------+------------+
| id|name| sal|Address|column_names|
+---+----+----+-------+------------+
| 1| ABC|5000| US| []|
| 2| DEF|4000| CAN| [Address]|
| 3| GHI|3500| JPN| [sal]|
| 4|JKLM|4800| CHN| [name, sal]|
+---+----+----+-------+------------+
Python: PySpark version of my previous scala code.
import pyspark.sql.functions as f
df1 = spark.read.option("header", "true").csv("test1.csv")
df2 = spark.read.option("header", "true").csv("test2.csv")
columns = df1.columns
df3 = df1.alias("d1").join(df2.alias("d2"), f.col("d1.id") == f.col("d2.id"), "left")
for name in columns:
df3 = df3.withColumn(name + "_temp", f.when(f.col("d1." + name) != f.col("d2." + name), f.lit(name)))
df3.withColumn("column_names", f.concat_ws(",", *map(lambda name: f.col(name + "_temp"), columns))).select("d1.*", "column_names").show()
Scala: Here is my best approach for your problem.
val df1 = spark.read.option("header", "true").csv("test1.csv")
val df2 = spark.read.option("header", "true").csv("test2.csv")
val columns = df1.columns
val df3 = df1.alias("d1").join(df2.alias("d2"), col("d1.id") === col("d2.id"), "left")
columns.foldLeft(df3) {(df, name) => df.withColumn(name + "_temp", when(col("d1." + name) =!= col("d2." + name), lit(name)))}
.withColumn("column_names", concat_ws(",", columns.map(name => col(name + "_temp")): _*))
.show(false)
First, I join two dataframe into df3 and used the columns from df1. By folding left to the df3 with temp columns that have the value for column name when df1 and df2 has the same id and other column values.
After that, concat_ws for those column names and the null's are gone away and only the column names are left.
+---+----+----+-------+------------+
|id |name|sal |Address|column_names|
+---+----+----+-------+------------+
|1 |ABC |5000|US | |
|2 |DEF |4000|UK |Address |
|3 |GHI |3000|JPN |sal |
|4 |JKL |4500|CHN |name,sal |
+---+----+----+-------+------------+
The only thing different from your expected result is that the output is not a list but string.
p.s. I forgot to use PySpark but this is the normal spark, sorry.
You can get that query build for you in PySpark and Scala by the spark-extension package.
It provides the diff transformation that does exactly that.
from gresearch.spark.diff import *
options = DiffOptions().with_change_column('changes')
df1.diff_with_options(df2, options, 'id').show()
+----+-----------+---+---------+----------+--------+---------+------------+-------------+
|diff| changes| id|left_name|right_name|left_sal|right_sal|left_Address|right_Address|
+----+-----------+---+---------+----------+--------+---------+------------+-------------+
| N| []| 1| ABC| ABC| 5000| 5000| US| US|
| C| [Address]| 2| DEF| DEF| 4000| 4000| UK| CAN|
| C| [sal]| 3| GHI| GHI| 3000| 3500| JPN| JPN|
| C|[name, sal]| 4| JKL| JKL_M| 4500| 4800| CHN| CHN|
+----+-----------+---+---------+----------+--------+---------+------------+-------------+
While this is a simple example, diffing DataFrames can become complicated when wide schemas, insertions, deletions and null values are involved. That package is well-tested, so you don't have to worry about getting that query right yourself.
There is a wonderful package for pyspark that compares two dataframes. The name of the package is datacompy
https://capitalone.github.io/datacompy/
example code:
import datacompy as dc
comparison = dc.SparkCompare(spark, base_df=df1, compare_df=df2, join_columns=common_keys, match_rates=True)
comparison.report()
The above code will generate a summary report, and the one below it will give you the mismatches.
comparison.rows_both_mismatch.display()
There are also more fearures that you can explore.

Reshape pyspark dataframe to show moving window of item interactions

I have a large pyspark dataframe of subject interactions in long format--each row describes a subject interacting with some item of interest, along with a timestamp and a rank-order for that subject's interaction (i.e., first interaction is 1, second is 2, etc.). Here's a few rows:
+----------+---------+----------------------+--------------------+
| date|itemId |interaction_date_order| userId|
+----------+---------+----------------------+--------------------+
|2019-07-23| 10005880| 1|37 |
|2019-07-23| 10005903| 2|37 |
|2019-07-23| 10005903| 3|37 |
|2019-07-23| 12458442| 4|37 |
|2019-07-26| 10005903| 5|37 |
|2019-07-26| 12632813| 6|37 |
|2019-07-26| 12632813| 7|37 |
|2019-07-26| 12634497| 8|37 |
|2018-11-24| 12245677| 1|5 |
|2018-11-24| 12245677| 1|5 |
|2019-07-29| 12541871| 2|5 |
|2019-07-29| 12541871| 3|5 |
|2019-07-30| 12626854| 4|5 |
|2019-08-31| 12776880| 5|5 |
|2019-08-31| 12776880| 6|5 |
+----------+---------+----------------------+--------------------+
I need to reshape these data such that, for each subject, a row has a length-5 moving window of interactions. So then, something like this:
+------+--------+--------+--------+--------+--------+
|userId| i-2 | i-1 | i | i+1 | i+2|
+------+--------+--------+--------+--------+--------+
|37 |10005880|10005903|10005903|12458442|10005903|
|37 |10005903|10005903|12458442|10005903|12632813|
Does anyone have suggestions for how I might do this?
Import spark and everything
from pyspark.sql import *
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
sc = SparkContext('local')
spark = SparkSession(sc)
Create your dataframe
columns = '| date|itemId |interaction_date_order| userId|'.split('|')
lines = '''2019-07-23| 10005880| 1|37 |
2019-07-23| 10005903| 2|37 |
2019-07-23| 10005903| 3|37 |
2019-07-23| 12458442| 4|37 |
2019-07-26| 10005903| 5|37 |
2019-07-26| 12632813| 6|37 |
2019-07-26| 12632813| 7|37 |
2019-07-26| 12634497| 8|37 |
2018-11-24| 12245677| 1|5 |
2018-11-24| 12245677| 2|5 |
2019-07-29| 12541871| 3|5 |
2019-07-29| 12541871| 4|5 |
2019-07-30| 12626854| 5|5 |
2019-08-31| 12776880| 6|5 |
2019-08-31| 12776880| 7|5 |'''
Interaction = Row("date", "itemId", "interaction_date_order", "userId")
interactions = []
for line in lines.split('\n'):
column_values = line.split('|')
interaction = Interaction(column_values[0], int(column_values[1]), int(column_values[2]), int(column_values[3]))
interactions.append(interaction)
df = spark.createDataFrame(interactions)
now we have
df.show()
+----------+--------+----------------------+------+
| date| itemId|interaction_date_order|userId|
+----------+--------+----------------------+------+
|2019-07-23|10005880| 1| 37|
|2019-07-23|10005903| 2| 37|
|2019-07-23|10005903| 3| 37|
|2019-07-23|12458442| 4| 37|
|2019-07-26|10005903| 5| 37|
|2019-07-26|12632813| 6| 37|
|2019-07-26|12632813| 7| 37|
|2019-07-26|12634497| 8| 37|
|2018-11-24|12245677| 1| 5|
|2018-11-24|12245677| 2| 5|
|2019-07-29|12541871| 3| 5|
|2019-07-29|12541871| 4| 5|
|2019-07-30|12626854| 5| 5|
|2019-08-31|12776880| 6| 5|
|2019-08-31|12776880| 7| 5|
+----------+--------+----------------------+------+
Create a window and collect itemId with count
from pyspark.sql.window import Window
import pyspark.sql.functions as F
window = Window() \
.partitionBy('userId') \
.orderBy('interaction_date_order') \
.rowsBetween(Window.currentRow, Window.currentRow+4)
df2 = df.withColumn("itemId_list", F.collect_list('itemId').over(window))
df2 = df2.withColumn("itemId_count", F.count('itemId').over(window))
df_final = df2.where(df2['itemId_count'] == 5)
now we have
df_final.show()
+----------+--------+----------------------+------+--------------------+------------+
| date| itemId|interaction_date_order|userId| itemId_list|itemId_count|
+----------+--------+----------------------+------+--------------------+------------+
|2018-11-24|12245677| 1| 5|[12245677, 122456...| 5|
|2018-11-24|12245677| 2| 5|[12245677, 125418...| 5|
|2019-07-29|12541871| 3| 5|[12541871, 125418...| 5|
|2019-07-23|10005880| 1| 37|[10005880, 100059...| 5|
|2019-07-23|10005903| 2| 37|[10005903, 100059...| 5|
|2019-07-23|10005903| 3| 37|[10005903, 124584...| 5|
|2019-07-23|12458442| 4| 37|[12458442, 100059...| 5|
+----------+--------+----------------------+------+--------------------+------------+
Final touch
df_final2 = (df_final
.withColumn('i-2', df_final['itemId_list'][0])
.withColumn('i-1', df_final['itemId_list'][1])
.withColumn('i', df_final['itemId_list'][2])
.withColumn('i+1', df_final['itemId_list'][3])
.withColumn('i+2', df_final['itemId_list'][4])
.select('userId', 'i-2', 'i-1', 'i', 'i+1', 'i+2')
)
df_final2.show()
+------+--------+--------+--------+--------+--------+
|userId| i-2| i-1| i| i+1| i+2|
+------+--------+--------+--------+--------+--------+
| 5|12245677|12245677|12541871|12541871|12626854|
| 5|12245677|12541871|12541871|12626854|12776880|
| 5|12541871|12541871|12626854|12776880|12776880|
| 37|10005880|10005903|10005903|12458442|10005903|
| 37|10005903|10005903|12458442|10005903|12632813|
| 37|10005903|12458442|10005903|12632813|12632813|
| 37|12458442|10005903|12632813|12632813|12634497|
+------+--------+--------+--------+--------+--------+

How to concatenate to a null column in pyspark dataframe

I have a below dataframe and I wanted to update the rows dynamically with some values
input_frame.show()
+----------+----------+---------+
|student_id|name |timestamp|
+----------+----------+---------+
| s1|testuser | t1|
| s1|sampleuser| t2|
| s2|test123 | t1|
| s2|sample123 | t2|
+----------+----------+---------+
input_frame = input_frame.withColumn('test', sf.lit(None))
input_frame.show()
+----------+----------+---------+----+
|student_id| name|timestamp|test|
+----------+----------+---------+----+
| s1| testuser| t1|null|
| s1|sampleuser| t2|null|
| s2| test123| t1|null|
| s2| sample123| t2|null|
+----------+----------+---------+----+
input_frame = input_frame.withColumn('test', sf.concat(sf.col('test'),sf.lit('test')))
input_frame.show()
+----------+----------+---------+----+
|student_id| name|timestamp|test|
+----------+----------+---------+----+
| s1| testuser| t1|null|
| s1|sampleuser| t2|null|
| s2| test123| t1|null|
| s2| sample123| t2|null|
+----------+----------+---------+----+
I want to update the 'test' column with some values and apply the filter with partial matches on the column. But concatenating to null column resulting in a null column again. How can we do this?
use concat_ws, like this:
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([["1", "2"], ["2", None], ["3", "4"], ["4", "5"], [None, "6"]]).toDF("a", "b")
# This won't work
df = df.withColumn("concat", concat(df.a, df.b))
# This won't work
df = df.withColumn("concat + cast", concat(df.a.cast('string'), df.b.cast('string')))
# Do it like this
df = df.withColumn("concat_ws", concat_ws("", df.a, df.b))
df.show()
gives:
+----+----+------+-------------+---------+
| a| b|concat|concat + cast|concat_ws|
+----+----+------+-------------+---------+
| 1| 2| 12| 12| 12|
| 2|null| null| null| 2|
| 3| 4| 34| 34| 34|
| 4| 5| 45| 45| 45|
|null| 6| null| null| 6|
+----+----+------+-------------+---------+
Note specifically that casting a NULL column to string doesn't work as you wish, and will result in the entire row being NULL if any column is null.
There's no nice way of dealing with more complicated scenarios, but note that you can use a when statement in side a concat if you're willing to
suffer the verboseness of it, like this:
df.withColumn("concat_custom", concat(
when(df.a.isNull(), lit('_')).otherwise(df.a),
when(df.b.isNull(), lit('_')).otherwise(df.b))
)
To get, eg:
+----+----+-------------+
| a| b|concat_custom|
+----+----+-------------+
| 1| 2| 12|
| 2|null| 2_|
| 3| 4| 34|
| 4| 5| 45|
|null| 6| _6|
+----+----+-------------+
You can use the coalesce function, which returns first of its arguments which is not null, and provide a literal in the second place, which will be used in case the column has a null value.
df = df.withColumn("concat", concat(coalesce(df.a, lit('')), coalesce(df.b, lit(''))))
You can fill null values with empty strings:
import pyspark.sql.functions as f
from pyspark.sql.types import *
data = spark.createDataFrame([('s1', 't1'), ('s2', 't2')], ['col1', 'col2'])
data = data.withColumn('test', f.lit(None).cast(StringType()))
display(data.na.fill('').withColumn('test2', f.concat('col1', 'col2', 'test')))
Is that what you were looking for?

Categories

Resources