{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.imports import *\n",
"from fastai.torch_imports import *\n",
"from fastai.transforms import *\n",
"from fastai.conv_learner import *\n",
"from fastai.model import *\n",
"from fastai.dataset import *\n",
"from fastai.sgdr import *\n",
"from fastai.plots import *"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.set_device(0)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"PATH = \"intel/\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"train_csv = f'{PATH}train.csv'\n",
"n = len(list(open(train_csv))) - 1 # header is not counted (-1)\n",
"val_idxs = get_cv_idxs(n) # random 20% data for validation set"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fastai\tintel intel.ipynb intel1.ipynb s.csv t.py\r\n"
]
}
],
"source": [
"!ls"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"models\tsample.csv test test.csv tmp train\ttrain.csv train.zip\r\n"
]
}
],
"source": [
"!ls {PATH}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"train = pd.read_csv(train_csv)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" image_name | \n",
" label | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.jpg | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.jpg | \n",
" 4 | \n",
"
\n",
" \n",
" 2 | \n",
" 2.jpg | \n",
" 5 | \n",
"
\n",
" \n",
" 3 | \n",
" 4.jpg | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 7.jpg | \n",
" 4 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" image_name label\n",
"0 0.jpg 0\n",
"1 1.jpg 4\n",
"2 2.jpg 5\n",
"3 4.jpg 0\n",
"4 7.jpg 4"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.head()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" image_name | \n",
"
\n",
" \n",
" label | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2628 | \n",
"
\n",
" \n",
" 1 | \n",
" 2745 | \n",
"
\n",
" \n",
" 4 | \n",
" 2784 | \n",
"
\n",
" \n",
" 5 | \n",
" 2883 | \n",
"
\n",
" \n",
" 2 | \n",
" 2957 | \n",
"
\n",
" \n",
" 3 | \n",
" 3037 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" image_name\n",
"label \n",
"0 2628\n",
"1 2745\n",
"4 2784\n",
"5 2883\n",
"2 2957\n",
"3 3037"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.pivot_table(index='label',aggfunc=len).sort_values(by='image_name')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"arch = resnet34\n",
"#arch = resnext101\n",
"bs = 64\n",
"\n",
"def get_data(sz, bs): # sz: image size, bs: batch size\n",
" tfms = tfms_from_model(arch, sz, aug_tfms=transforms_side_on, max_zoom=1.1)\n",
" data = ImageClassifierData.from_csv(PATH, 'train', f'{PATH}train.csv', test_name='test',\n",
" val_idxs=[0], tfms=tfms, bs=bs)\n",
" \n",
" return data "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import shutil"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" image_name | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 3.jpg | \n",
"
\n",
" \n",
" 1 | \n",
" 5.jpg | \n",
"
\n",
" \n",
" 2 | \n",
" 6.jpg | \n",
"
\n",
" \n",
" 3 | \n",
" 11.jpg | \n",
"
\n",
" \n",
" 4 | \n",
" 14.jpg | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" image_name\n",
"0 3.jpg\n",
"1 5.jpg\n",
"2 6.jpg\n",
"3 11.jpg\n",
"4 14.jpg"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test = pd.read_csv(f'{PATH}test.csv');test.head()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"\\nfor r in test.iterrows():\\n shutil.copy2((f'{PATH}train/' +r[1].values[0]), (f'{PATH}test/' +r[1].values[0]))\\n #print (f'{PATH}train/' +r[1].values[0])\\n #print(type(r[1].values[0]))\\n\""
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'''\n",
"for r in test.iterrows():\n",
" shutil.copy2((f'{PATH}train/' +r[1].values[0]), (f'{PATH}test/' +r[1].values[0]))\n",
" #print (f'{PATH}train/' +r[1].values[0])\n",
" #print(type(r[1].values[0]))\n",
"'''"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"sz = 100\n",
"data = get_data(sz, bs)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"learn = ConvLearner.pretrained(arch, data, precompute=True)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "47746d3aeea54e76ab4e86187f751d10",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=1, style=ProgressStyle(description_width='initial…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 83%|████████▎ | 221/267 [00:03<00:00, 59.53it/s, loss=1.37] \n",
" \r"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEOCAYAAABmVAtTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd8VFX6x/HPk56QAiSEkgCh11ACSBFdVFTEAoqdRUVc7G3dVX/u2nd11dVdBEQRWWxYFgt2VCyAgBB6LwkloSWE9F7O748Zslk2JBPInTszed6v17wyc+fO3G8YkifnnnvOEWMMSimlFICf3QGUUkp5Di0KSimlqmlRUEopVU2LglJKqWpaFJRSSlXToqCUUqqaFgWllFLVtCgopZSqpkVBKaVUNS0KSimlqgXYHaChYmJiTEJCgt0xlFLKq6xZs+aoMaZVfft5XVFISEggOTnZ7hhKKeVVRGSfK/vp6SOllFLVtCgopZSqpkVBKaVUNS0KSimlqmlRUEopVU2LglJKqWpaFJRSygt8t/UIuzMKLD+OFgWllPJwxhhuf2cNH61Nt/xYWhSUUsrD5RVXUFFliG4WZPmxtCgopZSHO1pYCkBMeLDlx9KioJRSHi6roAyA6HBtKSilVJOXVeBoKUQ305ZCo8kqKOWFRdupqKyyO4pSSjXI0UJHSyHGm1sKIjJXRDJEZPNJno8Skc9FZIOIbBGRyVZlAVieksXMH1N4+YfdVh5GKaUa3fGWQksv72ieB4yp4/k7ga3GmP7AKOBFEbHsO760fzuuHBTP9B92sTzlqFWHUUqpRpdVUEaLsEAC/K0/uWPZEYwxS4Bjde0CRIiIAOHOfSusygPw1Lg+dIppxn3vr6+uvEop5emyCkuJdsOVR2Bvn8IMoBdwENgE3GuMsfSEf1hQADOvTyKnuJzJ81ZrYVBKeYWj+WVuGaMA9haFC4H1QDtgADBDRCJr21FEpopIsogkZ2ZmntZBe7WNZNbEJHYczufKV1eQdqzotN5PKaWsdrSw1C1jFMDeojAZ+Ng47Ab2AD1r29EYM9sYM9gYM7hVq3qXGK3Xeb1aM/93QzlWWMaEWcs5lFt82u+plFJWySooc8sYBbC3KOwHzgMQkdZADyDVXQcf1LElH9w6jMLSCu54dy2lFZXuOrRSSrmsrKKK3OJyt4xRAGsvSX0PWAH0EJF0EZkiIreJyG3OXZ4GRojIJmAx8JAxxq2XBfVsE8kLV/Vn3f4cnvp8qzsPrZRSLskuct9oZoAAq97YGHNdPc8fBC6w6viuGpvYllvP7sxrS1LpFx/FNUM62B1JKaWqHS1w37xH0IRGNNfljxf24KxuMfzpk82sSMmyO45SSlU7Pu+RO0YzgxYFAAL8/ZhxfRIJMc247Z01pGZav5CFUkq5Iss5Q2pTGKfgUaJCA5l74xD8/YRb3kqmpFw7npVS9nPnDKmgReG/dIgO4+VrB5KaWcg/v99ldxyllOJoQRlB/n5EBFvWBfxftCicYGS3GK4Z3J7Xl6ay+UCu3XGUUk3c0YJSosODcMwIZD0tCrV4ZGwvWjYL4sEFGynXqbaVUjbKchYFd9GiUIuosECeHteXrYfyeHmxnkZSStknq7DMbQPXQIvCSY3p24YrB8Uz48fd/LJbp9pWStkjq6DMbWMUQItCnZ4a14fOMc2474P1ZObrjKpKKfcyxnC0oNRtYxRAi0KdwoICmDkxibzicu55b53Oj6SUcqvCskpKK6q0T8GT9GwTyd8mJLIiNYt731uvazwrpdzm+Jov2qfgYS4fGM9jl/Tmmy2HeeSTTRhj7I6klGoCjrp54BpYOCGer7l5ZCdyist5efEuusaGM/XsLnZHUkr5uCw3T4YH2lJokPtHd2NsYhue+2YHq/bUtfy0UkqdPjtaCloUGkBEeG5CPzq0DOOu+Wv1iiSllKWO/47RPgUPFhESyCsTk8gtLufu99Zqx7NSyjKZBSW0bBZEUID7flVrUTgFvdpG8szliaxMPcbzi3bYHUcp5aMy8kpp5cb+BNCicMomDIpn0rCOzF6SypcbD9kdRynlgzILSmkVoUXBazx6SW+SOjTnjws2sC+r0O44Sikfk5FXSqwWBe8RFOBYsc3fT3jgww1UVun4BaVU4zDGaEvBG7VrHsqTl/UheV82byxLtTuOUspH5JVUUFZRpUXBG10+MI4Lerfm74t2svNIvt1xlFI+IDO/BECLgjcSEZ65IpGIkADufX+9TpynlDptGc4xCloUvFRMeDDPTejHtkN5vPjtTrvjKKW83PGBa7ERIW49rhaFRjS6d2smDu3A7CWpujCPUuq0ZGpLwTf8+eLedG7VjAc+3EBeSbndcZRSXiozv5SgAD8iQ9w7b6kWhUYWGuTPP64eQEZ+Cc9+td3uOEopL5WR7xijICJuPa4WBQv0b9+cW87qzHur9rM8RU8jKaUaLjPf/WMUwMKiICJzRSRDRDbXsc8oEVkvIltE5Gerstjh/tHdSYgO4+GPNlFcplcjKaUaJiO/xO2jmcHalsI8YMzJnhSR5sArwGXGmD7AVRZmcbvQIH/+NqEf+48V8Zcvt9odRynlZXyupWCMWQLUtRLN9cDHxpj9zv0zrMpil2Gdo7n17M68++t+nTRPKeWysooqsovKaRXu3stRwd4+he5ACxH5SUTWiMgNNmaxzAMX9KB/++Y8/PFG0o4V2R1HKeUFjjqX4YyN9KGWggsCgEHAxcCFwKMi0r22HUVkqogki0hyZmamOzOetqAAP2ZcNxCAu+av1dHOSql6VY9RcPNaCmBvUUgHvjHGFBpjjgJLgP617WiMmW2MGWyMGdyqVSu3hmwM7VuG8cKV/dmQnssTn2n/glKqbsenuGhqLYWFwFkiEiAiYcBQYJuNeSw1pm8b7hjVhfdW7ef9VfvtjqOU8mB2jWYGxykcS4jIe8AoIEZE0oHHgUAAY8yrxphtIvINsBGoAuYYY056+aoveOCCHmw6kMtjC7fQNTacwQkt7Y6klPJAGc4ZUqOb+VBRMMZc58I+LwAvWJXB0/j7CS9fO5AJs5Yz5c1kFtw2nG6tI+yOpZTyMJn5pbRsFkRQgPtP5uiIZjdr0SyIN28+g6AAP26cu4pDucV2R1JKeZjM/FJbOplBi4It2rcMY97kIeSVVHDzvGSKyirsjqSU8iAZNg1cAy0KtunTLooZ1w9kx+E8HvhwA1W6vrNSCqiorCIlo4AO0WG2HF+Lgo1G9YjlkbG9+HrzYab/sNvuOEopD7DlYB75pRUM7WTPhShaFGw2ZWQnJiTF84/vd7IiJcvuOEopm61IdfweGN452pbja1GwmYjwl/F96dAyjD99somSch3xrFRTtiIliy6tmhEb6f55j0CLgkcIDfLnr5f3JfVoIa/8qKeRlGqqyiurWL33GCO6xNiWQYuChzirWyuuGBjHrJ9T2Hkk3+44SikbbEzPpaiskuFd7Dl1BFoUPMqfLu5FREgg972/XifOU6oJWunsTxhmU38CaFHwKNHhwTw/oR9bD+Xp+s5KNUErUrLo2SaCls2CbMugRcHDjO7dmslnJjBv+V6+23rE7jhKKTcpragked8xW1sJoEXBIz18UU/6tIvkjws2cDi3xO44Sik32Hwgl5LyKs8vCiLSTET8nPe7i8hlIhJofbSmKzjAn+nXDaSsoor7P1hPpY52VsrnpWQWAtCrrb2TZLrSUlgChIhIHLAYmAzMszKUgs6twnnisj6sSM3itSUpdsdRSlksPbsYP4F2zUNtzeFKURBjTBFwBTDdGHM50NvaWArgqkHxXNyvLS99u5MNaTl2x1FKWSj9WBFto0IJ9Lf3rL5LRUFEhgMTgS+d2yxbh0H9h4jwzOWJxIQH89BHGymvrLI7klLKImnZRcS3sLeVAK4VhfuA/wM+McZsEZHOwI/WxlLHRYUG8vT4vmw/nM/sJal2x1FKWSTtWDHxLeyZGbWmeouCMeZnY8xlxpjnnB3OR40x97ghm3I6v3drxia2YdriXaRmFtgdRynVyEorKjmSX0L7ll7QUhCR+SISKSLNgK3ADhH5o/XRVE1PXNqH4AA/Hv54k16NpJSPOZhTgjHQ3htaCkBvY0weMB74CugATLI0lfofsZEhPHpJb1btOcbrS/U0klK+JO1YEeBYldFurhSFQOe4hPHAQmNMOaB/qtrgqkHxXNS3DS9+u4NN6bl2x1FKNZK0bEdR8JaO5teAvUAzYImIdATyrAylaiciPHtFItHNgrn3/XW6trNSPiLtWDGB/kJrm9ZQqMmVjuaXjTFxxpixxmEfcI4bsqlaNA8L4qVr+rMnq5A/f7IZY7TRppS3S88uIq55KP5+YncUlzqao0TkJRFJdt5exNFqUDYZ0SWGe8/rxsfrDvD+6jS74yilTlNadrFH9CeAa6eP5gL5wNXOWx7wLytDqfrdfW43zuoWw+OfbWHzAe1fUMqbpR/zjIFr4FpR6GKMedwYk+q8PQl0tjqYqpu/n/DPawbQMiyIO+evpaBU+xeU8kaFpRVkFZZ5xMA1cK0oFIvIyOMPRORMoNi6SMpV0eHBTLt2AGnHinh84Ra74yilTkF6tuPXqTedProdmCkie0VkHzADuM3aWMpVQztHc9c5XflobToL1x+wO45SqoHSnZejtveW00fGmPXGmP5APyDRGDPQGLOhvteJyFwRyRCRzfXsN0REKkXkStdjq5ruOa8bgzq24M+fbGbP0UK74yilGsCTBq5BHbOdisjvT7IdAGPMS/W89zwcrYq36jiGP/AcsKie91J1CPD3Y9q1A7hk+jKmvpXMJ3eeSXiwTmSrlDdIyy4mNNCfaBvXZa6prpZCRD23OhljlgDH6tntbuAjIMOVsOrk4luEMeO6JFIyC3jgw/VU6fxISnmFvUcLiW8RWv0Ht91O+uek8yojyzhXcrscOBcYYuWxmoqR3WJ4ZGwv/vLlNqb/sJt7R3ezO5JSqg4FpRUs232Uqwe3tztKNTuX+Pkn8JAxprK+HUVk6vHBc5mZmW6I5r2mjOzEFQPj+Mf3O7XjWSkPt2jzYUorqhg/sJ3dUarZeeJ5MPC+s8kUA4wVkQpjzKcn7miMmQ3MBhg8eLCeF6mDiPDshETSc4r547830q55KEMSWtodSylVi4UbDhLfIpSkDi3sjlLNtpaCMaaTMSbBGJMALADuqK0gqIYLDvBn9qRBxLcIZepbyezVK5KU8jiZ+aUs25XJuAHtPKY/AVyb+yhYRK4XkUdE5LHjNxde9x6wAughIukiMkVEbhMRHePgBs3Dgph7k6Or5uZ5q8kpKrM5kVKqpi83HqTKwLgBcXZH+S+unD5aCOQCa4BSV9/YGHNdA/a9ydV9lesSYpox+4bBTHz9V259ew1vTxlKUICd3UhKqeMWbjhIr7aRdG9d78WcbuVKUYg3xoyxPImyxJCElrxwVT/ufX89t7+zhhnXJxEa5G93LKWatPTsItbtz+GhMT3tjvI/XPmzcbmIJFqeRFlm3IA4nh7flx92ZDDpjV/JLSq3O5JSTdrylCwAzusVa3OS/+VKURgJrBGRHSKyUUQ2ichGq4OpxjVpWEdmXJfExvRcrp+zktKKeq8EVkpZZGVKFtHNgugWG253lP/hyumjiyxPodzi4n5tCfQXpr69hn9+v8sjm65K+TpjDCtSsxjWOdqjrjo6zpUJ8fYBzYFLnbfmzm3KC13Qpw1XD47ntZ9TWLs/2+44SjU5+7KKOJRbwrAu0XZHqZUrl6TeC7wLxDpv74jI3VYHU9b58yW9aRMZwh/+vYGScj2NpJQ7rUx19CcM7+ylRQGYAgw1xjxmjHkMGAb8ztpYykqRIYE8f2V/UjML+dMnmzFGB4kr5S4rUrNoFRFMl1aeudS9K0VBgJp/TlY6tykvNrJbDPeN7sZHa9N5Y9keu+Mo1SQYY1iR4rn9CeBaR/O/gF9F5BPn4/HAG9ZFUu5yz7nd2HYoj2e+2kb31hGc3b2V3ZGU8mmpRwvJyC/12FNH4FpH80vAZBxrI2QDk40x/7Q6mLKen5/w0tUD6N46gnveX8eBHF16WykrVfcneGgnM9RRFEQk0vm1JbAXeAd4G9jn3KZ8QLPgAGb9dhAVlYY7311LWUWV3ZGU8lk/bs+gXVQICdGesfRmbepqKcx3fl0DJNe4HX+sfESnmGa8cGU/1qfl8OzX2+yOo5RPyi0q5+edmYxNbOux/QlQ98prlzi/dnJfHGWXixLbMvnMBP71y14GtG/ucTM3KuXtFm09THml4dL+nrOgTm1cGaew2JVtyvv930W9OKNTSx5csJENaTl2x1HKp3y+4SAdWobRLz7K7ih1qqtPIcTZdxAjIi1EpKXzlgB4dqlTpyQowI9ZE5OICQ9m6tvJZOSV2B1JKZ9wtKCU5SlZXNLPs08dQd0thVtx9B/0dH49flsIzLQ+mrJDdHgwc24cTH5JBTe/uZr8Ep1RVanT9fXmw1RWef6pI6ijKBhjpjn7E/5gjOnsXD6zkzGmvzFmhhszKjfr1TaSmdcnsf1QPlPfWqNTYSh1mr7YcJCuseH0bONZC+rUxpVxCtNFpK+IXC0iNxy/uSOcss85PWP5+1X9WZGaxf0frKeqSqfCUOpU5JeUs2rvMcb2bePxp47AtY7mx4Hpzts5wPPAZRbnUh5g/MA4/nxxL77efJhpi3fZHUcpr7T1YB7GwMAOLeyO4hJX5j66EjgPOGyMmQz0B4ItTaU8xpSRnZiQFM+0xbv4YfsRu+Mo5XW2HMwDoE+7SJuTuMaVolBsjKkCKpyjnDOAztbGUp5CRPjr5X3p3TaS+95fz76sQrsjKeVVthzMIyY8iFYR3vG3tCtFIVlEmgOv47j6aC2wytJUyqOEBPrz2qRB+PkJ17/+K2nHiuyOpJTX2Hooj97toryiPwFc62i+wxiTY4x5FTgfuNF5Gkk1Ie1bhvHOlKEUlFZw7eyVWhiUckFpRSW7juR7zakjqHvwWtKJN6AlEOC8r5qYvnFRvHvLUArLKrjmtRV6Kkmpeuw6UkBFlfGNogC86LzNBH4FZuM4hfQr8LL10ZQn6hsXxfxbhlFcXsk1r60kNbPA7khKeawtB3MB6NPOs6e2qKmuwWvnGGPOAfYBScaYwcaYQcBAYLe7AirP07tdJO9NHUZ5ZRXXzl7JnqPaYlCqNlsO5hEeHEDHlp47VfaJXOlo7mmM2XT8gTFmMzDAukjKG/RsE8n7U4dRUWW4RafDUKpWWw7m0attBH5+3tHJDK4VhW0iMkdERonIb0TkdUAn3Vd0ax3BzOuT2JtVpKOelTpBZZVh26E8rzp1BK4VhcnAFuBe4D5gq3NbnURkrohkiMjmkzw/UUQ2Om/LRaR/Q4IrzzC8SzSPXtyL77dl8OJ3O+yOo5TH2JtVSFFZJb29qJMZ6lhk5zhjTAnwD+etIeYBM4C3TvL8HuA3xphsEbkIR0f20AYeQ3mAG0cksP1wPjN/TCHI3597R3ezO5JSttt84Hgns48UBRH50BhztYhsAv7nvIAxpl9db2yMWeJce+Fkzy+v8XAlEF9vWuWRHKOeEymrrOIf3++kyhjuG93NawbrKNXYcorKeGHRDtpFhdC9tefPjFpTXS2Fe51fL3FDjinA1244jrKIv5/wwpX98Rdh2uJdFJVV8MjYXloYVJNTWWW45/31ZOSV8uFtwwn0d+Usveeoa43mQ86v+6wMICLn4CgKI+vYZyowFaBDhw5WxlGnwd9PeG5CP8KC/Hl96R5yisp59opEArzsh0Kp0zFt8S6W7MzkmcsTGdC+ud1xGqyu00f51HLaCBDAGGNO+0SZiPQD5gAXGWOyTrafMWY2jj4HBg8erJe4eDA/P+GJy/rQPCyIaYt3UVhWwcvXDtTCoJqE8soq5ixN5eLEtlx3Rnu745ySuloKlp4IE5EOwMfAJGPMTiuPpdxLRLj//O6EBwfw16+2ERKwkb9f1d+rrtVW6lRsOZhHUVklFyV6x4I6tan36qPjRCQWCDn+2Bizv5793wNGATEikg48DgQ6X/sq8BgQDbzi/MerMMYMbmB+5cF+d3ZnSsorefG7nYQG+fOX8X299gdFKVes3nMMgDMSWtqc5NTVWxRE5DIccyC1w7GWQkccg9f61PU6Y8x19Tx/C3CLy0mVV7rr3K4UllXy6s8ptIoI5r7R3e2OpJRlVu09RsfoMGIjQ+rf2UO5cqL3aWAYsNMY0wnHKmy/WJpK+QwR4aExPZiQFM8/v9/Fh8lpdkdSyhJVVYbkvccY4sWtBHCtKJQ7O4H9RMTPGPMjOveRagAR4dkrEhnZNYZHPt7Ev5PTdEoM5XNSMgvILir36lNH4FpRyBGRcGAJ8K6ITAMqrI2lfE1QgB+zfptEYnwUf1ywkYunL2Pprky7YynVaFbvzQZgSCffLwrjgCLgfuAbIAW41MpQyjdFhATy0W0jmHbtAApLK7hh7iq+3XLY7lhKNYrVe48REx5EQrT3TJNdG1eKwlSgnTGmwhjzpjHm5brGFChVFz8/YdyAOBbddzb94ptzz/vr2JCWY3cspU7bqj2O/gRvv8LOlaIQCSwSkaUicqeItLY6lPJ9oUH+zLlhMK0igpny5mpd81l5tYM5xRzIKfb6TmZwoSgYY540xvQB7sRxWerPIvK95cmUz2sVEcy/bjqDsooqpr69hqIy7apS3mnx9gwAhnWOtjnJ6WvI3AMZwGEgC4i1Jo5qarrGhvPydQPZfjiPBxdsxBi9Kkl5F2MMb6/YS9+4SHq19a4ZUWtTb1EQkdtF5CdgMRAD/K6+abOVaohRPWJ58MKefLHxELN+TrE7jlINsiI1i51HCrhheILX9yeAa9NcdATuM8astzqMarpu+01nth7K4/lvdhAREsikYR3tjqSUS95avo8WYYFc1r+d3VEahSsrrz3sjiCqaRMRXryqP8VlFTz6qWMFVy0MytMdyCnm262HmXp2F0IC/e2O0yh0PmPlMYIC/Hhl4iBG94rl0U83M3tJivYxKI/27krHcjO/HeY767xoUVAeJSjAj5kTk7g4sS3PfLWdp77YqlNiKI/1zZbDnNWtFfEtvHvAWk0uT52tlLsEB/gz/bqBtI4MYe4vezhWWMZLVw/AX9djUB4ku7CM1MxCJiT51vLyWhSUR/LzEx67tDfR4UG8sGgH4cEBuh6D8ijr0hxzHQ3q2MLmJI1Li4LyaHee05WC0gpm/ZRCVGggD47paXckpQBYsy8bfz+hf7z3rcNcFy0KyuM9eGEPcorKeeWnFHKKy3ni0j4EBWh3mLLX2n059G4bSWiQb1x1dJz+ZCmPJyL8ZXxfbh/Vhfm/7mfinJVk5pfaHUs1YRWVVaxPy/G5U0egRUF5CX8/4aExPZl27QA2Hcjl0unLWLPvmN2xVBO1/XA+xeWVDOzgW6eOQIuC8jLjBsTx0e0jCA7045rXVvKvX/boWAbldmv3+2YnM2hRUF6oT7soPrtrJKN6xPLk51t5ftEOLQzKrdbsy6Z1ZDBxzUPtjtLotCgorxQVGsjsSYOYOLQDs35K4cnPdZCbcp+1+7NJ6tDCJy+R1quPlNfy83N0QIcG+jNn2R5Kyiv56+WJOshNWSojr4S0Y8XcODzB7iiW0KKgvJqI8KeLexEa5M/0H3ZTXF7Ji1f1J8BfG8HKGt9uPQLAiC4xNiexhhYF5fVEhAcu6EFokD/Pf7OD4rJKpl070OeuH1ee4bP1B+kWG+4TC+rURv+cUj7jjlFdefKyPny37QjXzF7BkbwSuyMpH3Mgp5hVe48xbkA7n+xPAC0KysfcOCKB1ycNZndGAZfNWMb6tBy7Iykf8vmGgwBc1j/O5iTW0aKgfM7o3q1ZcNsIAvz8uOrV5by9Yq9esqoaxcL1BxnYoTkdon1nquwTWVYURGSuiGSIyOaTPC8i8rKI7BaRjSKSZFUW1fT0bhfJF3ePZGTXGB5duIUHPtxAaUWl3bGUF9t5JJ9th/IY5yPLbp6MlR3N84AZwFsnef4ioJvzNhSY5fyqVKNo0SyIN24cwowfd/PSdzs5mFvMa5MGExUaaHc05SWMMfzrl72s3Z/NtkN5+Alc3M+3i4JlLQVjzBKgrslpxgFvGYeVQHMRaWtVHtU0+fkJ95zXjX9eM4A1+7K56tXl7DqSb3cs5SWW7jrKU19sZd3+HGLCg/nDhT1oFRFsdyxL2XlJahyQVuNxunPbIXviKF82fmAcsRHB3Dl/LRdPX8YD53fnlrM660A3VafXl6YSGxHMj38Y1WSma7fzu6ztp7HW3kARmSoiySKSnJmZaXEs5atGdI3h2/t/w6jurXj26+3c8uZq7WdQJ7XtUB5Ldx3lxhEJTaYggL1FIR1oX+NxPHCwth2NMbONMYONMYNbtWrllnDKN7WKCOa1SYN4enxfftyRye3vrNXCoGo1Z+keQgP9mTi0g91R3MrOovAZcIPzKqRhQK4xRk8dKcuJCJOGdeSvl/flh+0Z3PnuOsoqquyOpTzIkbwSPttwgGuGtKd5WJDdcdzKsj4FEXkPGAXEiEg68DgQCGCMeRX4ChgL7AaKgMlWZVGqNhOHdqSyyvDYwi3cOX8tM69PalKnCVTtqqoMT32+lcoqw81ndrI7jttZVhSMMdfV87wB7rTq+Eq54gbnTJdaGNRxz3y1jS83HeLhi3r69CC1k9H//arJu2F4Ak+N68N3W49w69vJFJdpH0NT9cayPcxZtoebRiRw69md7Y5jCy0KSuEoDM9cnshPOzO5Ye6v5BaX2x1JuVluUTnPfb2d0b1iefSS3j474V19tCgo5XT90A7MuC6J9Wk5jJ/5Cx8mp1FSrq2GpuKLTQcpq6zi3vO6N+nxK1oUlKrh4n5teXPyGQQH+PHggo2MfO4HXl+SqsWhCfh03QG6tGpG37hIu6PYSouCUicY0TWGr+89i/m3DKVX20j++tU2znvxZz5MTqOiUi9d9UVpx4pYvTebK5Lim+xpo+O0KChVCxFhRNcY3p4ylPm3DCU6PIgHF2zk3Bd/5sPVWhx8zSfrDgAwboBvT3bnCi0KStVjRNcYFt55JnNucMyw+uBHG7l0xi+s2lPXfI/KWxhj+GQxsFNCAAAQeUlEQVTdAYZ2akl8i6Z3CeqJtCgo5QIRYXTv1nx215m8MjGJ3KIyrn5tBY9+upnKKl3Ax5ut3pvNnqOFXJHku6upNYSds6Qq5XVEhLGJbTmnRyx//3YHbyzbQ3ZRGS9dPUAHvXmhowWl3P/BelpHBjM2UWfuBy0KSp2S0CB/Hr2kN7ERwTz79XbySyr424RE2kaF2h1Nuaisooo73lnL0YJS/n3bcCJCdPEl0KKg1Gm59TddiAwN5LGFm/nNCz8xcWgHbj6zE+1b6rlpT5VbXM7POzP5d3Iaq/YeY9q1A+gX39zuWB5DvG1B88GDB5vk5GS7Yyj1X9Kzi5i+eDcL1qZTWWVIjIviiqQ4bhyegF8THgjlSTak5TBv+V6+2HiQ8kpDi7BA7jynK7ec1TSmsxCRNcaYwfXup0VBqcaTnl3EV5sO8eXGQ2xIz+Wivm146eoBhAb52x2tyTLG8PBHm/ggOY3w4AAmJMVxaf92DOzQokmNXNaioJSNjDHM/WUvf/lyK/3ionj9hsHERobYHatJemvFXh5buIWbz+zE/ed3a7J9B64WBb1cQikLiAhTRnZi9qTB7DxSwNiXl/HL7qN2x2py1u7P5ukvtnJuz1j+fHGvJlsQGkJbCkpZbNeRfG5/dy0pmQXcNCKBy/q3o398c+1rsMi6/dlM/2E3+SXl7DxSQERIAF/efRZRYU27IOjpI6U8SFFZBY8v3MJHa9OpMhATHsSEpHgmDu3YJBdysUpVlWHMtCVk5pfSs00kzcMCuW90d3q0ibA7mu20KCjlgXKKyvh5ZyZfbzrMd9uOUGUM5/WM5e5zu9G/vV4WeboWbTnMrW+vYdq1Axg3QEco16RFQSkPdzi3hPmr9vPWir3kFJUzqkcr7jmvG0kdWtgdzSsZYxg/8xeyi8r54YHfEOCvXaY1aUezUh6uTVQIvz+/O8seOpcHx/RgQ1oOV7yynBvnrmLboTy743mdZbuPsiE9l9tHddGCcBr0X04pm4UHB3DHqK4se+hcHr6oJxvTcxg38xfm/7ofb2vJ2+VYYRkvfruT1pHBOrHdadJpLpTyEM2CA7jtN124alA8932wnkc+2cSSnZlcnhTH8C7RROrllP+jvLKK15emMuvHFArLKnj+yv4EB+hAwdOhRUEpDxMdHsybk8/glZ9288pPKXyz5TD+fsI1Q9rzfxf11Gvta5jxw26mLd7F6F6xPDSmJ91a61VGp0s7mpXyYGUVVazbn82Xmw7x9sp9tIsK5dFLenNOz1ZN/i/i/JJyzvzbDwzrHM3sG+rtP23yXO1o1paCUh4sKMCPoZ2jGdo5mnED4vjjgg3c9s4awoL8Gdk1hvtGd6d3u6a50Pw7K/eTV1LBXed2tTuKT9GOZqW8xKCOLfj63rOYe9NgrkiKI3lfNuNmLuO1n1Oa3OpvJeWVvLEslbO6xei0141MWwpKeZHgAH/O7dmac3u25vfn9+CRjzfx7Nfb+WF7Bi9e3d8n1xjOyC/hcG4JoYH+BAf4U1pRyZebDnG0oIy7ztFWQmPToqCUl2rZLIhZv01iwZp0nvx8K2P+uZTfn9+d0ooqNh3IoX3LMKac2ckrZ2ctq6ji3V/38eXGQ6zZn01tXZ9nJLRkaOdo94fzcdrRrJQPSDtWxAMfbmDV3mMAxLcI5WBOMQH+flyS2Jb4FqFEhgYSHhxAs+AA2kaFMKB98+pBXlVVBgMesb5A2rEi7npvHRvScujdNpIxfdvQq20kJeWVlFZUERzgR0igP0kdmhMdHmx3XK/hER3NIjIGmAb4A3OMMX874fkOwJtAc+c+DxtjvrIyk1K+qH3LMN6bOowtB3Np3yKMFs2C2JdVyKs/p/D15sPkFpf/z1/bzcMCGdk1hqyCMjYdyEUERvdqzQW9WzMooQWxEe5vYXy39QgPfLgeA8yamMRFiW3dnqGps6ylICL+wE7gfCAdWA1cZ4zZWmOf2cA6Y8wsEekNfGWMSajrfbWloFTDVVUZCsoqKCipoLC0gl0ZBXy/9Qi/pByldWQI/eKjKCmv4vttR8gpKgegTWQIfeOiSIyLYnBCC0Z0iUak8VoS83/dz5ylqUw+M4GrBrdn1k8pTFu8i8S4KGZen6SzxzYyT2gpnAHsNsakOgO9D4wDttbYxwDHr6eLAg5amEepJsvPT4gMCaweFd2tdQRja/krvKKyinVpOWxIy2HTgVw2Hchl8fYjGAPn927Ns1ckEtMIp2yW7z7Kows3ExUayKMLt/C3r7dTWFbJlYPi+cv4voQENu0xGHaysijEAWk1HqcDQ0/Y5wngWxG5G2gGjK7tjURkKjAVoEOHDo0eVCnlEODvx5CElgxJaFm9raC0gvd+3c8L3+7gwn8s4d7R3Rg3II6o0FMbWb0/q4g75q+lc0wzPr5jBGv35zBnaSoX9G7Nb4d1bNTWiGo4K08fXQVcaIy5xfl4EnCGMebuGvv83pnhRREZDrwB9DXGVJ3sffX0kVL22HkknwcXbGR9Wg7BAX5c0KcNo3vFcla3VhhjOJhTQl5JOf5+QnCAHz3aRBAWFIAxho3puSzedoSNB3JZtz8HgIV3nklCTDObv6umwxNOH6UD7Ws8jud/Tw9NAcYAGGNWiEgIEANkWJhLKXUKureO4JM7RrD5QB4fJqfx1aZDfL7h5Gd8A/yExPgosgvL2JtVhL+f0C02nNG9WnPD8I5aEDyUlUVhNdBNRDoBB4BrgetP2Gc/cB4wT0R6ASFApoWZlFKnQcTxiz4xPoonL+vDpgO5/JJylJAAf+JahNI8NJBKYygoqWB9Wg6r9x6jfcsw7hjVlQv7tjnlU07KfSwrCsaYChG5C1iE43LTucaYLSLyFJBsjPkMeAB4XUTux9HpfJPxtoETSjVRfn5C//bNT7qM6AV92rg5kWoMlo5TcI45+OqEbY/VuL8VONPKDEoppVynE+IppZSqpkVBKaVUNS0KSimlqmlRUEopVU2LglJKqWpaFJRSSlXToqCUUqqa1y2yIyKZwL5GftsoINfm93L1da7sV9c+J3uuIdtjgKP1ZLCa3Z9ZQ15T376n+nxt2/XzOv3X+erPWEdjTKt69zLGNPkbMNvu93L1da7sV9c+J3uuIdtxjEhv0p9ZQ15T376n+vxJPhv9vE7zdU39Z0xPHzl87gHv5errXNmvrn1O9lxDt9vN7s+sIa+pb99Tfb627fp5nf7rmvTPmNedPlL2E5Fk48IUvMoz6Oflfez8zLSloE7FbLsDqAbRz8v72PaZaUtBKaVUNW0pKKWUqqZFQSmlVDUtCkoppappUVCNSkTGi8jrIrJQRC6wO4+qm4h0FpE3RGSB3VlU7USkmYi86fy5mmj18bQoqGoiMldEMkRk8wnbx4jIDhHZLSIP1/UexphPjTG/A24CrrEwbpPXSJ9XqjFmirVJ1Yka+NldASxw/lxdZnU2LQqqpnnAmJobRMQfmAlcBPQGrhOR3iKSKCJfnHCLrfHSPztfp6wzj8b7vJR7zcPFzw6IB9Kcu1VaHczSNZqVdzHGLBGRhBM2nwHsNsakAojI+8A4Y8yzwCUnvoeICPA34GtjzFprEzdtjfF5KXs05LMD0nEUhvW44Q95bSmo+sTxn79SwPEfNK6O/e8GRgNXishtVgZTtWrQ5yUi0SLyKjBQRP7P6nCqTif77D4GJojILNwwJYa2FFR9pJZtJx3xaIx5GXjZujiqHg39vLIALd6eodbPzhhTCEx2VwhtKaj6pAPtazyOBw7alEXVTz8v7+URn50WBVWf1UA3EekkIkHAtcBnNmdSJ6efl/fyiM9Oi4KqJiLvASuAHiKSLiJTjDEVwF3AImAb8KExZoudOZWDfl7ey5M/O50QTymlVDVtKSillKqmRUEppVQ1LQpKKaWqaVFQSilVTYuCUkqpaloUlFJKVdOioCwnIgVuOMZl9U0TbcExR4nIiFN43UARmeO8f5OIzGj8dA0nIgknTuVcyz6tROQbd2VS7qdFQXkN59TCtTLGfGaM+ZsFx6xrfrBRQIOLAvAIMP2UAtnMGJMJHBKRM+3OoqyhRUG5lYj8UURWi8hGEXmyxvZPRWSNiGwRkak1theIyFMi8iswXET2isiTIrJWRDaJSE/nftV/cYvIPBF5WUSWi0iqiFzp3O4nIq84j/GFiHx1/LkTMv4kIs+IyM/AvSJyqYj8KiLrROR7EWntnPb4NuB+EVkvImc5/4r+yPn9ra7tF6eIRAD9jDEbanmuo4gsdv7bLBaRDs7tXURkpfM9n6qt5SWO1bm+FJENIrJZRK5xbh/i/HfYICKrRCTC2SJY6vw3XFtba0dE/EXkhRqf1a01nv4UsHwFMGUTY4ze9GbpDShwfr0AmI1jNkg/4AvgbOdzLZ1fQ4HNQLTzsQGurvFee4G7nffvAOY4798EzHDenwf823mM3jjmqAe4EvjKub0NkA1cWUven4BXajxuwX9G/98CvOi8/wTwhxr7zQdGOu93ALbV8t7nAB/VeFwz9+fAjc77NwOfOu9/AVznvH/b8X/PE953AvB6jcdRQBCQCgxxbovEMTNyGBDi3NYNSHbeTwA2O+9PBf7svB8MJAOdnI/jgE12/7/SmzU3nTpbudMFzts65+NwHL+UlgD3iMjlzu3tnduzcKw09dEJ7/Ox8+saHEsV1uZTY0wVsFVEWju3jQT+7dx+WER+rCPrBzXuxwMfiEhbHL9o95zkNaOB3o51hgCIFJEIY0x+jX3aApknef3wGt/P28DzNbaPd96fD/y9ltduAv4uIs8BXxhjlopIInDIGLMawBiTB45WBTBDRAbg+PftXsv7XQD0q9GSisLxmewBMoB2J/kelJfToqDcSYBnjTGv/ddGkVE4fqEON8YUichPQIjz6RJjzIlLEJY6v1Zy8v/DpTXuywlfXVFY4/504CVjzGfOrE+c5DV+OL6H4jret5j/fG/1cXliMmPMThEZBIwFnhWRb3Gc5qntPe4HjgD9nZlLatlHcLTIFtXyXAiO70P5IO1TUO60CLhZRMIBRCROHOsERwHZzoLQExhm0fGX4VjBys/Zehjl4uuigAPO+zfW2J4PRNR4/C2OWS4BcP4lfqJtQNeTHGc5jumSwXHOfpnz/kocp4eo8fx/EZF2QJEx5h0cLYkkYDvQTkSGOPeJcHacR+FoQVQBk4DaOvAXAbeLSKDztd2dLQxwtCzqvEpJeS8tCsptjDHf4jj9sUJENgELcPxS/QYIEJGNwNM4fgla4SMcC5lsBl4DfgVyXXjdE8C/RWQpcLTG9s+By493NAP3AIOdHbNbqWVFM2PMdiDK2eF8onuAyc5/h0nAvc7t9wG/F5FVOE4/1ZY5EVglIuuBPwF/McaUAdcA00VkA/Adjr/yXwFuFJGVOH7BF9byfnOArcBa52Wqr/GfVtk5wJe1vEb5AJ06WzUpIhJujCkQkWhgFXCmMeawmzPcD+QbY+a4uH8YUGyMMSJyLY5O53GWhqw7zxJgnDEm264Myjrap6Cami9EpDmODuOn3V0QnGYBVzVg/0E4OoYFyMFxZZItRKQVjv4VLQg+SlsKSimlqmmfglJKqWpaFJRSSlXToqCUUqqaFgWllFLVtCgopZSqpkVBKaVUtf8HED/1oeujmTwAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()\n",
"learn.sched.plot(100)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1b5df4c939c247ac9ba721c49b907c23",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=10, style=ProgressStyle(description_width='initia…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 0.457467 0.241464 1.0 \n",
" 1 0.371568 0.720745 1.0 \n",
" 2 0.308038 0.527178 1.0 \n",
" 3 0.299045 0.348353 1.0 \n",
" 4 0.290519 0.441548 1.0 \n",
" 5 0.275544 0.700616 1.0 \n",
" 6 0.27094 0.62167 1.0 \n",
" 7 0.291084 0.277607 1.0 \n",
" 8 0.271876 0.484772 1.0 \n",
" 9 0.258556 0.099193 1.0 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[array([0.09919]), 1.0]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lr =8e-2\n",
"learn.fit(lr, 10)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a8496f7ffabc4094bf950cad728e0dba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3, style=ProgressStyle(description_width='initial…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 0.298079 0.411199 1.0 \n",
" 1 0.287929 0.256555 1.0 \n",
" 2 0.267667 0.309349 1.0 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[array([0.30935]), 1.0]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lrs = np.array([lr/8,lr/2,lr])\n",
"learn.unfreeze()\n",
"learn.fit(lrs, 2, cycle_len=1, cycle_mult=2)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e0bd5a404ebe4745b2350a277208d0ad",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3, style=ProgressStyle(description_width='initial…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 0.264641 0.335151 1.0 \n",
" 1 0.253506 0.301091 1.0 \n",
" 2 0.227851 0.210845 1.0 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[array([0.21084]), 1.0]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sz=150\n",
"learn.set_data(get_data(sz,bs))\n",
"learn.freeze()\n",
"learn.fit(lr, 2, cycle_len=1, cycle_mult=2)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c51ad72d155f4920ba8e31a6015359e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3, style=ProgressStyle(description_width='initial…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 0.290355 0.33793 1.0 \n",
" 1 0.275085 4.280131 0.0 \n",
" 2 0.193974 0.089761 1.0 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[array([0.08976]), 1.0]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.unfreeze()\n",
"learn.fit(lrs, 2, cycle_len=1, cycle_mult=2)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8f70eb7cf1bf47ce931834dc96b6fcd0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3, style=ProgressStyle(description_width='initial…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 0.180609 0.178459 1.0 \n",
" 1 0.190406 0.078421 1.0 \n",
" 2 0.162267 0.077021 1.0 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[array([0.07702]), 1.0]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sz=200\n",
"learn.set_data(get_data(sz,bs))\n",
"learn.freeze()\n",
"learn.fit(lr, 2, cycle_len=1, cycle_mult=2)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d769857d113445dd855e58d0c9867c56",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=7, style=ProgressStyle(description_width='initial…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 0.165711 0.150288 1.0 \n",
" 1 0.200272 0.130915 1.0 \n",
" 2 0.131086 0.156715 1.0 \n",
" 3 0.172976 0.449126 1.0 \n",
" 4 0.144101 0.022469 1.0 \n",
" 5 0.101792 0.092335 1.0 \n",
" 6 0.084492 0.046531 1.0 \n",
"\n"
]
},
{
"data": {
"text/plain": [
"[array([0.04653]), 1.0]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.unfreeze()\n",
"learn.fit(lrs, 3, cycle_len=1, cycle_mult=2)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"fn = data.test_ds.fnames"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7301,)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(np.array(fn)).shape"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"log_preds, y = learn.TTA(is_test=True) # use test dataset rather than validation dataset\n",
"probs = np.mean(np.exp(log_preds),0)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7301, 6)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"probs.shape"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7301,)"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"p = np.argmax(probs,axis=1);p.shape"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(p,columns=['label'])"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"df.insert(0, 'image_name', [f[5:] for f in data.test_ds.fnames])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"df['ord'] = fn\n",
"df['ord'] = df['ord'].str[5:-4].astype(int)\n",
"df = df.sort_values(by='ord')\n",
"#df.reset_index();df.head()\n",
"\n",
"#sample = df[['image_name','label']];sample.head()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"sample = df.copy()"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" image_name | \n",
" label | \n",
" ord | \n",
"
\n",
" \n",
" \n",
" \n",
" 5729 | \n",
" 3.jpg | \n",
" 5 | \n",
" 3 | \n",
"
\n",
" \n",
" 2123 | \n",
" 5.jpg | \n",
" 0 | \n",
" 5 | \n",
"
\n",
" \n",
" 525 | \n",
" 6.jpg | \n",
" 4 | \n",
" 6 | \n",
"
\n",
" \n",
" 5376 | \n",
" 11.jpg | \n",
" 2 | \n",
" 11 | \n",
"
\n",
" \n",
" 2851 | \n",
" 14.jpg | \n",
" 5 | \n",
" 14 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" image_name label ord\n",
"5729 3.jpg 5 3\n",
"2123 5.jpg 0 5\n",
"525 6.jpg 4 6\n",
"5376 11.jpg 2 11\n",
"2851 14.jpg 5 14"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sample.head()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"sub =sample.reset_index()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" image_name | \n",
" label | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 3.jpg | \n",
" 5 | \n",
"
\n",
" \n",
" 1 | \n",
" 5.jpg | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 6.jpg | \n",
" 4 | \n",
"
\n",
" \n",
" 3 | \n",
" 11.jpg | \n",
" 2 | \n",
"
\n",
" \n",
" 4 | \n",
" 14.jpg | \n",
" 5 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" image_name label\n",
"0 3.jpg 5\n",
"1 5.jpg 0\n",
"2 6.jpg 4\n",
"3 11.jpg 2\n",
"4 14.jpg 5"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sub = sub[['image_name','label']];sub.head()"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"sub.to_csv('s.csv', encoding='utf-8', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"learn.save('200_all_2')"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
"#learn.load('200_all')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}