I would like to apply VIT for image classification. But I have one problem and I don't know as resolve it. My error is this "KeyError: 'img'". The error is shown when I apply the last comand, and I don't know where is my error. The image within dataset are in .png, but I don't think that this was mistake. Below there is the script:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import os
import cv2
import matplotlib.pyplot as plt
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
from tensorflow import keras
from tensorflow.keras import layers
from datasets import load_metric
from PIL import Image as img
from IPython.display import Image, display
from datasets import load_dataset
import torch
dataset = load_dataset("imagefolder", data_dir="Datasets")
dataset
example = dataset["train"][10]
example
dataset["train"].features
example['image']
example['image'].resize((200, 200))
example['label']
dataset["train"].features["label"]
img_class_labels = dataset["train"].features["label"].names
from transformers import ViTFeatureExtractor
from tensorflow import keras
from tensorflow.keras import layers
model_id = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)
# learn more about data augmentation here: https://www.tensorflow.org/tutorials/images/data_augmentation
data_augmentation = keras.Sequential(
[
layers.Resizing(feature_extractor.size, feature_extractor.size),
layers.Rescaling(1./255),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.02),
layers.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# use keras image data augementation processing
def augmentation(examples):
# print(examples["img"])
examples["pixel_values"] = [data_augmentation(image) for image in examples["img"]]
return examples
# basic processing (only resizing)
def process(examples):
examples.update(feature_extractor(examples['img'], ))
return examples
# we are also renaming our label col to labels to use `.to_tf_dataset` later
dataset_ds = dataset["train"].rename_column("label", "labels")
processed_dataset = dataset_ds.map(augmentation, batched=True)
processed_dataset
CodePudding user response:
I guess the error is here:
def augmentation(examples):
# print(examples["img"])
examples["pixel_values"] = [data_augmentation(image) for image in examples["img"]]
return examples
You are trying to access 'examples' dictionary using 'img' key. From some code above it looks like the key should be 'image':
examples["pixel_values"] = [data_augmentation(image) for image in examples["image"]]