How to register custom object for the TFViTMainLayer in this model?
#2
by
mbluetail
- opened
I have a google's visual transformer model which I have trained in Tensorflow 2 and saved as an h5 file.
# Base model pre-trained on ImageNet-21k with the 224x224 image resolution
base_model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
# Freeze base model
base_model.trainable = False
# Create new model
inputs = keras.Input(shape = (3, 224, 224))
x = data_augmentation_vit(inputs)
vit = base_model.vit(inputs)[0]
vit = keras.layers.GlobalAveragePooling1D()(vit)
vit = tf.keras.layers.Dense(256, activation='relu')(vit)
vit = tf. keras.layers.Dropout(0.15)(vit)
outputs = tf.keras.layers.Dense(1, activation='sigmoid', name='outputs')(vit)
model_vit = tf.keras.Model(inputs, outputs)
print(model_vit.summary())
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 3, 224, 224)] 0
vit (TFViTMainLayer) TFBaseModelOutputWithPoo 86389248
ling(last_hidden_state=(
None, 197, 768),
pooler_output=(None, 76
8),
hidden_states=None, att
entions=None)
global_average_pooling1d (G (None, 768) 0
lobalAveragePooling1D)
dense_2 (Dense) (None, 256) 196864
dropout_37 (Dropout) (None, 256) 0
outputs (Dense) (None, 1) 257
=================================================================
So when I use the following code and run my app in Streamlit, it gives me this ValueError.
ValueError: Unknown layer: Custom>TFViTMainLayer. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
Traceback:
File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\streamlit\scriptrunner\script_runner.py", line 557, in _run_script
exec(code, module.__dict__)
File "app_extended.py", line 248, in <module>
model_loader = tf.keras.models.load_model(path_to_model)
File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\Users\maria\anaconda3\envs\tfenv\lib\site-packages\keras\utils\generic_utils.py", line 562, in class_and_config_for_serialized_keras_object
raise ValueError(
my code
if model_name == 'Vision Transformer(ViT)':
## Initialize tensorflow model
path_to_model = "C:/Users/maria/Jupiter_Notebooks/Dataset_Thermal_Project/Camera_videos/Saved_models/model_vit.h5"
model_loader = tf.keras.models.load_model(path_to_model)
model_vit = tf.keras.models.Model(model_loader.inputs, model_loader.outputs)
.....
Not sure how to register the custom object in my example after looking at this link, "https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object%20for%20details".
Need some help with this please?