Home > Back-end >  How can I centre/detect the digits for MNIST Handwritten Digit Prediction?
How can I centre/detect the digits for MNIST Handwritten Digit Prediction?

Time:11-24

I am producing a mobile app and in the first part of it, the user will have to take a photo of a sudoku grid and the computer will scan and read it, using my trained TensorFlow Model.

I have a big issue with the TensorFlow model though, it seems to not be very good at its job, and I think it's not the model's fault but rather that the tensors being sent don't have the digits centred.

Obviously, I don't expect 100% accuracy but especially for digits printed on[ the grid, I would expect better as around 20% of digits seem to be wrong.

Here is a picture of a 6 which the model predicted as an 8

img=<base64 string represeting 313*320 image of grid> 

[This is the base64 image at the top](https://i.stack.imgur.com/OrPis.jpg)

import cv2
import json
import numpy as np
import tensorflow as tf
import base64
import matplotlib.pyplot as plt

model = tf.keras.models.load_model("newmodel")
data = base64.b64decode(img)
np_data = np.fromstring(data, np.uint8)
img = cv2.imdecode(np_data, cv2.IMREAD_GRAYSCALE)
height, width = img.shape
img = cv2.resize(img, (width, width))
height, width = img.shape
print(height,width)
nums = []
for y in range(9):
    row = []
    for x in range(9):

            left = round(x*(width)/9 3)
            top = round(y*(height)/9 3)
            right = round((x 1)*(width)/9-3)
            bottom = round((y 1)*(height)/9-3)
            image = img[top:bottom, left:right]
            image = np.array(image)
            image = cv2.resize(image, (28,28))
            image = 255-image 

            #Checking for empty cells
            numofBlack = 0
            for r in image:
                for item in r:
                    if item > 127:
                        numofBlack  = 1
            if numofBlack < 50:
                row.append(0)

            else:
                pred = model.predict(image.reshape(1,28, 28, 1))
                row.append(int(pred.argmax()))

    nums.append(row)
print(nums)

The grid from the image above returned:

[[8, 3, 0, 0, 7, 0, 0, 0, 0], [8, 0, 0, 1, 9, 5, 0, 0, 0], [0, 9, 8, 0, 0, 0, 0, 6, 0], [8, 0, 0, 0, 6, 0, 0, 0, 8], [4, 0, 0, 8, 0, 3, 0, 0, 7], [7, 0, 0, 0, 2, 0, 0, 0, 6], [0, 8, 0, 0, 0, 0, 2, 8, 0], [0, 0, 0, 4, 1, 9, 0, 0, 5], [0, 0, 0, 0, 8, 0, 0, 7, 6]]

Mistaking one or two digits its fine with me because I can have a manual check after the image detection and I get that the model won't be perfect but the number of 8's that were read from digits that aren't 8 is suspicious and I feel like it might come from the digits being slightly off centred.

So, the question: Is there a library in python I can use to detect the digits in the cells rather than the current manual method with round(x*(width)/9 3) and ugly maths like that OR is the problem I'm facing and hence the solution, something completely different?

CodePudding user response:

Thanks to NickODell's helpful comments, my solution is going to be to find the centre of mass of each image and shift the images to centre them.

Here is the link to the python solution I'm using:

center of mass of pixels in grayscale image

  • Related