Home > Mobile >  Why do I get worse results in Java than in Python when using the same Tensorflow models?
Why do I get worse results in Java than in Python when using the same Tensorflow models?

Time:04-29

Introduction:

For education purpose I developed a Java class that enables students to load Tensorflow models in the Python demo classification

Java: My Java-based approach ends up in an image like this:

Java demo classification

I think that it is not bad, but it isn't perfect. With other models like e.g. Google's imagenet_mobilenet model I get similar results that are ok, but I suppose they are always a bit better when running online demos in Jupyter notebooks. I do not have more evidence - only a feeling. Im some cases the same image from the online demo is recognized as a different class - but not always. I might provide more data on that later.

Assumption and work done yet:

There might be an error in the data structures or algorithms on them in my Java code. I really searched the web for some weeks now, but I am unsure if my code really is precise, mainly as there are too few examples out there. E.g., I tried to change the order of RGB or the way it is calculated in the method that converts an image into a ND array. However, I saw no significant changes. Maybe the error is anywhere else. However, probably it is just as it is. If my code works well and is correct, that it is also ok for me - but I am still wondering why there are differences. Thanks for answers!

Code:

Here is a fully working example with two classes (I know, the Frame with the Panel drawing is bad - I coded this just fast for this example)

/**
 * 1. TensorFlow Core API Library: org.tensorflow -> tensorflow-core-api
 *      https://mvnrepository.com/artifact/org.tensorflow/tensorflow-core-api
 *          -> tensorflow-core-api-0.4.0.jar
 *      
 * 2.   additionally click "View All" and open:
 *      https://repo1.maven.org/maven2/org/tensorflow/tensorflow-core-api/0.4.0/
 *      Download the correct native library for your OS
 *          -> tensorflow-core-api-0.4.0-macosx-x86_64.jar
 *          -> tensorflow-core-api-0.4.0-windows-x86_64.jar
 *          -> tensorflow-core-api-0.4.0-linux-x86_64.jar 
 *      
 * 3. TensorFlow Framework Library:  org.tensorflow -> tensorflow-framework
 *      https://mvnrepository.com/artifact/org.tensorflow/tensorflow-framework/0.4.0
 *          -> tensorflow-framework-0.4.0.jar      
 *          
 * 4. Protocol Buffers [Core]: com.google.protobuf -> protobuf-java
 *      https://mvnrepository.com/artifact/com.google.protobuf/protobuf-java
 *          -> protobuf-java-4.0.0-rc-2.jar
 * 
 * 5. JavaCPP: org.bytedeco -> javacpp
 *      https://mvnrepository.com/artifact/org.bytedeco/javacpp
 *          -> javacpp-1.5.7.jar
 * 
 * 6. TensorFlow NdArray Library:  org.tensorflow -> ndarray
 *      https://mvnrepository.com/artifact/org.tensorflow/ndarray
 *          -> ndarray-0.3.3.jar
 */
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.IntNdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.TInt32;
import java.util.HashMap;
import java.util.Map;
import java.awt.image.BufferedImage;
import javax.imageio.ImageIO;
import java.awt.Color;
import java.io.File;
import javax.swing.JFrame;
import javax.swing.JButton;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.BorderLayout;

public class MoveNetDemo {

    private SavedModelBundle model;
    private String inputLayerName;
    private String outputLayerName;
    private String keyName;
    private BufferedImage image;
    private float[][] output;    
    private int width;
    private int height;

    public MoveNetDemo(String pFoldername, int pImageWidth, int pImageHeight) {
        width = pImageWidth;
        height = pImageHeight;

        model = SavedModelBundle.load(pFoldername, "serve");
        // Read input and output layer names from file
        inputLayerName = model.signatures().get(0).getInputs().keySet().toString();
        outputLayerName = model.signatures().get(0).getOutputs().keySet().toString();
        inputLayerName = inputLayerName.substring(1, inputLayerName.length()-1);
        outputLayerName = outputLayerName.substring(1, outputLayerName.length()-1);
        keyName = model.signatures().get(0).key();        
    }

    // not necessary here
    public String getModelInformation() { 
        String infos = "";
        for (int i=0; i<model.signatures().size(); i  ) {
            infos  = model.signatures().get(i).toString();
        }         
        return infos;
    }  

    public void setData(String pFilename) {
        image = null;
        try {
            image = ImageIO.read(new File(pFilename));            
        } 
        catch (Exception e) {          
        }
    }

    public BufferedImage getData() {
        return image;
    }

    private IntNdArray fillIntNdArray(IntNdArray pMatrix, BufferedImage pImage) {        
        try {
            int w = pImage.getWidth();
            int h = pImage.getHeight();                

            for (int i = 0; i < h; i  ) {
                for (int j = 0; j < w; j  ) {                 
                    Color mycolor = new Color(pImage.getRGB(j, i));
                    int red = mycolor.getRed();
                    int green = mycolor.getGreen();
                    int blue = mycolor.getBlue();
                    pMatrix.setInt(red, 0, j, i, 0);
                    pMatrix.setInt(green, 0, j, i, 1);
                    pMatrix.setInt(blue, 0, j, i, 2);                                       
                }
            }
        }
        catch (Exception e) {            
        }
        return pMatrix;        
    }

    public void run() {
        Map<String, Tensor> feed_dict = null;
        IntNdArray input_matrix = NdArrays.ofInts(Shape.of(1, width, height, 3));
        input_matrix = fillIntNdArray(input_matrix, image);            
        Tensor input_tensor = TInt32.tensorOf(input_matrix);
        feed_dict = new HashMap<>();
        feed_dict.put(inputLayerName, input_tensor); 
        Map<String, Tensor> res = model.function(keyName).call(feed_dict);                
        Tensor output_tensor = res.get(outputLayerName); 

        output = new float[17][3];
        for (int i= 0; i<17; i  ) {
            output[i][0] = output_tensor.asRawTensor().data().asFloats().getFloat(i*3)*256;                
            output[i][1] = output_tensor.asRawTensor().data().asFloats().getFloat(i*3 1)*256;                
            output[i][2] = output_tensor.asRawTensor().data().asFloats().getFloat(i*3 2);
        }
    }

    public float[][] getOutputArray() {
        return output;
    }

    public static void main(String[] args) {
        MoveNetDemo im = new MoveNetDemo("/Users/myname/Downloads/Code/TF_Test_04_NEW/movenet_singlepose_thunder_4", 256, 256);        
        im.setData("/Users/myname/Downloads/Code/TF_Test_04_NEW/test.jpeg");

        JFrame jf = new JFrame("TEST");
        jf.setSize(300, 300);
        jf.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        ImagePanel ip = new ImagePanel(im.getData());
        jf.add(ip, BorderLayout.CENTER);

        JButton st = new JButton("RUN");
        st.addActionListener(new ActionListener() { 
                public void actionPerformed(ActionEvent e) {
                    im.run();                            
                    ip.update(im.getOutputArray());
                    
                } 
            });        
        jf.add(st, BorderLayout.NORTH);

        jf.setVisible(true);
    }
}

and the ImagePanel class:

import javax.swing.JPanel;
import java.awt.image.BufferedImage;
import java.awt.Graphics;
import java.awt.Color;

public class ImagePanel extends JPanel {

    private BufferedImage image;
    private float[][] points;

    public ImagePanel(BufferedImage pImage) {        
        image = pImage;        
    }

    public void update(float[][] pPoints) {
        points = pPoints;
        repaint();
    }

    @Override
    protected void paintComponent(Graphics g) {                
        super.paintComponent(g);        
        g.drawImage(image, 0,0,null);
        g.setColor(Color.GREEN);
        if (points != null) {
            for (int j=0; j<17; j  ) {                            
                g.fillOval((int)points[j][0], (int)points[j][1], 5, 5);
            } 
        }
    }
}

CodePudding user response:

I found the answer. I mixed up height and width twice! No idea, why this behaves so strange (nearly correct but not perfect) but it works now.

In the Jupyter notebook it says:

input_image: A [1, height, width, 3]

so I changed the method fillIntArray to:

private IntNdArray fillIntNdArray(IntNdArray pMatrix, BufferedImage pImage) {        
        try {
            int w = pImage.getWidth();
            int h = pImage.getHeight();                

            for (int i = 0; i < h; i  ) {
                for (int j = 0; j < w; j  ) {                 
                    Color mycolor = new Color(pImage.getRGB(j, i));
                    int red = mycolor.getRed();
                    int green = mycolor.getGreen();
                    int blue = mycolor.getBlue();
                    pMatrix.setInt(red, 0, i, j, 0); // switched j and i 
                    pMatrix.setInt(green, 0, i, j, 1); // switched j and i 
                    pMatrix.setInt(blue, 0, i, j, 2); // switched j and i                                    
                }
            }
        }
        catch (Exception e) {            
        }
        return pMatrix;        
    }

and accordingly in the run()-method:

IntNdArray input_matrix = NdArrays.ofInts(Shape.of(1, height, width, 3));

In the Jupyter notebook you can toggle the helper functions for visualization and see that at first y and then x coordinates are taken. Height first, then width. Changing this in the ImagePanel class too, solves the problem and the classification is as expected and the same quality as in the online demonstration!

if (points != null) {
    for (int j=0; j<17; j  ) {                            
        // switched 0 and 1
        g.fillOval((int)points[j][1], (int)points[j][0], 5, 5);
    } 
}

Here it is: enter image description here

  • Related