How to preprocess images using PySpark?
Question Description: I got a project where I need to set up a proof-of-concept of a big data architecture (AWS S3 + SageMaker) for 1) pre-treat images using PySpark, 2) perform a PCA, and 3) train some machine or deep learning models. My issue is understanding how to manipulate image data using PySpark and could not provide satisfactory answers online.
So I think that any answer/hint can interest a broad audience of beginners like me. A similar thread remains unanswered here.
As follows you find what I have tried so far (using Python 3.8 on Jupyter Notebook):
- Creating spark session with credentials to my AWS S3
from pyspark.sql import SparkSession
import sagemaker_pyspark
import botocore.session
session = botocore.session.get_session()
credentials = session.get_credentials()
conf = (SparkConf()**
.set("spark.driver.extraClassPath", ":".join(sagemaker_pyspark.classpath_jars())))
spark = (
SparkSession
.builder
.config(conf=conf) \
.config('fs.s3a.access.key', credentials.access_key)
.config('fs.s3a.secret.key', credentials.secret_key)
.appName("test")
.getOrCreate()
- Importing images from my S3 bucket
s3_url = "s3a://<MY_BUCKET>/dataset/*"
df = spark.read.format("image").load(s3_url)
print((df.count(), len(df.columns)))
print(df.printSchema())
df.select('image.nChannels', "image.width", "image.height", "image.data").show(truncate=True)
Output:
(60, 1)
root
|-- image: struct (nullable = true)
| |-- origin: string (nullable = true)
| |-- height: integer (nullable = true)
| |-- width: integer (nullable = true)
| |-- nChannels: integer (nullable = true)
| |-- mode: integer (nullable = true)
| |-- data: binary (nullable = true)
None
+---------+-----+------+--------------------+
|nChannels|width|height| data|
+---------+-----+------+--------------------+
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
| 3| 100| 100|[FF FF FF FF FF F...|
+---------+-----+------+--------------------+
only showing top 20 rows
So I got the images as bytes in my df.data.
- Trying to use pandas_udf to pass from bytes to arrays
import numpy as np
import io
from PIL import Image
from pyspark.sql.functions import pandas_udf, PandasUDFType
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.resnet50 import preprocess_input
@pandas_udf('array<float>', 'pyspark.sql.dataframe.DataFrame')
def preprocess(content):
"""
Preprocesses raw image bytes for prediction.
"""
img = Image.open(io.BytesIO(content))
arr = img_to_array(img)
return arr.flatten()
df_transformed = df.select(preprocess("image.data"))
type(df_transformed)
df_transformed.printSchema()
df_transformed.show()
Output:
root
|-- preprocess(image.data): array (nullable = true)
| |-- element: float (containsNull = true)
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-24-e3999c55086a> in <module>
20 type(df_transformed)
21 df_transformed.printSchema()
---> 22 df_transformed.show()
~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
376 """
377 if isinstance(truncate, bool) and truncate:
--> 378 print(self._jdf.showString(n, 20, vertical))
379 else:
380 print(self._jdf.showString(n, int(truncate), vertical))
~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
61 def deco(*a, **kw):
62 try:
---> 63 return f(*a, **kw)
64 except py4j.protocol.Py4JJavaError as e:
65 s = e.java_exception.toString()
~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling o433.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 12.0 failed 1 times, most recent failure: Lost task 0.0 in stage 12.0 (TID 16, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
process()
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 283, in dump_stream
for series in iterator:
File "<string>", line 1, in <lambda>
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 96, in <lambda>
return lambda *a: (verify_result_length(*a), arrow_return_type)
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 87, in verify_result_length
result = f(*a)
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "<ipython-input-24-e3999c55086a>", line 14, in preprocess
TypeError: a bytes-like object is required, not 'Series'
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:98)
at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1887)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1875)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1874)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1874)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2108)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2057)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2046)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
process()
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 283, in dump_stream
for series in iterator:
File "<string>", line 1, in <lambda>
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 96, in <lambda>
return lambda *a: (verify_result_length(*a), arrow_return_type)
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 87, in verify_result_length
result = f(*a)
File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "<ipython-input-24-e3999c55086a>", line 14, in preprocess
TypeError: a bytes-like object is required, not 'Series'
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:98)
at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more
- I managed to do so using ImageSchema but I am breaking the chains of df (using
.collect()
) which is not suitable
from pyspark.ml.image import ImageSchema
#https://stackoverflow.com/questions/67705881/unable-to-read-images-simultaneously-in-parallels-using-pyspark
df = df.select('image.*')
# Pre-caching the required schema. If you remove this line an error will be raised.
ImageSchema.imageFields
# Transforming images to np.array
arrays = df.rdd.map(ImageSchema.toNDArray).collect()
img = np.array(arrays)
print(img.shape)
Output: (60, 100, 100, 3)
On top of that, I need to perform PCA to reduce image dims.

Expert Answer
try using ImageSchema
and DenseVector
inside a UDF and apply the function to the unpacked image
column (struct format). The result would be in dense vector format of the images.
df = spark.read.format("image").load(url)
df.show()
# +--------------------+
# | image|
# +--------------------+
# |[file:///content/...|
# |[file:///content/...|
# +--------------------+
import pyspark.sql.functions as F
from pyspark.ml.image import ImageSchema
from pyspark.ml.linalg import DenseVector, VectorUDT
ImageSchema.imageFields
img2vec = F.udf(lambda x: DenseVector(ImageSchema.toNDArray(x).flatten()), VectorUDT())
df = df.withColumn('vecs', img2vec("image"))
df.show()
# +--------------------+--------------------+
# | image| vecs|
# +--------------------+--------------------+
# |[file:///content/...|[255.0,255.0,255....|
# |[file:///content/...|[248.0,248.0,248....|
# +--------------------+--------------------+