提问者:小点点

从pyspark中的dataframe ArrayType列中获取第一个N元素


我有一个spark数据框,其中行为-

1   |   [a, b, c]
2   |   [d, e, f]
3   |   [g, h, i]

现在我想只保留数组列中的前两个元素。

1   |   [a, b]
2   |   [d, e]
3   |   [g, h]

如何才能做到这一点?

注意-请记住,我在这里提取的不是单个数组元素,而是数组中可能包含多个元素的一部分。


共2个答案

匿名用户

下面是如何使用API函数。

假设您的DataFrame如下:

df.show()
#+---+---------+
#| id|  letters|
#+---+---------+
#|  1|[a, b, c]|
#|  2|[d, e, f]|
#|  3|[g, h, i]|
#+---+---------+

df.printSchema()
#root
# |-- id: long (nullable = true)
# |-- letters: array (nullable = true)
# |    |-- element: string (containsNull = true)

您可以使用方括号按索引访问letters列中的元素,并将其包装在调用pyspark中。sql。功能。array()创建一个新的ArrayType列。

import pyspark.sql.functions as f

df.withColumn("first_two", f.array([f.col("letters")[0], f.col("letters")[1]])).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

或者,如果要列出的索引太多,可以使用列表:

df.withColumn("first_two", f.array([f.col("letters")[i] for i in range(2)])).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

对于pyspark 2.4版,还可以使用pyspark。sql。功能。切片()

df.withColumn("first_two",f.slice("letters",start=1,length=2)).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

slice对于大型阵列可能具有更好的性能(请注意,起始索引为1,而不是0)

匿名用户

要么我的Pypark技能已经生锈了(我承认我现在已经不怎么磨练了),要么这确实是一个难题。。。我做到这一点的唯一方法是使用SQL语句:

spark.version
#  u'2.3.1'

# dummy data:

from pyspark.sql import Row
x = [Row(col1="xx", col2="yy", col3="zz", col4=[123,234, 456])]
rdd = sc.parallelize(x)
df = spark.createDataFrame(rdd)
df.show()
# result:
+----+----+----+---------------+
|col1|col2|col3|           col4|
+----+----+----+---------------+
|  xx|  yy|  zz|[123, 234, 456]|
+----+----+----+---------------+

df.createOrReplaceTempView("df")
df2 = spark.sql("SELECT col1, col2, col3, (col4[0], col4[1]) as col5 FROM df")
df2.show()
# result:
+----+----+----+----------+ 
|col1|col2|col3|      col5|
+----+----+----+----------+ 
|  xx|  yy|  zz|[123, 234]|
+----+----+----+----------+

对于未来的问题,最好按照建议的指导方针来制作可复制的ApacheSpark数据帧示例。