Here's the initial code
from pyspark.sql import functions as F, Window as W
df_subs_loc_movmnt_ts = df_subs_loc_movmnt.withColumn("new_ts",from_unixtime(unix_timestamp(col("ts"),"HH:mm:ss"),"HH:mm:ss"))
df_subs_loc_movmnt_ts.show(5)
Here's the output
+--------+---------------+--------+---------------+-------------+----+-----+---+--------+
| date_id| ts| subs_no| cgi| msisdn|year|month|day| new_ts|
+--------+---------------+--------+---------------+-------------+----+-----+---+--------+
|20200801|14:27:18.000000|10007239|510-11-542276-6|1034506190093|2022| 6| 1|14:27:18|
|20200801|14:29:44.000000|10054647|510-11-610556-5|7845838057779|2022| 6| 1|14:29:44|
|20200801|08:24:21.000000|10057750|510-11-542301-6| 570449692639|2022| 6| 1|08:24:21|
|20200801|13:49:27.000000|10019958|510-11-610206-6|6674433175670|2022| 6| 1|13:49:27|
|20200801|20:07:32.000000|10019958|510-11-611187-6|6674433175670|2022| 6| 1|20:07:32|
+--------+---------------+--------+---------------+-------------+----+-----+---+--------+
What I plan to do
w = W.partitionBy('date_id', 'subs_no', 'year', 'month', 'day').orderBy(new_ts)
df_subs_loc_movmnt_duration = df_subs_loc_movmnt_ts.withColumn('duration', F.regexp_extract(new_ts - F.min(new_ts).over(w), 0))
df_subs_loc_movmnt_duration = df_subs_loc_movmnt_duration.replace('00:00:00', 'first', 'duration')
What happen is
----> 1 w = W.partitionBy('date_id', 'subs_no', 'year', 'month', 'day').orderBy(new_ts)
`NameError: name 'new_ts' is not defined`
Note new_ts is avaliable in df_subs_loc_movmnt_ts not df_subs_loc_movmnt
Related
I am solving a problem using spark running in my local machine.
I am reading a parquet file from the local disk and storing it to the dataframe.
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
spark = SparkSession.builder\
.config("spark.driver.memory","4g")\
.config("spark.executor.memory","4g")\
.config("spark.driver.maxResultSize","2g")\
.getOrCreate()
content = spark.read.parquet('./files/file')
So, Content Dataframe contents around 500k rows i.e.
+-----------+----------+
|EMPLOYEE_ID|MANAGER_ID|
+-----------+----------+
| 100| 0|
| 101| 100|
| 102| 100|
| 103| 100|
| 104| 100|
| 105| 100|
| 106| 101|
| 101| 101|
| 101| 101|
| 101| 101|
| 101| 102|
| 101| 102|
. .
. .
. .
I write this code to provide each EMPLOYEE_ID an EMPLOYEE_LEVEL according to their hierarchy.
# Assign EMPLOYEE_LEVEL 1 WHEN MANAGER_ID is 0 ELSE NULL
content_df = content.withColumn("EMPLOYEE_LEVEL", when(col("MANAGER_ID") == 0, 1).otherwise(lit('')))
level_df = content_df.select("*").filter("Level = 1")
level = 1
while True:
ldf = level_df
temp_df = content_df.join(
ldf,
((ldf["EMPLOYEE_LEVEL"] == level) &
(ldf["EMPLOYEE_ID"] == content_df["MANAGER_ID"])),
"left") \
.withColumn("EMPLOYEE_LEVEL",ldf["EMPLOYEE_LEVEL"]+1)\
.select("EMPLOYEE_ID","MANAGER_ID","EMPLOYEE_LEVEL")\
.filter("EMPLOYEE_LEVEL IS NOT NULL")\
.distinct()
if temp_df.count() == 0:
break
level_df = level_df.union(temp_df)
level += 1
It's running, but very slow execution and after some period of time it gives this error.
Py4JJavaError: An error occurred while calling o383.count.
: java.lang.OutOfMemoryError: Java heap space
at scala.collection.immutable.List.$colon$colon(List.scala:117)
at scala.collection.immutable.List.$plus$colon(List.scala:220)
at org.apache.spark.sql.catalyst.expressions.String2TrimExpression.children(stringExpressions.scala:816)
at org.apache.spark.sql.catalyst.expressions.String2TrimExpression.children$(stringExpressions.scala:816)
at org.apache.spark.sql.catalyst.expressions.StringTrim.children(stringExpressions.scala:948)
at org.apache.spark.sql.catalyst.trees.TreeNode.withNewChildren(TreeNode.scala:351)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:595)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.TraversableLike$$Lambda$61/0x00000001001d2040.apply(Unknown Source)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:595)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1148)
at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1147)
at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:555)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1122)
at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1121)
at org.apache.spark.sql.catalyst.expressions.UnaryExpression.mapChildren(Expression.scala:467)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
I tried many solutions including increasing driver and executor memory, using cache() and persist() for dataframe also doesn't worked for me.
I am using Spark 3.2.1
Spark
Any help will be appreciated.
Thank you.
I figure out the problem. This error related to the mechanism of spark DAG, it use DAG lineage to track a series transformations, when the algorithms need to iterate, the lineage can grow fast and hit the limitation of memory. So break the lineage is necessary when implementing iteration algorithms.
There are mainly 2 ways: 1. add checkpoint. 2.recreate dataframe.
I modify my codes below, which just add a checkpoint to break the lineage and works for me.
epoch_cnt = 0
while True:
print('hahaha1')
print('cached df', len(spark.sparkContext._jsc.getPersistentRDDs().items()))
singer_pairs_undirected_ungrouped = singer_pairs_undirected.join(old_song_group_kernel,
on=singer_pairs_undirected['src'] == old_song_group_kernel['id'],
how='left').filter(F.col('id').isNull()) \
.select('src', 'dst')
windowSpec = Window.partitionBy("src").orderBy(F.col("song_group_id_cnt").desc())
singer_pairs_vote = singer_pairs_undirected_ungrouped.join(old_song_group_kernel,
on=singer_pairs_undirected_ungrouped['dst'] ==
old_song_group_kernel['id'], how='inner') \
.groupBy('src', 'song_group_id') \
.agg(F.count('song_group_id').alias('song_group_id_cnt')) \
.withColumn('song_group_id_cnt_rnk', F.row_number().over(windowSpec)) \
.filter(F.col('song_group_id_cnt_rnk') == 1)
singer_pairs_vote_output = singer_pairs_vote.select('src', 'song_group_id') \
.withColumnRenamed('src', 'id')
print('hahaha5')
new_song_group_kernel = old_song_group_kernel.union(singer_pairs_vote_output) \
.select('id', 'song_group_id').dropDuplicates().persist().checkpoint()
print('hahaha9')
current_kernel_cnt = new_song_group_kernel.count()
print('hahaha2')
old_song_group_kernel.unpersist()
print('hahaha3')
old_song_group_kernel = new_song_group_kernel
epoch_cnt += 1
print('epoch rounds: ', epoch_cnt)
print('previous kernel count: ', previous_kernel_cnt)
print('current kernel count: ', current_kernel_cnt)
if current_kernel_cnt <= previous_kernel_cnt:
print('Iteration done !')
break
print('hahaha4')
previous_kernel_cnt = current_kernel_cnt
I am new to Python and DataFrame. Here I am writing a Python code to run an ETL job in AWS Glue. Please find the same code snippet below.
test_DyF = glueContext.create_dynamic_frame.from_catalog(database="teststoragedb", table_name="testtestfile_csv")
test_dataframe = test_DyF.select_fields(['empid','name']).toDF()
now the above test_dataframe is of type pyspark.sql.dataframe.DataFrame
Now, I need to loop through the above test_dataframe. As far as I see, I could see only collect or toLocalIterator. Please find the below sample code
for row_val in test_dataframe.collect():
But both these methods are very slow and not efficient. I cannot use pandas as it is not supported by AWS Glue.
Please find the steps I am doing
source information:
productid|matchval|similar product|similar product matchval
product A|100|product X|100
product A|101|product Y|101
product B|100|product X|100
product C|102|product Z|102
expected result:
product |similar products
product A|product X, product Y
product B|product X
product C|product Z
This is the code I am writing
I am getting a distinct dataframe of the source with productID
Loop through this distinct data frame set
a) get the list of matchval for the product from the source
b) identify the similar product based on matchval filters
c) loop through to get the concatinated string ---> This loop using the rdd.collect is affecting the performance
Can you please share any better suggestion on what can be done?
please elaborate what logic you want to try it out. DF looping can be done via SQL approach or you can also follow below RDD approach
def my_function(each_record):
#my_logic
#loop through for each command.
df.rdd.foreach(my_function)
Added following code further based on your input
df = spark.read.csv("/mylocation/61250775.csv", header=True, inferSchema=True, sep="|")
seq = ['product X','product Y','product Z']
df2 = df.groupBy("productid").pivot("similar_product",seq).count()
+---------+---------+---------+---------+
|productid|product X|product Y|product Z|
+---------+---------+---------+---------+
|product B| 1| null| null|
|product A| 1| 1| null|
|product C| null| null| 1|
+---------+---------+---------+---------+
The final approach which match your requirement
df = spark.read.csv("/mylocation/61250775.csv", header=True, inferSchema=True, sep="|")
df.printSchema()
>>> df.printSchema()
root
|-- id: string (nullable = true)
|-- matchval1: integer (nullable = true)
|-- similar: string (nullable = true)
|-- matchval3: integer (nullable = true)
from pyspark.sql.functions import concat_ws
from pyspark.sql.functions import collect_list
dfx = df.groupBy("id").agg(concat_ws(",", collect_list("similar")).alias("Similar_Items")).select(col("id"), col("Similar_Items"))
dfx.show()
+---------+-------------------+
| id| Similar_Items|
+---------+-------------------+
|product B| product X|
|product A|product X,product Y|
|product C| product Z|
+---------+-------------------+
You can also use the MAP class. In my case, I was iterating through data and calculate hash for the full row.
import sys
from awsglue.transforms import *
from awsglue.utils import getResolvedOptions
from pyspark.context import SparkContext
from awsglue.context import GlueContext
from awsglue.job import Job
import hashlib
## #params: [JOB_NAME]
args = getResolvedOptions(sys.argv, ['JOB_NAME'])
sc = SparkContext()
glueContext = GlueContext(sc)
spark = glueContext.spark_session
job = Job(glueContext)
job.init(args['JOB_NAME'], args)
## #type: DataSource
## #args: [database = "load-test", table_name = "table_test", transformation_ctx = "datasource0"]
## #return: datasource0
## #inputs: []
datasource0 = glueContext.create_dynamic_frame.from_catalog(database = "load-test", table_name = "table_test", transformation_ctx = "datasource0")
def hash_calculation(rec):
md5 = hashlib.md5()
md5.update('{}_{}_{}_{}'.format(rec["funcname"],rec["parameter"],rec["paramtype"],rec["structure"]).encode())
rec["hash"] = md5.hexdigest()
print("looping the recs")
return rec
mapped_dyF = Map.apply(frame = datasource0, f = hash_calculation)
I am trying to get a row from my database, which contains multiple columns which are each paired with a unit id column, like so:
id|run_id|diesel_engine_installed_power|diesel_engine_installed_power_unit_id|pv_installed_power|pv_installed_power_unit_id|battery_capacity|battery_capacity_unit_id|
--|------|-----------------------------|-------------------------------------|------------------|--------------------------|----------------|------------------------|
1| | 300| 1| 200| 1| 1000| 4|
2| 484| 300| 1| 200| 1| 1000| 4|
To do so, I am trying to alias the various unit columns while querying them in SQLAlchemy:
diesel_engine_installed_power_MiscUnit = aliased(MiscUnit)
pv_installed_power_MiscUnit = aliased(MiscUnit)
battery_capacity_MiscUnit = aliased(MiscUnit)
mg_res = session.query(ProcRun, ProcMemoGridInput, diesel_engine_installed_power_MiscUnit, pv_installed_power_MiscUnit, battery_capacity_MiscUnit). \
with_entities(
ProcRun,
ProcMemoGridInput,
diesel_engine_installed_power_MiscUnit.codeLabel.label("diesel_engine_installed_power_MiscUnit"),
pv_installed_power_MiscUnit.codeLabel.label("pv_installed_power_MiscUnit"),
battery_capacity_MiscUnit.codeLabel.label("battery_capacity_MiscUnit")
). \
filter(ProcRun.id == ProcMemoGridInput.run_id). \
filter(ProcRun.id == 484). \
filter(ProcMemoGridInput.diesel_engine_installed_power_unit_id == diesel_engine_installed_power_MiscUnit.id). \
filter(ProcMemoGridInput.pv_installed_power_unit_id == pv_installed_power_MiscUnit.id). \
filter(ProcMemoGridInput.battery_capacity_unit_id == battery_capacity_MiscUnit.id). \
one()
It is based on this solution:
Usage of "aliased" in SQLAlchemy ORM
But it tells me that AttributeError: type object 'MiscUnit' has no attribute 'codeLabel'. I don't really understand what the difference is, from what I understand this is the same process for aliasing the MiscUnit ORM object.
Running into an error I think being caused by the Window Function.
When I apply this script and persist just a few sample rows it works fine however when I apply it to my whole dataset (only a few GB)
it fails with this bizarre error on the last step when trying to persist to hdfs ... the script works when I persist w/o the Window Function so the problem must be from that (I have around 325 feature columns running through the for loop).
Any idea what could be causing the problem? My goal is to just impute time series data via forward fill method on every variable in my dataframe.
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import Window
import sys
print(spark.version)
'2.3.0'
# sample data
df = spark.createDataFrame([('2019-05-10 7:30:05', '1', '10', '0.5', 'FALSE'),\
('2019-05-10 7:30:10', '2', 'UNKNOWN', '0.24', 'FALSE'),\
('2019-05-10 7:30:15', '3', '6', 'UNKNOWN', 'TRUE'),\
('2019-05-10 7:30:20', '4', '7', 'UNKNOWN', 'UNKNOWN'),\
('2019-05-10 7:30:25', '5', '10', '1.1', 'UNKNOWN'),\
('2019-05-10 7:30:30', '6', 'UNKNOWN', '1.1', 'NULL'),\
('2019-05-10 7:30:35', '7', 'UNKNOWN', 'UNKNOWN', 'TRUE'),\
('2019-05-10 7:30:49', '8', '50', 'UNKNOWN', 'UNKNOWN')], ["date", "id", "v1", "v2", "v3"])
df = df.withColumn("date", F.col("date").cast("timestamp"))
# imputer process / all cols that need filled are strings
def stringReplacer(x, y):
return F.when(x != y, x).otherwise(F.lit(None)) # replace with NULL
def forwardFillImputer(df, cols=[], partitioner="date", value="UNKNOWN"):
for i in cols:
window = Window\
.partitionBy(F.month(partitioner))\
.orderBy(partitioner)\
.rowsBetween(-sys.maxsize, 0)
df = df\
.withColumn(i, stringReplacer(F.col(i), value))
fill = F.last(df[i], ignorenulls=True).over(window)
df = df.withColumn(i, fill)
return df
df2 = forwardFillImputer(df, cols=[i for i in df.columns])
# errors here
df2\
.write\
.format("csv")\
.mode("overwrite")\
.option("header", "true")\
.save("test_window_func.csv")
Py4JJavaError: An error occurred while calling o13504.save.
: org.apache.spark.SparkException: Job aborted.
at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:224)
at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:154)
at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:104)
at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:102)
at org.apache.spark.sql.execution.command.DataWritingCommandExec.doExecute(commands.scala:122)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:80)
at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:80)
at org.apache.spark.sql.DataFrameWriter$$anonfun$runCommand$1.apply(DataFrameWriter.scala:654)
at org.apache.spark.sql.DataFrameWriter$$anonfun$runCommand$1.apply(DataFrameWriter.scala:654)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:654)
at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:273)
at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:267)
at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:225)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.StackOverflowError
at org.apache.spark.sql.execution.SparkPlan.prepare(SparkPlan.scala:200)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$prepare$1.apply(SparkPlan.scala:200)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$prepare$1.apply(SparkPlan.scala:200)
at scala.collection.immutable.List.foreach(List.scala:381)
possible working solution
def forwardFillImputer(df, cols=[], partitioner="date", value="UNKNOWN"):
window = Window \
.partitionBy(F.month(partitioner)) \
.orderBy(partitioner) \
.rowsBetween(-sys.maxsize, 0)
imputed_cols = [F.last(stringReplacer(F.col(i), value), ignorenulls=True).over(window).alias(i)
for i in cols]
missing_cols = [i for i in df.columns if i not in cols]
return df.select(missing_cols+imputed_cols)
df2 = forwardFillImputer(df, cols=[i for i in df.columns[1:]])
df2.printSchema()
root
|-- date: timestamp (nullable = true)
|-- id: string (nullable = true)
|-- v1: string (nullable = true)
|-- v2: string (nullable = true)
|-- v3: string (nullable = true)
df2.show()
+-------------------+---+---+----+-----+
| date| id| v1| v2| v3|
+-------------------+---+---+----+-----+
|2019-05-10 07:30:05| 1| 10| 0.5|FALSE|
|2019-05-10 07:30:10| 2| 10|0.24|FALSE|
|2019-05-10 07:30:15| 3| 6|0.24| TRUE|
|2019-05-10 07:30:20| 4| 7|0.24| TRUE|
|2019-05-10 07:30:25| 5| 10| 1.1| TRUE|
|2019-05-10 07:30:30| 6| 10| 1.1| NULL|
|2019-05-10 07:30:35| 7| 10| 1.1| TRUE|
|2019-05-10 07:30:49| 8| 50| 1.1| TRUE|
+-------------------+---+---+----+-----+
By the stacktrace provided I believe the error comes from preparation of the execution plan, as it says:
Caused by: java.lang.StackOverflowError
at org.apache.spark.sql.execution.SparkPlan.prepare(SparkPlan.scala:200)
I believe that the reason for that is because you call the method .withColumn twice in the loop. What .withColumn does in the Spark execution plan is basically a select statement of all columns with 1 column changed as specified in the method. If you have 325 columns, then for single iteration this will call select on 325 columns twice -> 650 columns passed into the planner. Doing this 325 times you can see how it can create an overhead.
However it is very interesting though that you do not receive this error for a small sample, I'd expect otherwise.
Anyway you can try replacing your forwardFillImputer like this:
def forwardFillImputer(df, cols=[], partitioner="date", value="UNKNOWN"):
window = Window \
.partitionBy(F.month(partitioner)) \
.orderBy(partitioner) \
.rowsBetween(-sys.maxsize, 0)
imputed_cols = [F.last(stringReplacer(F.col(i), value), ignorenulls=True).over(window).alias(i)
for i in cols]
missing_cols = [F.col(i) for i in df.columns if i not in cols]
return df.select(missing_cols + imputed_cols)
This way you basically just parse into planner a single select statement, which should be easier to handle.
Just as a warning, generally Spark doesn't do well with high number of columns, so you might see other strange issues along the way.
I need to pass a list into a UDF, the list will determine the score/category of the distance. For now, I am hard coding all distances to be the 4th score.
a= spark.createDataFrame([("A", 20), ("B", 30), ("D", 80)],["Letter", "distances"])
from pyspark.sql.functions import udf
def cate(label, feature_list):
if feature_list == 0:
return label[4]
label_list = ["Great", "Good", "OK", "Please Move", "Dead"]
udf_score=udf(cate, StringType())
a.withColumn("category", udf_score(label_list,a["distances"])).show(10)
when I try something like this, I get this error.
Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.util.ArrayList]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
at py4j.Gateway.invoke(Gateway.java:274)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:214)
at java.lang.Thread.run(Thread.java:745)
from pyspark.sql.functions import udf, col
#sample data
a= sqlContext.createDataFrame([("A", 20), ("B", 30), ("D", 80)],["Letter", "distances"])
label_list = ["Great", "Good", "OK", "Please Move", "Dead"]
def cate(label, feature_list):
if feature_list == 0:
return label[4]
else: #you may need to add 'else' condition as well otherwise 'null' will be added in this case
return 'I am not sure!'
def udf_score(label_list):
return udf(lambda l: cate(l, label_list))
a.withColumn("category", udf_score(label_list)(col("distances"))).show()
Output is:
+------+---------+--------------+
|Letter|distances| category|
+------+---------+--------------+
| A| 20|I am not sure!|
| B| 30|I am not sure!|
| D| 80|I am not sure!|
+------+---------+--------------+
Try currying the function, so that the only argument in the DataFrame call is the name of the column on which you want the function to act:
udf_score=udf(lambda x: cate(label_list,x), StringType())
a.withColumn("category", udf_score("distances")).show(10)
I think this may help by passing list as a default value of a variable
from pyspark.sql.functions import udf, col
#sample data
a= sqlContext.createDataFrame([("A", 20), ("B", 30), ("D", 80),("E",0)],["Letter", "distances"])
label_list = ["Great", "Good", "OK", "Please Move", "Dead"]
#Passing List as Default value to a variable
def cate( feature_list,label=label_list):
if feature_list == 0:
return label[4]
else: #you may need to add 'else' condition as well otherwise 'null' will be added in this case
return 'I am not sure!'
udfcate = udf(cate, StringType())
a.withColumn("category", udfcate("distances")).show()
Output:
+------+---------+--------------+
|Letter|distances| category|
+------+---------+--------------+
| A| 20|I am not sure!|
| B| 30|I am not sure!|
| D| 80|I am not sure!|
| E| 0| Dead|
+------+---------+--------------+