Home > other >  Tensorflow tflite generated file is unavailable
Tensorflow tflite generated file is unavailable

Time:02-04

Model code:
 keep_prob=tf. Placeholder (dtype=tf float32, shape=[], name='keep_prob') 
Images=tf. Placeholder (dtype=tf float32, shape=[None, 64, 64, 1), name='image_batch')
Labels=tf. Placeholder (dtype=tf int64, shape=[None], name='label_batch')
Is_training=tf. Placeholder (dtype=tf bool, shape=[], name='train_flag')
With tf. Device ('/CPU: 0 ') :
With slim. Arg_scope ([slim. Conv2d, slim. Fully_connected],
Normalizer_fn=slim batch_norm,
Normalizer_params={' is_training: is_training}) :
Conv3_1=slim. Conv2d (images, 64, (3, 3), 1, padding='the SAME', the scope='conv3_1')
Max_pool_1=slim. Max_pool2d (conv3_1 [2, 2], [2, 2], the padding='the SAME', the scope='pool1')
Conv3_2=slim. Conv2d (max_pool_1, 128, (3, 3), padding='the SAME', the scope='conv3_2')
Max_pool_2=slim. Max_pool2d (conv3_2 [2, 2], [2, 2], the padding='the SAME', the scope='pool2')
Conv3_3=slim. Conv2d (max_pool_2, 256, (3, 3), padding='the SAME', the scope='conv3_3')
Max_pool_3=slim. Max_pool2d (conv3_3 [2, 2], [2, 2], the padding='the SAME', the scope='pool3')
Conv3_4=slim. Conv2d (max_pool_3, 512, (3, 3), padding='the SAME', the scope='conv3_4')
Conv3_5=slim. Conv2d (conv3_4, 512, (3, 3), padding='the SAME', the scope='conv3_5')
Max_pool_4=slim. Max_pool2d (conv3_5 [2, 2], [2, 2], the padding='the SAME', the scope='pool4')

Flatten=slim. Flatten (max_pool_4)
Fc1=slim. Fully_connected (slim. Dropout (flatten, keep_prob), 1024,
Activation_fn=tf. Nn. Relu, scope='fc1')
Logits=slim. Fully_connected (slim. Dropout (fc1, keep_prob), FLAGS, charset_size, activation_fn=None,
The scope='fc2)
Loss=tf. Reduce_mean (tf) nn) sparse_softmax_cross_entropy_with_logits (logits=logits, labels=labels))
Accuracy=tf. Reduce_mean (tf) cast (tf) equal (tf) argmax (logits, 1), labels), tf, float32))

Update_ops=tf. Get_collection (tf) GraphKeys) update_ops)
If update_ops:
Update_ops updates=tf group (*)
Loss=control_flow_ops. With_dependencies (/updates, loss)

Global_step=tf. Get_variable (" step ", [],, initializer.=tf constant_initializer (0.0), trainable=False)
Optimizer=tf. Train. AdamOptimizer (learning_rate=0.1)
Train_op=slim. Learning. Create_train_op (loss, the optimizer, global_step=global_step)
Probabilities=tf. Nn. Softmax (logits)

Tf. The summary. Scalar (' loss 'loss)
Tf. The summary. Scalar (' accuracy, accuracy)
Merged_summary_op=tf. The summary. Merge_all ()
Predicted_val_top_k, predicted_index_top_k=tf. Nn. Top_k (probabilities, k=top_k)
Accuracy_in_top_k=tf. Reduce_mean (tf) cast (tf) nn) in_top_k (probabilities, labels, top_k), tf, float32))


Generate pb files
 output_node_names="fc2/BatchNorm Reshape_1" 
Saver=tf. Train. Import_meta_graph (input_checkpoint + 'meta', clear_devices=False)
With tf. The Session () as sess:
Saver. Restore (sess, input_checkpoint) # recovery figure and get data
Persistent output_graph_def=graph_util. Convert_variables_to_constants (# model, to fix a variable's value
Sess=sess,
Input_graph_def=sess. Graph_def # is equal to: sess. Graph_def
Output_node_names=output_node_names. Split (", ")) #, if there are multiple output node by commas

All_output_node_names=[n.n ame for n in output_graph_def. Node]
Print (STR (all_output_node_names))

Pb_file1='./output_graph - 7. Pb '
With tf. Gfile. Gfile (pb_file1, "wb") as f: # save the model
F.w rite (output_graph_def SerializeToString ()) # serialization output



Generate tflite file:
 def get_tflite_file (pb_file_path) : 
Input_names=[" image_batch train_flag ", ""," keep_prob "]
Output_names=[" fc2/BatchNorm Reshape_1 "]
Input_tensor={" image_batch ": [1, float (64), the float (64), 1]," train_flag ": [True]," keep_prob ": [4]}
The convert=tf. Lite. TFLiteConverter. From_frozen_graph (pb_file_path input_arrays=input_names, output_arrays=output_names, input_shapes=input_tensor)
The convert. Post_training_quantize=True
The convert. Target_spec. Supported_ops=[tf. Lite. OpsSet. TFLITE_BUILTINS, tf, lite. OpsSet. SELECT_TF_OPS]
The convert. Allow_custom_ops=True
The convert. Default_ranges_stats=(0, 255)
Tflite_model=the convert. Convert ()
Open ("../bak/output_graph - 1. Tflite ", "wb"). The write (tflite_model)
  • Related