Ege Demir
commited on
Commit
·
2ee333c
1
Parent(s):
3bc6595
Initial copy-up of DCGAN code
Browse files- DCGAN_train.ipynb +408 -0
DCGAN_train.ipynb
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "0d3774bd-5295-42ac-b0e6-4f3d3a82901a",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import tensorflow as tf\n",
|
| 11 |
+
"from tensorflow import keras\n",
|
| 12 |
+
"from tensorflow.keras import layers\n",
|
| 13 |
+
"import numpy as np\n",
|
| 14 |
+
"import matplotlib.pyplot as plt\n",
|
| 15 |
+
"import os\n",
|
| 16 |
+
"import gdown\n",
|
| 17 |
+
"from zipfile import ZipFile\n"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "code",
|
| 22 |
+
"execution_count": 2,
|
| 23 |
+
"id": "4f7cd728-3373-4fb7-b595-f594b7b14525",
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"outputs": [],
|
| 26 |
+
"source": [
|
| 27 |
+
"os.makedirs(\"celeba_gan\")\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"url = \"https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684\"\n",
|
| 30 |
+
"output = \"celeba_gan/data.zip\"\n",
|
| 31 |
+
"gdown.download(url, output, quiet=True)\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"with ZipFile(\"celeba_gan/data.zip\", \"r\") as zipobj:\n",
|
| 34 |
+
" zipobj.extractall(\"celeba_gan\")\n"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": 4,
|
| 40 |
+
"id": "c74b2281-2fae-4be9-8463-0f9bba9d0c45",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [
|
| 43 |
+
{
|
| 44 |
+
"name": "stdout",
|
| 45 |
+
"output_type": "stream",
|
| 46 |
+
"text": [
|
| 47 |
+
"Found 202599 files belonging to 1 classes.\n"
|
| 48 |
+
]
|
| 49 |
+
}
|
| 50 |
+
],
|
| 51 |
+
"source": [
|
| 52 |
+
"dataset = keras.preprocessing.image_dataset_from_directory(\n",
|
| 53 |
+
" \"celeba_gan\", label_mode=None, image_size=(64, 64), batch_size=32\n",
|
| 54 |
+
")\n",
|
| 55 |
+
"dataset = dataset.map(lambda x: x / 255.0)"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": 8,
|
| 61 |
+
"id": "c9e9b947-45b0-456c-ba7e-914d43045f18",
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [
|
| 64 |
+
{
|
| 65 |
+
"data": {
|
| 66 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAuHklEQVR4nO19a6xm13nW2vfvdu4zZ+bMjGfGlzjjXOzYcZvUtCaOA1VbXIGaVqhAqyqoQi1RSgu0QvAHVIgEpYEfTYuERC9CXBRCgaBISaFpGiltUO0msZ0Ze+zx3M+c23fOd913fhjv93nfc/b2N8djd9m8z6/1nbX23muvvdfZz3t3yrI0CoXCPrh/1hNQKBQHQzenQmEpdHMqFJZCN6dCYSl0cyoUlsJv6jxz+sRMqtw0TdnvoiiqtuM4NC7jx2UZ/UGeI8/zqv120yjL+eIaYFv+xuPkOVx3tv+jTec4DPA5GGNMq9Wq2nPdHutbWlqq2r1e78BjjDHG87wD28bw9XA9XJuCjSvL/MBjjDHGdel3FAS158f1kedgcxRLXxTiRT7gfK+OKw4cJ/H5//5l56C/65dTobAUujkVCkuhm1OhsBS6ORUKS6GbU6GwFLo5FQpL0WhKUSiaTD+303eY6zWZgmadx9sZ+uVUKCyFbk6FwlLo5lQoLIVuToXCUujmVCgshW5OhcJS6OZUKCyFbk6FwlLo5lQoLMUd8RBqCjL1fbzE2ytoWrEfh/EQavLYebMD6eW1mwKs6yCn+FYF/+uXU6GwFLo5FQpLoZtTobAUujkVCkuhm1OhsBS6ORUKS3FHTCmzqtRl7tFZ83oq3t54M4KyZ0VTDuHDmFXeSuiXU6GwFLo5FQpLoZtTobAUujkVCkuhm1OhsBS6ORUKS9FoSpnV+35WdbUs94bl5bAc4NsRuAKZEWXtDJgORGRO4LjQB6Yll4/LDS/FV3t1ZhIQzyWH0ozCjIW/XJdeC8cR7wCU15NlCevMEbKM4KzlDGeHfE9pHrlcA/iJ72bo8Dl57LcMS2noO8y4GuiXU6GwFLo5FQpL8ZZ6CO2vQOzW9r3dgLN3HU7RHcc/eKAxpqhdH06NPWCGcqVKoEwOUKlSUNICqHHucKrJyXADNS5v36vrdsolsL7GV6K+k9Fmee264w79+jXdC7WlhDjL+65fToXCUujmVCgsxZtCa+v6msZJDZ7U8NkP4i3+Ps0c0VxH/D8sS9Bsg4ZQnqHpv2hZo60tBeXC8xcOf/SoKXbhfHciW85hK5UdltY2Xvvw/PW2r/VGx+qXU6GwFLo5FQpLoZtTobAUb0pl68PImW9VLtCDcGfMOHSOwkjPGfghTBGhT/c9mYyqdhSFbJwXkWklzVLWV4J87kGe4H1ya0bXjoS3FnpoZQ3WkrzJIaYmeFl6huFzl7oG/hsmIp4RBuo35aaVYO8m76k9x6yvh7wu3vftzPE16JdTobAUujkVCkvRSGtndVCWuYDqPYTqz/929xBiju8up6QuM7MkrM/JiaJ2YLnPri6ycfedWKO+s2dZ35GjR6u2B2vKS2EYM4onVfvZ8+dZ3zPPXazaNza2qvY058/W9SOau1dPSQ/r3M7fg3qzCqeMTee486ijpHdaNNMvp0JhKXRzKhSWQjenQmEpGmVOxwQznWQfwy9RzqT973lcfkH1fZMbF3L5psDuxjk2mHcKl+7TLXlEiV+STOiK45IyhDb+n+PnWHLjqv3ke9dY31NPfA/8onlkwTwbd21E14qnU9Y3mtK6ehCkHYR8vmHYrdqPvu9+1vfYOZpXlpFcvNHnz+xrT79UtW/s9Flf4JIc6PotOp9wFeQyOF+rEt0IG2TYsqz/ruDz9aXsWx7smlgaGcEDv4t6WbLp3XyjLqj65VQoLIVuToXCUjTT2pkp42wUQ0YEYN+sqvdZ89Y0QR7jF0Q7S0HBYqdDPwTl7bhEebtABZ945B427qmP/rmq3Sq4KaUVtqt2f0rzGk44JSomY5p/yueBNBFpW5byaxmX7qVwFlhXGhHl9du0HnP+Nhv3/U+cqdpJeo71nX/petW+uUfHuT6/Vu4QfY9dLjq13IO9jA4b2eKK16OOoEpzIPNAEmPrqOztBJXPAv1yKhSWQjenQmEp3vRg6yY0UVnsm7Ua2WE0t8YY4wFdLVyRvhOuHQhKemqRxv7i3/zxqt2a3mLjliKiiWF7kfXd2qVzrm/uVu1MaMrDHDx/xHqgJ1Be0L0IFm4c0BrnXsT6kpwob5KSpjUq+XwjZNCTPut735mVqn1XQmv84s0dNm4PGLsb8Xk4JXXOHJQt0OSVhrgTHj0zO9kfYo/ol1OhsBS6ORUKS6GbU6GwFHfIlDJjHlKZ2b8hKuVOVz9u8uQoQc70hMdKL6dIjr/ykYdY3xMfuLtqz+f9qh3OL7Nx2yOSEa+9con15QXJlh6WbUgnbFwIcyxEwLaXczn5Nbi+kJ/B9JOYPdbnZ+SRlCUkcxpXyKaGImAKj6/jdEAyc7dNJqJza0fYuKvbw6rdnwz4pKM2/Ljz0SX47FGXIfUas5r26jzZDvp9u9Avp0JhKXRzKhSW4nVo7WxVkg6bt3ZW6opX3j8M6YhwaGdOznAtcStpRvQvzLlT+ad+4i9U7QdPzvHzD8kLJgNXlMujNhs3isE84HFvGddAsDUEBviucMAHy4q0LJUlmWOwQhhfG2PSkk4SlPzR+2CCSQqimmnAF2sAr4zfO8r6xnvkxRSVQMNHQzaul8N8Dc+H1DdIow8ZsA3tJqrZRDt5n6wydrBXUFM+JDWlKBTvIOjmVCgshW5OhcJSNMqcPZCjSkGZC49klGkiAmYL5Nok5xRSBmrKPQo/C1ZfT+QGLUiec4UsluBYcHFrxfxaKynJmb/8D3+a9YXpetUeDPt8/inJR7khWXKYc5nTycltTgYXhw7ImVCir9UWScLw3jwpR1E7h8Sy+ByMMcYHf74g4IuVpFCBHAKlHSETLoN73djhLoanTpPJ5NqVa1V7bfkUG5dOrlbt0NtlfVlOz2YKMnLmtti4AuVnh5uM8NYy94276O3PIEbNpoDtulxlxhhT5q/vkqpfToXCUujmVCgsRSOtPRmQSl1S0hwo7zjge3wIUQcZ0KCsOJwppRlQBkHmNgVTUB6T+t4RMch/+6c+VrXL0Rbr2x1CvpvoOOtLAgiOhvyzwb56BuCJksasxwP2GniwHgWnk05OlNQT+WIdFD/gWplgdJjTNhb5bTyPXgV81kUhq3QThYzEfcYTMpkcmSMamky5KWXtJEWvXLnOI3haQKORohcFp9ClDwsnljuTf3iD2G9ymY0qN+a+muF4/XIqFJZCN6dCYSl0cyoUlqJR5lx0RrV9GRwauPw0HvxG+XNS1Ls3zVpvZT/gnDKXKci7EXQ9eu4MG9YNaJwvMiH40WLVTh1uIkldkKdLMglEIgVBCe5qRcKjTbwQ1g7U66XhMqELsp5MWuWDDOoFJIulwowQgwwXStMYDE1h3YyQfdMUnpOoBxiAMI/3nAl5ywdTzdIyN5G4u2P4QdkZUpGsLPfouNwRUTogE8qkcrPWOWl6/zAfbVOCr7rzGWMac+G+Bv1yKhSWQjenQmEpGmltaoBKyc8+/A4FdQiAZngu0gNOGZESeCIw2IjqzXVw4P9LKW4Hzx/AuO975H1sXJxR37jkkScJzDnNuRkE6WTXIW+hyBNl/mBeIeSHfXUsRqXQ+YJAeMRAwi9H0NUMzB052JMSYVtKwRwjHgUzmZTgnuU6ku5hBWwROQM0H6NNFjp8TfOMPLLmhRmubMG9JHSO7ZEQiSB4pRBzLEAkcIQodZics64rqfHtB2LvO7+UTQ6AfjkVCkuhm1OhsBSNtBa1q/v9h4E+iS90DLSIKfQaKlvvy00LxzmszceV0iMfgFpkF3LOjkQ5g87CatWeGkG9wePc9wQ1wTy2KVEw6dzeBlVx5PElj4DK+j6Ma3Na63ikKS5k3l2IxE4TmNNozMa14bg45kHlbQgMmECfXF4f6Ngo4zQ/h6BvDx6a7/BrFRBogJ5br84LKCl4XWUTruUOuuRlZDzJ0ZGiC8+iQ9Ba2TdrTuXGBAIzuAjpl1OhsBS6ORUKS6GbU6GwFI0yZwGqeBmVgnJPIvY4SlwF9jVwcFeUEWTx1dDel7AJ2yLa2gNi3+uSOj/zuRySYvBsyk04LYwAETIWyk5oanKFx4pbYOKuek8oPyT7QKvLzQ8+lApszS2xvgSC3acT8upaXFpk44a75MU0t8DPv7GxUbW9BpkKX5jQ569PkoFnEa6BIzymnIPNNsYYk+bk4VTEGKQu4KCXjnh34NIyAHpWDyH++3ClK2fNfVt7/Bs6WqFQvGnQzalQWIpGWusy2iLpJFBeoU0ugKogbbkzJRdkPpf649DysQw0rky4an9zc7Nqn1hb4ydhJfU4rc3GFETsA+1PhNdLm5lPhBkE+tptmmOvx/Pblmxch/UdP0Fmha1NoqfJmJdcWF6hczCTizFmaYmo8t4eHZck0tRBazAY8lIKQUiU1A/oWtIzLIcg7VTkn3Ihl5HJIOeRCK7AsIBcBI67QHMlrUWa3mRKqSvb8GrfweaTO1E2BKFfToXCUujmVCgshW5OhcJSNMqcASvHJvLFYukzcZocStLFU5AphBa7qXwagssG4v8JO4c4DgKFuyGZT/KYB5HHE3Jz2xU1StaOUj2QWCTnwlJ8CcglrsgglqYki7U8bsbx4bfnQp/D17TV6VXtuYVF1ueAPLq4ROUHN4RsPRrRfY8HXB6dgHscynCpMC1hXxiIkvHg2sfkSpfLhA7cWxzz8xe4xHCtVsivNSjpOMfhMnjBTCn17nWzluiT4+pMJFoCUKH4/wS6ORUKS9FIaxcWsUKzyFsL9HIqAlrzCVHDMZbUizm9wc/+rN4UkiiwWYmgWzzjXI+iPFo+v1aREt0bblxlfZM2esuIiBUwEYRg3ghdTpujFtHaTofnIep2Kfg6iiAvjlDfZxDe4wlqnAL9w9Q0XeFlVAItL0UwuycjO/4ffF+KLOCZI0wkKeQNYmUhYi4qpHDt6VRGrICXFJhmkpDPYwh5jgppLoE5et6dNW8Ys9908xrkO6y0VqF4h0I3p0JhKZpp7fJCbV8J+3pP0JYpODa7SGUnwgG6yQsDndEx9aPw0sH/LqnDA5RTcNz3gIJ1fE51YshfVIpg6J09ol29OU4TXaBTnYiO64Jm1RhjwoAorxtyzWIOQdRIhktRYdsfQuXpq3wN2l04J6ypDKjGCtvjklPSdfAmKn0QNwSddD2qJObvCXoHtHm4fb1qO8LxPYNg9yQVFbZhHVvgaVV4vOra1pSocRHIfFOYvpP3zBpszUQuaSEoDx4nPYeaKlvPQnn1y6lQWArdnAqFpdDNqVBYimYPoYBU9u4+IwZ4/rtcHhhMKFrByQ5OXW/M4aJS8n0mF4xAEFEvPsosJGNlQn6JIJdsf4+bQXpQLiD1uJfKlSuXq/bCPMmZvS4//y6YZ1Z686zv9CkqK3jqLvJGmuvw/LbLy8eqdqvFZWv0/EHTzC4EVxtjzCQmmagViMRXI3pmVzeomveNjT4bNx2B147P5ziFEoAh5O7tRfyZJRDNEwsdQgffOYhsSVIuPObwrLN9BfXod7AvIdzBMmdTOZB9Bryaiux3JuqKoF9OhcJS6OZUKCxFI61FZbsjKSnoqF1RAcpMiWY54H1Tlgd7oRgz+2e/EA7hGKzs51xl74HHisEcqEI1PrdAdNJt8/w8U0MU9eLlddZ3c4toY75OFbGjFjeXnD11V9VORWWucpcczgclmR+W5zl1PbJD1+p2uakGvXj6OxQ4LgOlN3aIum6OeN8eVA9bH9M9l3On2bjnX3quat+8eYn1zfWI9i/B/NcWuVdUG3LJBlwCMEELRAcfKmxP+XxLeJ6eyMuEFdoKUc2rLthaevew91EmmQWq3OTZ1vROz+IRp19OhcJS6OZUKCyFbk6FwlI0m1LA5c0TJcty4Pwm4TU5igm5ggWsynN9CcBmVTYck4vgWSgn9/EffJJ1XbvwbNUe37pJHaeOsnH5PEXfnDx7F+v77c/9t6p9ZWOHz6WmjFvU5eaBsUsyopNz2WnybUrIle1tV+0f+uiH2bj3HCEZbmGBu1UuL9P80ayyvb3Nxl3cIvn2Nz73v1hfa4Fk7b/41Mer9q/8y8+wcT/3d36hal/+H59nfTf26D144Tqt9/d91yNs3AtPf7tqP/7we1lff5fk4pMgq2YDbhZa6pBseurMEdZ36xbpBl4WLoZ1ESVNNU/kc/ZqkgTIc89a9boO+uVUKCyFbk6FwlI00tqygM+0qAFYJETdsin3qsnht1/PFg6FUFCMCINe1q+zvqceJ2rowb1ceuUVNu75TaJM3/ryV1jfCKpe+y3uETOG/DcOqOyzKTfpYETFj/+Nn2R9v/avf6Vq74EHz3/4r/+bjfulH/1Q1Z6P+GM7tkjmjqsjuhe/5PP44u99tWonbX4vE4gO+ZNnL1btP/j602zc73/1a1W7KDiNSyDfUA6lp7/2jW+ycatw7WsbPPftgydIBMD36mMf/hAb5/QoV+83v/0c61uBb85FUd27jr42mVLcQ+ajbfIQ0qgUheJtDN2cCoWlaC7HAM7LcS6ceg1QuoRrMQOoM5a4pHGTWewbqweD87IPbVfkusFw4i/88bdY3/EV0nA++eBZ+vviWTZu6zuk1RxnL7O+FAJ3M5Em0smImgQOUVcv43Ry69JLVfuz//zTrG+nT9cuoIzAz37qU2zciWWiidPLz7O+eIO8gtoueSdtiQreP/v3/nHV/kef/gzrm4CYsn6Jzv8vPv0P2LgUnrVXiopsIHJkUOorDPgzy2B9bqxvsL4PnqSxZ9717qq92ONeRpcv/p+q/b7772d9X3yGnmGWNb7itWC0VrxzXk2VsabAjqbz10G/nAqFpdDNqVBYCt2cCoWlaCTkKBvkhgfnOlCSLfW4PJBBbtZsSsdJ75hZKwSzcwsPoQICvcM2D3H4T1/4RtVe7UDZA0H3c5BLlhYXWd9oDGUKPOHFBAmofJA3PFE5uyhIMs4nXC6ZC0Duhv+Vv/VvPsvGfeIvfW/VPneMz/HyNQrmDueoL5/22bjf+fVfrdrdnHsxHV2i5GVTSAR247kX2bgOPHf5zAowty1AJe5UlIWY90kXEJYypy2t8XBEQdk7HjfbnL773qodnbyX9d34ffJAKgMe3F6HJlPKPn1IQ+KuunMcBvrlVCgshW5OhcJSNJtS2uSt4bk8f04OqveyzSmSBwHRWFnYmXJHbPTWqHNIluOcfbmMCNOS/68JQQO+NQbTj6gWFvnkOH5K0Noc6F4paDlSWRfmn4tgbqaIF+p2NDN40Ld6lAd9F5CfZzDg9H2pTTQ6gfIGR3ucXj/x8JmqfX2rz/piEBeShOaRCvNXDH8InBXeCeaTEmmhCFb2wDTmiPN7Af1hMCDvoXuO8fXYA3Hjwp9+h/eV9K56M1YIa4Ikp4etTna70C+nQmEpdHMqFJZCN6dCYSmaTSldksUKkZvW8SmwNoq5LOaBetyZUF8TB59VFnBKLpu6BZSdc/gcMUj2XfeTK9i4v8nG3eqDPLrAzUKtNi1RO+Tz74ApBU0M8b4S2/Q7Eq5gvRZdrwOlAkcDHq2xC0HIwwGXmbs+BV9nKcliY6h/Yowxy/C0V08fY30+VP4eQKXvRCTIGkMFbFNymTaEe8uhqvhoyIPxMX+u54kcvBnN+e5Vkmnnl7jMGcK6/bt//yXWlwV0b57ha1UHqfPA3/ti6mtMKbfjvjcL9MupUFgK3ZwKhaVopLV5CflLC+EhVMKhQlWeQx7beEqUpilPy6xwZWgLMAlH6OULoBVeRNEaRSby7EIeouUOp8YBePBEInK8BdfzS6JPkgatLJF4sCDKCHYhT+se5Ka9tn6TjXOgZMSRFW7C6HXBIyuj88U5jxrZgRxIpYgo+ejHKP9SDB49u0Oeu6e/16/aEyirYIwxHnhQpRlR3HRelOsYQVTKzS3WlyVEJxMw120P+mzcBPL9po4oLWFuH02lFKQxpagpxyChphSF4h0K3ZwKhaVoLseQEHUoRXViBwKPXeFxk4z6VTsD7Z5ktYfRZhWyyjDAcziNSAuiVl/5o2eq9r2L/ByoqUwF5e2AZ9Rci2tyMZ9Rr01ax/k5roE8ukKpODui6vWt9VtVeziktbrnXh5AXMC1Tp5cZX33QjrPi5coP9JU0Kpjd91TtZMRz/t0/oULVfs950izHfp8rTrg1D9q8/scwnNPM3o/hkJbOxrTu5SJwPS2R/R9BFXSXhYlqi/coOMKT1QcBw2+4zXRVYKkoCxtqxClvJpzvFHtrIR+ORUKS6GbU6GwFLo5FQpL0WxKyfpVuyy5fBFg2b8pV6kXkDo/zNCDgl+urhxbE2T16tKpPy6Hvq98nQKv7/mRH2DjxluUvl+WnYhjkpcmjvCEapFM64IsHGYicBfMSVMReJzD2m3euFa1H/7go2zcf/wvVBbi0mleMmLv0Yer9nPPUoTG8+cvsHF//a/+WNUeDrn3kAEZ8TokJDtyRJhtQgiyT/nzdNFLakrni2NutsHfpTDDlVDp2o0oULq1dJKNu/JtSkI2NELmhDmagq83u1bDO8f65LAa0bJJbj0M9MupUFgK3ZwKhaVopLVODFWpDVcnpzHkxUmEczE4p5dwiWYvDFPbx5yL9w2EywqVtwNV0sqQTCKr9/DKVttgzuiJUgc+XLEQav8M6VoPc+ZwGjcaE60Vha3NFExSd58j88numJs6Hn/yo1XbcblHzNVNcopfOXW2av/AuQfYuG0QN/w2r74dBHTfOdCxnQEXWdptWsc05eaNFLx70hTNJbL6Fj53HggwTsDLyCez07UhX7ihgfm7/JmxqutNFasb0GRKccuDHd+lB5zSWoXiHQrdnAqFpdDNqVBYiuYSgFMyHRRiH2cgV41FKbg8IPlrmtQn7jochOodZEKhlTe+T7c3gnyov/yvfp2N+8EHz1XtjnAPTDHoVsicTkjmpcmYZPCyxZcV+xIR1Fu4JHMNIYlXLuSVFkSldHq8svX8UXLnK8Ekcmv9BhvngSyG8rgxxgQgI7bAndGTFcfhuTtTbqYYgHw6nUKJSCFzOhhFIoLnHQjwvzGkOV28fJmNw1zJMvGai26nAY+IqQvqb5QPRVddYrpU1NJRmVOheIdCN6dCYSkaae32HtEFX0SllAX9jjNhIgE6GUCpQKnyRrojKWmBJgwWUc3/n7igipdKcqTeqPIeiSDk60DH5iJuYvAhEHsi1PABeAhFkP9naySoGuTk8QPpVQPzhdyx3S6PXgkjonHzHe6ttbVDpqA2jFte4Xl3pmCeSWP+PDFSJMkhF5BJxDgK2F4MhccXlGPIoaJ0lomyjSAGTV2+Vpchv9Dkap/GlfwcpUO0WUYjOc7BZqFXD6yhmuLZ4nGODLaGKSfwjuXChMZeW3G5XCbsPQD65VQoLIVuToXCUjTS2hi0T/s+wgXmiynrugx+28u8If+PYBtIJSStqIMcVedlJL1EXrxMVbrOLPGKVW1whM/FFQaghXVLansLi2wclqdYXOKO5EwBDBpOR6QiTUdELzdGt1gf3k8Sk6YyFtrDMCItaSSCvv0W5FuCFJdxPmHjSnDUj4X2uoD3YDwBmiy0tRgw74qKbPEEylqAyCIMAqZ0MY+PECOYerW+QlgTmIeQ0FgzDS0E58tAfbdQba1C8Y6Ebk6FwlLo5lQoLEWjzJlCmYVSuknAz7TgezwBr/0cLrEvby2qq2X1YKx63ZDyvgl18oW8Vn9Cst7ckTXW58dkOoiHXP6agodMG+S5IJeeRCQ/bu/yIOf1dSqLuNsnU0eR8LVa65GMOCdy30YRRYpMJzSn8YQn1sI8tmWLrw2W24vBy+jcAzzRmOeRqSmL+b1kUB5wggHVwoRWwPuRCrMWJvwqQfYt9plECugTXkzwcrqi/AW+P2XD+8ejUoQHHOT8jSE4PEtldBbOicN9fUuKfjkVCluhm1OhsBTNphRQy2f72CTt63HBqQNk2zdj9EThbM+4QB08WQ0aHJY9uPZUmAckpUHUUWD59xRo19PPv8D63n+KvGwKny9XCc70I6CTkXC2LnA9RLT16go5enfAkV5Wg17t0Dzuu+8+1oe0FtX8A1GpbLNPpQ9evvwS6zu+SuefnyfaPJlyaux4mJ+HP7NbmyQCYEB4VvB7icG00h/woHLXo3lkrKI5f8eQKDYJOncir498w/IcxDb0hBLr4TZcSwaSHAT9cioUlkI3p0JhKXRzKhSWormyNSTukompCpAB4kIEF0OUyhSCeAshuKL6uikRk1NjVjFmfx7bunM0yRoJRDwMptwFa2dAcjeaS4wxJoSg5Dgl2czZ4yp1v0P31ut1WV8+haRbEPnj+xEbF0NA8ZWbV1nfMpQYxOXYE7lpJ1AfJRCS1FKHTCTtNs038Piz7e/QfY5jblrC18kLaG2SWCQCA1PTMtSRMcaY9W3QKcC9OK587rNhVplTmlK42U9sE4d0A0FA707QYB9xxWXzGRKN6ZdTobAUujkVCkvRSGuHIdAlmbvTo30dC86bQ0SCH0C+lVKkxodzBr78PwFmFsjJE7Q53StdDLYW8wBKjblqYpH7ptWm8z/2vY+zvpvPfbNq+6KydZHQ+eehdJ0sa5fi9SJ+nwtQlboAcSBJOb3eHhKV/cZ5Xg26BK8dD2jb8XkelH33CYqIec85Xt4gGRMVnwK1z2K+VpOcolLaQMmNMSZ1aA12CnpOhcPPURR0rZU5HqUTb9I54dEatxR2OEApzHAlUGBf0Mfa8n37Tlqfr9iF0SGY14p9ohlUhhemMXcGYq5fToXCUujmVCgsRSOt9TEQVmi5HKC1vnDeiKCaWIIpGCNBSUt0UBbVpsFrx4MLOB4/B+Ylkg7KRYLpGaHSssPp2Icff6xqP/AAL2Hg9jer9s3Na6wv8Ik2bsZ0L72ArxVW3ArbvM9zgeItED09EvAq2vd0z1TtR97P8xwdP3m2ak/B2T0bc21tPu1X7fGIr0ELlnUM6UxjESDvQKXvocc1zzuQK2gCIkuZ8JIOd585UbVHInVq1CXvpPEEg9nZMKbdlxp8FG+aSoA0BVTzCnj82uiEj/MIQ/5cWMU6mSRrhgQC+uVUKCyFbk6FwlLo5lQoLEWjzOkFKN+J8mZ4EqHKDoFrdwJQa/s8WsMwmVOUBwB51wUziydkTg+jH0SCpQRKGI4G/ard7nB57ud/8Zeq9sU/+jqfB9zL8dVV1nd9m84ZQHmAkUjslIFHzM51bgZ56IG7q/YemDCigHsjYWB3WvL7bEPg9K1bFLydTLkHz2iP5tvt8Gex1Se5cDel+V/Z5JEt23363erxJGGZRzJ4ASFIy4tcFjOQqzYXnj8/9Ym/VbV/9TOfqdq+x8fhi+uWsuQHypW8py7YWnoOsZILMsEXBFVjMrRQmEfQi05GGc3i46RfToXCUujmVCgsRSOtTVKsSlVPa2NBeQuHqGcE5Q2izhE2LoTcOr5fPxW8VtN/k1yYDlwPc6fSHP/ZP/0nbNzeiGjKYMiDiwNIqBuK0gHHFojGbW/ScVNP0FpwlL61x4OLkxcpB+1qF3LHXlln4+ahspiIXTbPrxNVnkDeJyOc1re2ifIeO3aC9V3b2KD5g3nj2hVeqezMSTLpTERla6z2tdihay/3hIkBRJNYiETdRfJK++mf+WTV/re/8Vl+Dihr4cgqY7g+t1FNHcHMLCInFOanzcGryxfJdfE1kKagUh3fFYq3L3RzKhSWQjenQmEpGmXOMcgvUl2do2eSyEuagByBsmpQcvOAcTAxFTelOEDS0USSpzzCAeWcOObyHCYo+/jHf6xqt3vzbNwu5JJdXlpkfdewGnTJk4stQjRLZ4Vc2a4NuAkDU9D2R1w+urbxctU+cYTkyl7E16M9ogWfCBPJ0ePHqvYr10hG3OpzGbw3R8mznrt1ifUZcJ/E6thb4l6OguzedcR6LJA5abVH7dDw4PMUzAqOz81aSytkrhpCLZq/9hM/ycb97uf+M83XyFop8LtB5sR2sT+bAJ2/Pt2yKSBaJhOlMDHXsyNzNjuv/13UL6dCYSl0cyoUlqI5KgW+0oXDh2aQH3UqKvrmkAunBZ/vMu+zcXs79KmXOXM80EM7Hp3/vjPcBPDDP/T9VfuhDzzC+q5DqYMhBBP7IY+miPbIjPClL/1P1rcIETGlWAMXuM8aRJS0W9z7ZgjrM9ri1OdqQvf54nUyn3QWeFXqIiXPnHaLU8ELNy/S+cEU5HlcjBiO+1V7POaRIt0umXta4I3zrpM8GHoppPOfXFlkfXOQT6flY0Vzvh65Q/PavcbzIblQ3iCEPL5rp3lpxk/+XfLqcgpO3184TwHy33rmO6xva4vMTtMRUHbBan2P5ow5iY3huXsLyFUbi+rbLLBbfAZlnuaDoF9OhcJS6OZUKCyFbk6FwlI0ypztFslm+9yNQJ0vS59hvtsEVMjz81zW+5Ef/ctV+8QJLks+9NB7q7YLMtDO1gYb1wb3wFaLu4nF8c2q3euS+SQVau3FEKLex5usD8tfdBa5XByCnByBS9qKSAQ2D0nJFh4+x/peWafrXQKZc2Oby1HDMcn4kTAdrC3SvR05QyUMxyN+jhhkp2yOR5QcOUKulavLZNJZiPhzX+qSLBZFPIHYwhw939GAMi3EuTDDQcKsVNRi2YXME+2QZOtUmNrwsxJG3DT27gc+QO13f4D1LYGprB3Rvdy4foWNe+bpP6naf/CHf8j6tndIXi/BpuiJTAguyqCijkohS78cAP1yKhSWQjenQmEpGmlta47ojUyHjxEPifDa//lf+PtV+7u/58NVe5xw+ru5SREZjkhl34EAYoxY2c642SaBnKgbGzyQ2UCisRwSVXkBv5cOlEiYF4HYow3yuJlf5lWvMdAWNeNdQVkKEAkCEcnRWlus2vesUUSGKwLTByMomyeovQGa2A5pIgtrXFTowL2FLU7RpzHmqoWkYyL37WKHzCCTVEQjwfPtdGlNExHpk0Pg+GKXn9/kdI5uh96/wVQm4IISHSL6ZnWNzD9dcX4HPHow2uS9Dz7Exj3wfhKrfuZTP8f6XDAF/eZv/nbV/sLvfpFfC7aXTLtb7svTvB/65VQoLIVuToXCUjTS2k98koJdU0FJz5w5VbXLgn+z0Rk4hHIDbotr1bZ3+3AMdwh3fKzgBR2uVHORa8feLnd8DwKiND4E5yYF9/gI54hOrg94X7xLlGxV5AZqGXTqJ9oZtsT/PJizH4rq2FDZue0jlefXOrlG6xicfhfrmwINDcGzRTpzt9rUF4sK4VlB5y9Bu+oLzbAH84qER3gGDu1DVvKCixGDXdLkznU5vc4TqBAOazXJ5LVorebmuDcVRjmPhDY4A7HCg2/T3p7I8Qti0I11nkfp2DF6Xz7y5Eeq9g8/9RSfByz/bCHeHPrlVCgshW5OhcJS6OZUKCxFo8x57DiViUuEF1AEAcuDMa+7gVLKAEwfoz7n/1g1eijKye3+6bNV24MSd902j7QoJySrzoPpxxhjipz+92QgL3oiOHwAkRBBi58jc8hrJx5wmXa+RXU9HDD3uIbLxTnkgY1CbiIJIci5hWp/GeCbQa7Xkj+L5Tkuy7+GVJi4MPjcFRmnIjALsTmKIHu8MylHxRj5A+aGWFTAHsA6HoHaKMYY0wL9RQAeQsWEvzttKKWYpfz8UJrGHFvl8uhoBInYxnRcFHLZFxO7JUI+fwVyD7ugT9hL+PuBVddlBfYSZNr7zcHQL6dCYSl0cyoUlqKR1n7zwgtVuxDOy0vgHL0o8u5sbZHz8oXx5ardi7jjewkq7/l5EVwMBCoCx/HBiFMHA+aCdkfk1gXmhuUdZGRtAA7LH/yu72Z9f3yTcvx4IncvOnDnaOIpOPX2gLrKcnXtNjh3g0kkTTklDYAmBqJUw17Wr9pd8Mwp8/o8qjIvjgseWhjIEAWchmfgSRQLUQc9plLw4ClFPld8MIGwjKUToN6Y10isG5rvSpGbKmrT85xM+BwLEG9wHX2fnz8A05sn+8C7Cj2rIhFk77Igc5nLyLwu9MupUFgK3ZwKhaXQzalQWIpGmXPxxPGqLUvSGeDueyK36XREqmf0zE+FWn4eAqBL0deG5E4p8Poi52rtEEq/+z4XYEAkZOXBSyFwTSDJ1mOP/3nW9/SXP1+1pwlX53dzWpNOm1y6pGyagio+EMJGALJqXtC4jpD1RkNa492dHdaH9zbqiHJ7gMVF0hO4Qu72IXlZ0VAar4TopFaXB2xP0c8SFn+/zEnP0xP1Z6YxBTJ7EGS/r0RfiX1Sjqd5TcfbrG9ujkw3CbgYTkbcRa8s0ZzEp4/VCAMwHKZD7gJYsDk3yZzvNgdBv5wKhaXQzalQWIrmvLWGaEohgoSTKZZIEAHEHVLnRxGoqyNZPRg9Z3iQs4E8P6g2bwsPG3C02GdiQPqKqfdltWMD1PjoEe5t05kjmuj7nPpg5ejxgPp6c5xa9qACdJZxE0nkYQVvmm+c8iidecjPE4hAXaTNaAZBM40xPIi6JWizB9TQbcipioHvucOfewbrncD8i4KLIh0wOQTCU2lriwLw348eUxuCMoKpxnGEPQYiVpKEr+METDVIa0PxXmFla0dEQkUgOhQsLzN/tq5087pN6JdTobAUujkVCkvRTGuBtuVCZZWBk3MutGWM1oJ3T1IIbWebKGSnxTV/mCfHx2pQsqIUUJpCpLxEjxV+jNCYAqWOM06DTp89XbWnu5dYnws0JoZgdJffpilhXi3hYI15bFrgxdRtc2qcGhrnin+prRZ5VyHFlevRgrxE3YhT3jCieWXwrFOhaZ3CfU5i7n2DjvZMuyq013NQgcwTuaOGQxAd4BzoZWWMYdYC6e2UQ19HaK+R1h5fpYpmu32u1c0h0ABzBhljzAQCGVCiCzxOjbHCniMnOcN3Ub+cCoWl0M2pUFgK3ZwKhaVolDmvQ1BsINTrLZBRehHn5B347WH1YyF7gDhq+v0brK+A/xs+lDqIhBnh/PkLVfvE2l2sD60W0sOEAfKqxotcFls5+0DV3ju/zvo6Lsmnu3ibQqZFT5GpSJSGUR+oss9ERAmWuOiJEoBoBkjhOUmZm3msBHwdc3g2GIGUC1OEg5E4os/FaBOXxi33eO7YBHLrdjs8UikLIfoGzy1y0xo04wiTTp5iEDU//xTMg9v9Pp1f5GVOwGMtyHnSty68myGs41gElU9BXpfyfy6e70HQL6dCYSl0cyoUlqKR1s63iRJIqhaFGKjK97gP1ZsSqGzV7fF8Mf0+9/pAZHly4N9jUUV7Cl4eRSGpK3qROAf+3RhjXKB/tzZ5lbFHP/RY1f69C19lfQFQ0lUIFveSesriSgdooD4sKNsTwcXgmTMec1uN9IKpziFsLj54UElPJTQxIAWTdAw9hGQgdg5jA6Ch8lpoTsKyDcYYs8vmDNXNhVjluNgnIpfhNUDvLGOMGUHlNbzno0eX2TgMQkinfH2zAPIyg0gRlnw9sOqHfBaNYtZrx7zuCIVC8WcC3ZwKhaXQzalQWIpmmRMCU1ORuxNpvowGQZV9Z4HOMRlxVbPD1P58KmFIF8Br74qaFsePQ0B4xF3jJpDTtmwIIPZBjR6KJGQRVL2O2jyn7VwL1fkkt5ahkPVA3uCJxoyB9KimANmsyIXqHc4h7xNNJnhvKB8aw80zUk5FubDVItOHXCuU79o+N+mgeQCPKwJ+L/i+SJORx4Li6Rw4J2OMyVPUNdRHf+QiUVoI6++C+Whvl79XbbDzlSLqCueFOg/UtRhjDHq1Nq1jHfTLqVBYCt2cCoWlcGZR6SoUirce+uVUKCyFbk6FwlLo5lQoLIVuToXCUujmVCgshW5OhcJS/F9fW39ZVVKmIgAAAABJRU5ErkJggg==\n",
|
| 67 |
+
"text/plain": [
|
| 68 |
+
"<Figure size 432x288 with 1 Axes>"
|
| 69 |
+
]
|
| 70 |
+
},
|
| 71 |
+
"metadata": {
|
| 72 |
+
"needs_background": "light"
|
| 73 |
+
},
|
| 74 |
+
"output_type": "display_data"
|
| 75 |
+
}
|
| 76 |
+
],
|
| 77 |
+
"source": [
|
| 78 |
+
"for x in dataset:\n",
|
| 79 |
+
" plt.axis(\"off\")\n",
|
| 80 |
+
" plt.imshow((x.numpy() * 255).astype(\"int32\")[0])\n",
|
| 81 |
+
" break\n"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": 9,
|
| 87 |
+
"id": "2dea3fa4-1ac8-4889-8b52-8ec3e2ac7c9e",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"outputs": [
|
| 90 |
+
{
|
| 91 |
+
"name": "stdout",
|
| 92 |
+
"output_type": "stream",
|
| 93 |
+
"text": [
|
| 94 |
+
"Model: \"discriminator\"\n",
|
| 95 |
+
"_________________________________________________________________\n",
|
| 96 |
+
" Layer (type) Output Shape Param # \n",
|
| 97 |
+
"=================================================================\n",
|
| 98 |
+
" conv2d (Conv2D) (None, 32, 32, 64) 3136 \n",
|
| 99 |
+
" \n",
|
| 100 |
+
" leaky_re_lu (LeakyReLU) (None, 32, 32, 64) 0 \n",
|
| 101 |
+
" \n",
|
| 102 |
+
" conv2d_1 (Conv2D) (None, 16, 16, 128) 131200 \n",
|
| 103 |
+
" \n",
|
| 104 |
+
" leaky_re_lu_1 (LeakyReLU) (None, 16, 16, 128) 0 \n",
|
| 105 |
+
" \n",
|
| 106 |
+
" conv2d_2 (Conv2D) (None, 8, 8, 128) 262272 \n",
|
| 107 |
+
" \n",
|
| 108 |
+
" leaky_re_lu_2 (LeakyReLU) (None, 8, 8, 128) 0 \n",
|
| 109 |
+
" \n",
|
| 110 |
+
" flatten (Flatten) (None, 8192) 0 \n",
|
| 111 |
+
" \n",
|
| 112 |
+
" dropout (Dropout) (None, 8192) 0 \n",
|
| 113 |
+
" \n",
|
| 114 |
+
" dense (Dense) (None, 1) 8193 \n",
|
| 115 |
+
" \n",
|
| 116 |
+
"=================================================================\n",
|
| 117 |
+
"Total params: 404,801\n",
|
| 118 |
+
"Trainable params: 404,801\n",
|
| 119 |
+
"Non-trainable params: 0\n",
|
| 120 |
+
"_________________________________________________________________\n"
|
| 121 |
+
]
|
| 122 |
+
}
|
| 123 |
+
],
|
| 124 |
+
"source": [
|
| 125 |
+
"discriminator = keras.Sequential(\n",
|
| 126 |
+
" [\n",
|
| 127 |
+
" keras.Input(shape=(64, 64, 3)),\n",
|
| 128 |
+
" layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\"),\n",
|
| 129 |
+
" layers.LeakyReLU(alpha=0.2),\n",
|
| 130 |
+
" layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\"),\n",
|
| 131 |
+
" layers.LeakyReLU(alpha=0.2),\n",
|
| 132 |
+
" layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\"),\n",
|
| 133 |
+
" layers.LeakyReLU(alpha=0.2),\n",
|
| 134 |
+
" layers.Flatten(),\n",
|
| 135 |
+
" layers.Dropout(0.2),\n",
|
| 136 |
+
" layers.Dense(1, activation=\"sigmoid\"),\n",
|
| 137 |
+
" ],\n",
|
| 138 |
+
" name=\"discriminator\",\n",
|
| 139 |
+
")\n",
|
| 140 |
+
"discriminator.summary()\n"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"execution_count": 10,
|
| 146 |
+
"id": "2a2507b1-9ad7-48f3-8f90-1052ac67886b",
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"outputs": [
|
| 149 |
+
{
|
| 150 |
+
"name": "stdout",
|
| 151 |
+
"output_type": "stream",
|
| 152 |
+
"text": [
|
| 153 |
+
"Model: \"generator\"\n",
|
| 154 |
+
"_________________________________________________________________\n",
|
| 155 |
+
" Layer (type) Output Shape Param # \n",
|
| 156 |
+
"=================================================================\n",
|
| 157 |
+
" dense_1 (Dense) (None, 8192) 1056768 \n",
|
| 158 |
+
" \n",
|
| 159 |
+
" reshape (Reshape) (None, 8, 8, 128) 0 \n",
|
| 160 |
+
" \n",
|
| 161 |
+
" conv2d_transpose (Conv2DTra (None, 16, 16, 128) 262272 \n",
|
| 162 |
+
" nspose) \n",
|
| 163 |
+
" \n",
|
| 164 |
+
" leaky_re_lu_3 (LeakyReLU) (None, 16, 16, 128) 0 \n",
|
| 165 |
+
" \n",
|
| 166 |
+
" conv2d_transpose_1 (Conv2DT (None, 32, 32, 256) 524544 \n",
|
| 167 |
+
" ranspose) \n",
|
| 168 |
+
" \n",
|
| 169 |
+
" leaky_re_lu_4 (LeakyReLU) (None, 32, 32, 256) 0 \n",
|
| 170 |
+
" \n",
|
| 171 |
+
" conv2d_transpose_2 (Conv2DT (None, 64, 64, 512) 2097664 \n",
|
| 172 |
+
" ranspose) \n",
|
| 173 |
+
" \n",
|
| 174 |
+
" leaky_re_lu_5 (LeakyReLU) (None, 64, 64, 512) 0 \n",
|
| 175 |
+
" \n",
|
| 176 |
+
" conv2d_3 (Conv2D) (None, 64, 64, 3) 38403 \n",
|
| 177 |
+
" \n",
|
| 178 |
+
"=================================================================\n",
|
| 179 |
+
"Total params: 3,979,651\n",
|
| 180 |
+
"Trainable params: 3,979,651\n",
|
| 181 |
+
"Non-trainable params: 0\n",
|
| 182 |
+
"_________________________________________________________________\n"
|
| 183 |
+
]
|
| 184 |
+
}
|
| 185 |
+
],
|
| 186 |
+
"source": [
|
| 187 |
+
"latent_dim = 128\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"generator = keras.Sequential(\n",
|
| 190 |
+
" [\n",
|
| 191 |
+
" keras.Input(shape=(latent_dim,)),\n",
|
| 192 |
+
" layers.Dense(8 * 8 * 128),\n",
|
| 193 |
+
" layers.Reshape((8, 8, 128)),\n",
|
| 194 |
+
" layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding=\"same\"),\n",
|
| 195 |
+
" layers.LeakyReLU(alpha=0.2),\n",
|
| 196 |
+
" layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding=\"same\"),\n",
|
| 197 |
+
" layers.LeakyReLU(alpha=0.2),\n",
|
| 198 |
+
" layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding=\"same\"),\n",
|
| 199 |
+
" layers.LeakyReLU(alpha=0.2),\n",
|
| 200 |
+
" layers.Conv2D(3, kernel_size=5, padding=\"same\", activation=\"sigmoid\"),\n",
|
| 201 |
+
" ],\n",
|
| 202 |
+
" name=\"generator\",\n",
|
| 203 |
+
")\n",
|
| 204 |
+
"generator.summary()\n"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "markdown",
|
| 209 |
+
"id": "88691fae-b91b-40ad-9ce3-765777608598",
|
| 210 |
+
"metadata": {},
|
| 211 |
+
"source": [
|
| 212 |
+
"# Override train_step"
|
| 213 |
+
]
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "code",
|
| 217 |
+
"execution_count": 11,
|
| 218 |
+
"id": "0cd186bd-94f4-4f3b-9937-5062bb568415",
|
| 219 |
+
"metadata": {},
|
| 220 |
+
"outputs": [],
|
| 221 |
+
"source": [
|
| 222 |
+
"class GAN(keras.Model):\n",
|
| 223 |
+
" def __init__(self, discriminator, generator, latent_dim):\n",
|
| 224 |
+
" super(GAN, self).__init__()\n",
|
| 225 |
+
" self.discriminator = discriminator\n",
|
| 226 |
+
" self.generator = generator\n",
|
| 227 |
+
" self.latent_dim = latent_dim\n",
|
| 228 |
+
"\n",
|
| 229 |
+
" def compile(self, d_optimizer, g_optimizer, loss_fn):\n",
|
| 230 |
+
" super(GAN, self).compile()\n",
|
| 231 |
+
" self.d_optimizer = d_optimizer\n",
|
| 232 |
+
" self.g_optimizer = g_optimizer\n",
|
| 233 |
+
" self.loss_fn = loss_fn\n",
|
| 234 |
+
" self.d_loss_metric = keras.metrics.Mean(name=\"d_loss\")\n",
|
| 235 |
+
" self.g_loss_metric = keras.metrics.Mean(name=\"g_loss\")\n",
|
| 236 |
+
"\n",
|
| 237 |
+
" @property\n",
|
| 238 |
+
" def metrics(self):\n",
|
| 239 |
+
" return [self.d_loss_metric, self.g_loss_metric]\n",
|
| 240 |
+
"\n",
|
| 241 |
+
" def train_step(self, real_images):\n",
|
| 242 |
+
" # Sample random points in the latent space\n",
|
| 243 |
+
" batch_size = tf.shape(real_images)[0]\n",
|
| 244 |
+
" random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
|
| 245 |
+
"\n",
|
| 246 |
+
" # Decode them to fake images\n",
|
| 247 |
+
" generated_images = self.generator(random_latent_vectors)\n",
|
| 248 |
+
"\n",
|
| 249 |
+
" # Combine them with real images\n",
|
| 250 |
+
" combined_images = tf.concat([generated_images, real_images], axis=0)\n",
|
| 251 |
+
"\n",
|
| 252 |
+
" # Assemble labels discriminating real from fake images\n",
|
| 253 |
+
" labels = tf.concat(\n",
|
| 254 |
+
" [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n",
|
| 255 |
+
" )\n",
|
| 256 |
+
" # Add random noise to the labels - important trick!\n",
|
| 257 |
+
" labels += 0.05 * tf.random.uniform(tf.shape(labels))\n",
|
| 258 |
+
"\n",
|
| 259 |
+
" # Train the discriminator\n",
|
| 260 |
+
" with tf.GradientTape() as tape:\n",
|
| 261 |
+
" predictions = self.discriminator(combined_images)\n",
|
| 262 |
+
" d_loss = self.loss_fn(labels, predictions)\n",
|
| 263 |
+
" grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n",
|
| 264 |
+
" self.d_optimizer.apply_gradients(\n",
|
| 265 |
+
" zip(grads, self.discriminator.trainable_weights)\n",
|
| 266 |
+
" )\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" # Sample random points in the latent space\n",
|
| 269 |
+
" random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
|
| 270 |
+
"\n",
|
| 271 |
+
" # Assemble labels that say \"all real images\"\n",
|
| 272 |
+
" misleading_labels = tf.zeros((batch_size, 1))\n",
|
| 273 |
+
"\n",
|
| 274 |
+
" # Train the generator (note that we should *not* update the weights\n",
|
| 275 |
+
" # of the discriminator)!\n",
|
| 276 |
+
" with tf.GradientTape() as tape:\n",
|
| 277 |
+
" predictions = self.discriminator(self.generator(random_latent_vectors))\n",
|
| 278 |
+
" g_loss = self.loss_fn(misleading_labels, predictions)\n",
|
| 279 |
+
" grads = tape.gradient(g_loss, self.generator.trainable_weights)\n",
|
| 280 |
+
" self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" # Update metrics\n",
|
| 283 |
+
" self.d_loss_metric.update_state(d_loss)\n",
|
| 284 |
+
" self.g_loss_metric.update_state(g_loss)\n",
|
| 285 |
+
" return {\n",
|
| 286 |
+
" \"d_loss\": self.d_loss_metric.result(),\n",
|
| 287 |
+
" \"g_loss\": self.g_loss_metric.result(),\n",
|
| 288 |
+
" }\n"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "markdown",
|
| 293 |
+
"id": "6ccd520d-d223-4447-92c8-24299d7b1f5e",
|
| 294 |
+
"metadata": {},
|
| 295 |
+
"source": [
|
| 296 |
+
"## Create a callback that periodically saves generated images"
|
| 297 |
+
]
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"cell_type": "code",
|
| 301 |
+
"execution_count": 12,
|
| 302 |
+
"id": "621b2abf-e343-47b8-82dd-5103a738f249",
|
| 303 |
+
"metadata": {},
|
| 304 |
+
"outputs": [],
|
| 305 |
+
"source": [
|
| 306 |
+
"class GANMonitor(keras.callbacks.Callback):\n",
|
| 307 |
+
" def __init__(self, num_img=3, latent_dim=128):\n",
|
| 308 |
+
" self.num_img = num_img\n",
|
| 309 |
+
" self.latent_dim = latent_dim\n",
|
| 310 |
+
"\n",
|
| 311 |
+
" def on_epoch_end(self, epoch, logs=None):\n",
|
| 312 |
+
" random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))\n",
|
| 313 |
+
" generated_images = self.model.generator(random_latent_vectors)\n",
|
| 314 |
+
" generated_images *= 255\n",
|
| 315 |
+
" generated_images.numpy()\n",
|
| 316 |
+
" for i in range(self.num_img):\n",
|
| 317 |
+
" img = keras.preprocessing.image.array_to_img(generated_images[i])\n",
|
| 318 |
+
" img.save(\"generated_img_%03d_%d.png\" % (epoch, i))\n"
|
| 319 |
+
]
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"cell_type": "markdown",
|
| 323 |
+
"id": "0588f900-8567-4d3d-87e0-5ae559d85c36",
|
| 324 |
+
"metadata": {},
|
| 325 |
+
"source": [
|
| 326 |
+
"## Train the end-to-end model"
|
| 327 |
+
]
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"cell_type": "code",
|
| 331 |
+
"execution_count": 13,
|
| 332 |
+
"id": "1c771d14-b327-40ab-8458-4eaf73c16a28",
|
| 333 |
+
"metadata": {},
|
| 334 |
+
"outputs": [
|
| 335 |
+
{
|
| 336 |
+
"name": "stdout",
|
| 337 |
+
"output_type": "stream",
|
| 338 |
+
"text": [
|
| 339 |
+
" 5/6332 [..............................] - ETA: 16:15:50 - d_loss: 0.6776 - g_loss: 0.7854"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"ename": "KeyboardInterrupt",
|
| 344 |
+
"evalue": "",
|
| 345 |
+
"output_type": "error",
|
| 346 |
+
"traceback": [
|
| 347 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 348 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 349 |
+
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15592/2002100634.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 8\u001b[0m )\n\u001b[0;32m 9\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m gan.fit(\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[0mdataset\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mGANMonitor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_img\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlatent_dim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlatent_dim\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m )\n",
|
| 350 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\keras\\utils\\traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 62\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 63\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 64\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 65\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# pylint: disable=broad-except\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 66\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
| 351 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m 1382\u001b[0m _r=1):\n\u001b[0;32m 1383\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mon_train_batch_begin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1384\u001b[1;33m \u001b[0mtmp_logs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1385\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshould_sync\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1386\u001b[0m \u001b[0mcontext\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0masync_wait\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
| 352 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tensorflow\\python\\util\\traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 148\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 149\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 150\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 151\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 152\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
| 353 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tensorflow\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m 913\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 914\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mOptionalXlaContext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_jit_compile\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 915\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 916\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 917\u001b[0m \u001b[0mnew_tracing_count\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
| 354 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tensorflow\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m_call\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m 945\u001b[0m \u001b[1;31m# In this case we have created variables on the first call, so we run the\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 946\u001b[0m \u001b[1;31m# defunned version which is guaranteed to never create variables.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 947\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_stateless_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# pylint: disable=not-callable\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 948\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_stateful_fn\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 949\u001b[0m \u001b[1;31m# Release the lock early so that multiple threads can perform the call\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
| 355 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 2954\u001b[0m (graph_function,\n\u001b[0;32m 2955\u001b[0m filtered_flat_args) = self._maybe_define_function(args, kwargs)\n\u001b[1;32m-> 2956\u001b[1;33m return graph_function._call_flat(\n\u001b[0m\u001b[0;32m 2957\u001b[0m filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access\n\u001b[0;32m 2958\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
|
| 356 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_call_flat\u001b[1;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[0;32m 1851\u001b[0m and executing_eagerly):\n\u001b[0;32m 1852\u001b[0m \u001b[1;31m# No tape is watching; skip to running the function.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1853\u001b[1;33m return self._build_call_outputs(self._inference_function.call(\n\u001b[0m\u001b[0;32m 1854\u001b[0m ctx, args, cancellation_manager=cancellation_manager))\n\u001b[0;32m 1855\u001b[0m forward_backward = self._select_forward_and_backward_functions(\n",
|
| 357 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36mcall\u001b[1;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[0;32m 497\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0m_InterpolateFunctionError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 498\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcancellation_manager\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 499\u001b[1;33m outputs = execute.execute(\n\u001b[0m\u001b[0;32m 500\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msignature\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 501\u001b[0m \u001b[0mnum_outputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_outputs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
| 358 |
+
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tensorflow\\python\\eager\\execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[1;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[0;32m 52\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 53\u001b[0m \u001b[0mctx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 54\u001b[1;33m tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0m\u001b[0;32m 55\u001b[0m inputs, attrs, num_outputs)\n\u001b[0;32m 56\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
| 359 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
| 360 |
+
]
|
| 361 |
+
}
|
| 362 |
+
],
|
| 363 |
+
"source": [
|
| 364 |
+
"epochs = 1 # In practice, use ~100 epochs\n",
|
| 365 |
+
"\n",
|
| 366 |
+
"gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)\n",
|
| 367 |
+
"gan.compile(\n",
|
| 368 |
+
" d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),\n",
|
| 369 |
+
" g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),\n",
|
| 370 |
+
" loss_fn=keras.losses.BinaryCrossentropy(),\n",
|
| 371 |
+
")\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"gan.fit(\n",
|
| 374 |
+
" dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]\n",
|
| 375 |
+
")\n"
|
| 376 |
+
]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"cell_type": "code",
|
| 380 |
+
"execution_count": null,
|
| 381 |
+
"id": "ce3c558b-a39a-48f5-b109-d077057b3dcf",
|
| 382 |
+
"metadata": {},
|
| 383 |
+
"outputs": [],
|
| 384 |
+
"source": []
|
| 385 |
+
}
|
| 386 |
+
],
|
| 387 |
+
"metadata": {
|
| 388 |
+
"kernelspec": {
|
| 389 |
+
"display_name": "Python 3 (ipykernel)",
|
| 390 |
+
"language": "python",
|
| 391 |
+
"name": "python3"
|
| 392 |
+
},
|
| 393 |
+
"language_info": {
|
| 394 |
+
"codemirror_mode": {
|
| 395 |
+
"name": "ipython",
|
| 396 |
+
"version": 3
|
| 397 |
+
},
|
| 398 |
+
"file_extension": ".py",
|
| 399 |
+
"mimetype": "text/x-python",
|
| 400 |
+
"name": "python",
|
| 401 |
+
"nbconvert_exporter": "python",
|
| 402 |
+
"pygments_lexer": "ipython3",
|
| 403 |
+
"version": "3.9.6"
|
| 404 |
+
}
|
| 405 |
+
},
|
| 406 |
+
"nbformat": 4,
|
| 407 |
+
"nbformat_minor": 5
|
| 408 |
+
}
|