Posts Tagged ‘TensorFlowJS’
[DevoxxUK2024] Is It (F)ake?! Image Classification with TensorFlow.js by Carly Richmond
Carly Richmond, a Principal Developer Advocate at Elastic, captivated the DevoxxUK2024 audience with her engaging exploration of image classification using TensorFlow.js. Inspired by her love for the Netflix show Is It Cake?, Carly embarked on a project to build a model distinguishing cakes disguised as everyday objects from their non-cake counterparts. Despite her self-professed lack of machine learning expertise, Carly’s journey through data gathering, pre-trained models, custom model development, and transfer learning offers a relatable and insightful narrative for developers venturing into AI-driven JavaScript applications.
Gathering and Preparing Data
Carly’s project begins with the critical task of data collection, a foundational step in machine learning. To source images of cakes resembling other objects, she leverages Playwright, a JavaScript-based automation framework, to scrape images from bakers’ websites and Instagram galleries. For non-cake images, Carly utilizes the Unsplash API, which provides royalty-free photos with a rate-limited free tier. She queries categories like reptiles, candles, and shoes to align with the deceptive cakes from the show. However, Carly acknowledges limitations, such as inadvertently including biscuits or company logos in the dataset, highlighting the challenges of ensuring data purity with a modest set of 367 cake and 174 non-cake images.
Exploring Pre-Trained Models
To avoid building a model from scratch, Carly initially experiments with TensorFlow.js’s pre-trained models, Coco SSD and MobileNet. Coco SSD, trained on the Common Objects in Context (COCO) dataset, excels in object detection, identifying bounding boxes and classifying objects like cakes with reasonable accuracy. MobileNet, designed for lightweight classification, struggles with Carly’s dataset, often misclassifying cakes as cups or ice cream due to visual similarities like frosting. CORS issues further complicate browser-based MobileNet deployment, prompting Carly to shift to a Node.js backend, where she converts images into tensors for processing. These experiences underscore the trade-offs between model complexity and practical deployment.
Building and Refining a Custom Model
Undeterred by initial setbacks, Carly ventures into crafting a custom convolutional neural network (CNN) using TensorFlow.js. She outlines the CNN’s structure, which includes convolution layers to extract features, pooling layers to reduce dimensionality, and a softmax activation for binary classification (cake vs. not cake). Despite her efforts, the model’s accuracy languishes at 48%, plagued by issues like tensor shape mismatches and premature tensor disposal. Carly candidly admits to errors, such as mislabeling cakes as non-cakes, illustrating the steep learning curve for non-experts. This section of her talk resonates with developers, emphasizing perseverance and the iterative nature of machine learning.
Leveraging Transfer Learning
Recognizing the limitations of her dataset and custom model, Carly pivots to transfer learning, using MobileNet’s feature vectors as a foundation. By adding a custom classification head with ReLU and softmax layers, she achieves a significant improvement, with accuracy reaching 100% by the third epoch and correctly classifying 319 cakes. While not perfect, this approach outperforms her custom model, demonstrating the power of leveraging pre-trained models for specialized tasks. Carly’s comparison of human performance—90% accuracy by the DevoxxUK audience versus her model’s results—adds a playful yet insightful dimension, highlighting the gap between human intuition and machine precision.