catandog / train3.py
okeowo1014's picture
Update train3.py
dc92021 verified
raw
history blame contribute delete
No virus
2.45 kB
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Flatten, Dense
# Define data paths (modify as needed)
train_data_dir = 'train'
validation_data_dir = 'valid'
test_data_dir = 'valid'
# Set image dimensions (adjust if necessary)
img_width, img_height = 224, 224 # VGG16 expects these dimensions
# Data augmentation for improved generalization (optional)
train_datagen = ImageDataGenerator(
rescale=1./255, # Normalize pixel values
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
validation_datagen = ImageDataGenerator(rescale=1./255) # Only rescale for validation
# Load training and validation data
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=32, # Adjust batch size based on GPU memory
class_mode='binary' # Two classes: cat or dog
)
validation_generator = validation_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=32,
class_mode='binary'
)
# Load pre-trained VGG16 model (without the top layers)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))
# Freeze the base model layers (optional - experiment with unfreezing for fine-tuning)
base_model.trainable = False
# Add custom layers for classification on top of the pre-trained model
x = base_model.output
x = Flatten()(x)
predictions = Dense(1, activation='sigmoid')(x) # Sigmoid for binary classification
# Create the final model
model = tf.keras.Model(inputs=base_model.input, outputs=predictions)
# Compile the model
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
# Train the model
history = model.fit(
train_generator,
epochs=10, # Adjust number of epochs based on dataset size and validation performance
validation_data=validation_generator
)
# Evaluate the model on test data (optional)
test_generator = validation_datagen.flow_from_directory(
test_data_dir,
target_size=(img_width, img_height),
batch_size=32,
class_mode='binary'
)
test_loss, test_acc = model.evaluate(test_generator)
print('Test accuracy:', test_acc)
# Save the model for future use (optional)
model.save('cat_dog_classifier.keras')