classification_neural: Questions and suggestions for revision
Open questions
1. Training times. CNN training on a laptop runs about 5–10 minutes for 5 epochs on 10,000 examples. On older hardware this could be much longer. Should I include pre-trained model weights so students can evaluate without training? Alternatively, should I reduce the default dataset size further (e.g., 5,000 examples) at the cost of lower peak accuracy?
2. KNN speed. KNN prediction on 10,000 training examples takes 1–3 minutes.
This is manageable but slow. Options: subsample training set further (faster,
lower accuracy), or use BallTree with a faster metric. Is waiting 2–3 minutes
acceptable in a classroom context?
3. A4.3.4 (clustering). The standards for this lab include A4.3.4 (clustering in unsupervised learning), but I did not include a clustering activity. The current lab already has a lot of content. Options: (a) drop A4.3.4 from the standards list and cover it elsewhere, (b) add a short K-means activity at the end (group similar digits without labels), (c) address it in discussion prompts only (as I've done currently). What do you prefer?
4. Feature engineering experience from Part 1. The lab asks students to design features for MNIST and observe that they fail. This is meant to motivate neural networks. But some students might find it discouraging rather than motivating. Is there a way to frame this so the "failure" feels productive?
5. PyTorch vs. sklearn MLPClassifier. The current lab uses PyTorch for MLP
and CNN. An alternative is to use sklearn.neural_network.MLPClassifier for
the MLP (simpler API, no GPU support, limited architecture options) and reserve
PyTorch for the CNN. This would reduce complexity in Part 3. Worth considering?
6. The devnote sketch mentioned clustering (A4.3.4) in the lab standards but the sketch only says "try to apply the feature-based techniques from the previous lab and get poor performance." I interpreted A4.3.4 as a secondary standard to cover in discussion, not a primary lab activity. Is that right?
Suggestions for improvement
-
Add a "misclassified digits" visualization. After training each classifier, show 9 digits that were misclassified (arranged in a 3×3 grid). This helps students see what each classifier struggles with and builds intuition for why CNNs help.
-
Add filter visualization for CNN. Show what the learned filters look like (8×8 grid of the first-layer kernels). This makes the "what does a convolutional filter detect?" question concrete. A few lines of matplotlib.
-
Kaggle MNIST vs. sklearn MNIST. The
fetch_openmldownload can be slow (~30 seconds) or fail on restricted networks. Consider including a fallback usingtorchvision.datasets.MNISTwhich caches the data differently. -
The comparison table at the end is the strongest piece of the lab pedagogically. Consider making it a live class artifact (e.g., a shared Google Sheet where each group contributes their best results) rather than just an individual submission.