|

How to preprocess images using PySpark?

Question Description: I got a project where I need to set up a proof-of-concept of a big data architecture (AWS S3 + SageMaker) for 1) pre-treat images using PySpark, 2) perform a PCA, and 3) train some machine or deep learning models. My issue is understanding how to manipulate image data using PySpark and could not provide satisfactory answers online.

So I think that any answer/hint can interest a broad audience of beginners like me. A similar thread remains unanswered here.

As follows you find what I have tried so far (using Python 3.8 on Jupyter Notebook):

  • Creating spark session with credentials to my AWS S3
from pyspark.sql import SparkSession
import sagemaker_pyspark
import botocore.session

session = botocore.session.get_session()
credentials = session.get_credentials()

conf = (SparkConf()**
        .set("spark.driver.extraClassPath", ":".join(sagemaker_pyspark.classpath_jars())))

spark = (
    SparkSession
    .builder
    .config(conf=conf) \
    .config('fs.s3a.access.key',  credentials.access_key)
    .config('fs.s3a.secret.key', credentials.secret_key)
    .appName("test")
    .getOrCreate()
  • Importing images from my S3 bucket
s3_url = "s3a://<MY_BUCKET>/dataset/*"
df = spark.read.format("image").load(s3_url)
print((df.count(), len(df.columns)))
print(df.printSchema())
df.select('image.nChannels', "image.width", "image.height", "image.data").show(truncate=True)

Output:

(60, 1)
root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)

None
+---------+-----+------+--------------------+
|nChannels|width|height|                data|
+---------+-----+------+--------------------+
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
+---------+-----+------+--------------------+
only showing top 20 rows

So I got the images as bytes in my df.data.

  • Trying to use pandas_udf to pass from bytes to arrays
import numpy as np
import io
from PIL import Image
from pyspark.sql.functions import pandas_udf, PandasUDFType
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.resnet50 import preprocess_input


@pandas_udf('array<float>', 'pyspark.sql.dataframe.DataFrame')
def preprocess(content):
    """
    Preprocesses raw image bytes for prediction.
    """
    img = Image.open(io.BytesIO(content))
    arr = img_to_array(img)
    return arr.flatten()


df_transformed = df.select(preprocess("image.data"))
type(df_transformed)
df_transformed.printSchema()
df_transformed.show()

Output:

root
 |-- preprocess(image.data): array (nullable = true)
 |    |-- element: float (containsNull = true)

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-24-e3999c55086a> in <module>
     20 type(df_transformed)
     21 df_transformed.printSchema()
---> 22 df_transformed.show()

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
    376         """
    377         if isinstance(truncate, bool) and truncate:
--> 378             print(self._jdf.showString(n, 20, vertical))
    379         else:
    380             print(self._jdf.showString(n, int(truncate), vertical))

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o433.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 12.0 failed 1 times, most recent failure: Lost task 0.0 in stage 12.0 (TID 16, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
    process()
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 283, in dump_stream
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 96, in <lambda>
    return lambda *a: (verify_result_length(*a), arrow_return_type)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 87, in verify_result_length
    result = f(*a)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-24-e3999c55086a>", line 14, in preprocess
TypeError: a bytes-like object is required, not 'Series'

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:98)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:121)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1887)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1875)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1874)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1874)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
    at scala.Option.foreach(Option.scala:257)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2108)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2057)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2046)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
    at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
    at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
    at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
    at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
    at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
    at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
    at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
    at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
    at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
    at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
    at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
    process()
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 283, in dump_stream
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 96, in <lambda>
    return lambda *a: (verify_result_length(*a), arrow_return_type)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 87, in verify_result_length
    result = f(*a)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-24-e3999c55086a>", line 14, in preprocess
TypeError: a bytes-like object is required, not 'Series'

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:98)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:121)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    ... 1 more

  • I managed to do so using ImageSchema but I am breaking the chains of df (using .collect() ) which is not suitable
from pyspark.ml.image import ImageSchema
#https://stackoverflow.com/questions/67705881/unable-to-read-images-simultaneously-in-parallels-using-pyspark
df = df.select('image.*')

# Pre-caching the required schema. If you remove this line an error will be raised.
ImageSchema.imageFields

# Transforming images to np.array
arrays = df.rdd.map(ImageSchema.toNDArray).collect()

img = np.array(arrays)
print(img.shape)

Output: (60, 100, 100, 3)

On top of that, I need to perform PCA to reduce image dims.

Expert Answer

try using ImageSchema and DenseVector inside a UDF and apply the function to the unpacked image column (struct format). The result would be in dense vector format of the images.

df = spark.read.format("image").load(url)
df.show()

# +--------------------+
# |               image|
# +--------------------+
# |[file:///content/...|
# |[file:///content/...|
# +--------------------+

import pyspark.sql.functions as F
from pyspark.ml.image import ImageSchema
from pyspark.ml.linalg import DenseVector, VectorUDT

ImageSchema.imageFields

img2vec = F.udf(lambda x: DenseVector(ImageSchema.toNDArray(x).flatten()), VectorUDT())

df = df.withColumn('vecs', img2vec("image"))
df.show()

# +--------------------+--------------------+
# |               image|                vecs|
# +--------------------+--------------------+
# |[file:///content/...|[255.0,255.0,255....|
# |[file:///content/...|[248.0,248.0,248....|
# +--------------------+--------------------+

Similar Posts