I have a dataframe like:
df = spark.createDataFrame([[1, '3.1.5'],
[2, '1.23.0'],
[3, '0.2.0'],
[4, None]], ['Row', 'Version'])
And I want to filter it for the upper versions of a determinate version. E.g v = '1.2.0'
Expected:
--- -------
|Row|Version|
--- -------
| 1| 3.1.5|
| 2| 1.23.0|
--- -------
I tried use the library packaging in a UDF
from packaging import version
from pyspark.sql import functions as F
from pyspark.sql import types as T
def version_parse(text):
try:
vers = version.parse(text)
except TypeError:
vers = None
return str(vers)
version_parse_udf = F.udf(version_parse, T.StringType())
df = df.filter(version_parse_udf(F.col('Version')) > version.parse('1.2.0'))
But it gives me the exception
AttributeError: 'Version' object has no attribute '_get_object_id'
CodePudding user response:
Following the comment of @yadavlpsir
def version_parse(vers: str, boundary_vers: str) -> bool:
try:
vers = version.parse(vers) >= version.parse(boundary_vers)
except TypeError:
vers = False
return vers
version_parse_udf = F.udf(version_parse, T.BooleanType())
df = df.filter(version_parse_udf(F.col('Version'), F.lit('1.2.0')))
It works to me
CodePudding user response:
Having this as the input:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[[1, '3.1.5'],
[2, '1.23.0'],
[3, '0.2.0'],
[4, None],
[5, '1.3.0']],
['Row', 'Version'])
The simple following filter would not work:
df.filter(F.col("Version") > "1.20.0").show()
# --- -------
# |Row|Version|
# --- -------
# | 1| 3.1.5|
# | 2| 1.23.0|
# | 5| 1.3.0|
# --- -------
You could make use of split
which divides the version major.minor.patch
into array of integers [major, minor, patch]
. Then the comparison would work.
def vers(v): return F.transform(F.split(v, "\."), lambda x: x.cast('int'))
df = df.filter(vers("Version") > vers(F.lit("1.20.0")))
df.show()
# --- -------
# |Row|Version|
# --- -------
# | 1| 3.1.5|
# | 2| 1.23.0|
# --- -------