import io
import json
from typing import Any, Dict, Optional, Union
def _custom_default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True):
"""A custom decoder for webdataset. Support multiple images list decoding.
This handles common file extensions: .txt, .cls, .cls2,
.jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl.
These are the most common extensions used in webdataset.
For other extensions, users can provide their own decoder.
Args:
sample: sample, modified in place
"""
sample = dict(sample)
for key, value in sample.items():
extension = key.split(".")[-1]
if key.startswith("__"):
continue
elif extension in ["txt", "text"]:
sample[key] = value.decode("utf-8")
elif extension in ["cls", "cls2"]:
sample[key] = int(value.decode("utf-8"))
elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]:
import numpy as np
import PIL.Image
if format == "PIL":
sample[key] = PIL.Image.open(io.BytesIO(value))
else:
sample[key] = np.asarray(PIL.Image.open(io.BytesIO(value)))
elif extension in ["jpgs", "pngs", "ppms", "pgms", "pbms", "pnms"]:
import pickle
import numpy as np
import PIL.Image
value = pickle.loads(value)
if format == "PIL":
sample[key] = [PIL.Image.open(io.BytesIO(v)) for v in value]
else:
sample[key] = [np.asarray(PIL.Image.open(io.BytesIO(v))) for v in value]
elif extension == "json":
sample[key] = json.loads(value)
elif extension == "npy":
import numpy as np
sample[key] = np.load(io.BytesIO(value))
elif extension == "mp":
import msgpack
sample[key] = msgpack.unpackb(value, raw=False)
elif extension in ["pt", "pth"]:
import torch
sample[key] = torch.load(io.BytesIO(value))
elif extension in ["pickle", "pkl"]:
import pickle
sample[key] = pickle.loads(value)
return sample
def _custom_default_encoder(sample: Dict[str, Any], format: Optional[Union[str, bool]] = True):
"""A custom encoder for webdataset.
In addition to the original encoding, it also supports encode image lists and byte type images.
This handles common file extensions: .txt, .cls, .cls2, .jpg,
.png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl, .jpgs (images list),
.jpegs (images list), .pngs (images list).
These are the most common extensions used in webdataset.
For other extensions, users can provide their own encoder.
Args:
sample (Dict[str, Any]): sample
"""
from ray.data._internal.datasource.webdataset_datasource import extension_to_format
sample = dict(sample)
for key, value in sample.items():
extension = key.split(".")[-1]
if key.startswith("__"):
continue
elif extension in ["txt"]:
if isinstance(value, list):
sample[key] = [v.encode("utf-8") for v in value]
else:
sample[key] = value.encode("utf-8")
elif extension in ["cls", "cls2"]:
sample[key] = str(value).encode("utf-8")
elif extension in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]:
import numpy as np
import PIL.Image
if isinstance(value, np.ndarray):
value = PIL.Image.fromarray(value)
elif isinstance(value, bytes):
pass
else:
assert isinstance(value, PIL.Image.Image), f"{key} should be a PIL image, got {type(value)}"
stream = io.BytesIO()
value.save(stream, format=extension_to_format.get(extension.lower(), extension))
sample[key] = stream.getvalue()
elif extension in ["jpgs", "jpegs", "pngs", "ppms", "pgms", "pbms", "pnms"]:
import numpy as np
import PIL.Image
def _encode_image(value):
if isinstance(value, np.ndarray):
value = PIL.Image.fromarray(value)
elif isinstance(value, bytes):
return value
assert isinstance(value, PIL.Image.Image)
stream = io.BytesIO()
value.save(stream, format=extension_to_format.get(extension.lower(), extension))
return stream.getvalue()
import pickle
sample[key] = pickle.dumps([_encode_image(v) for v in value])
elif extension == "json":
sample[key] = json.dumps(value).encode("utf-8")
elif extension == "npy":
import numpy as np
stream = io.BytesIO()
np.save(stream, value)
sample[key] = stream.getvalue()
elif extension == "mp":
import msgpack
sample[key] = msgpack.dumps(value)
elif extension in ["pt", "pth"]:
import torch
stream = io.BytesIO()
torch.save(value, stream)
sample[key] = stream.getvalue()
elif extension in ["pickle", "pkl"]:
import pickle
stream = io.BytesIO()
pickle.dump(value, stream)
sample[key] = stream.getvalue()
return sample