Update app.py
Browse files
app.py
CHANGED
|
@@ -13,7 +13,7 @@ generator = Generator()
|
|
| 13 |
variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
|
| 14 |
|
| 15 |
fs = HfFileSystem()
|
| 16 |
-
with fs.open("PrakhAI/
|
| 17 |
g_state = from_state_dict(variables, msgpack_restore(f.read()))
|
| 18 |
|
| 19 |
def sample_latent(batch, key):
|
|
@@ -22,12 +22,12 @@ def sample_latent(batch, key):
|
|
| 22 |
def to_img(normalized):
|
| 23 |
return ((normalized+1)*255./2.).astype(np.uint8)
|
| 24 |
|
| 25 |
-
st.write("The model and its details are at https://huggingface.co/PrakhAI/
|
| 26 |
if st.button('Generate Random'):
|
| 27 |
st.session_state['generate'] = None
|
| 28 |
|
| 29 |
-
ROWS =
|
| 30 |
-
COLUMNS =
|
| 31 |
|
| 32 |
def set_latent(latent):
|
| 33 |
st.session_state['generate'] = latent
|
|
@@ -40,9 +40,9 @@ if 'generate' in st.session_state:
|
|
| 40 |
if "similarity" not in st.session_state:
|
| 41 |
st.session_state["similarity"] = 0.5
|
| 42 |
similarity = st.number_input(label="Mutation (for \"Generate Similar\") - lower value generates more similar images", key="similarity", min_value=0.01, max_value=1.0)
|
| 43 |
-
latents = np.repeat([previous], repeats=
|
| 44 |
-
|
| 45 |
-
img = np.array(to_img(
|
| 46 |
for row in range(ROWS):
|
| 47 |
with st.container():
|
| 48 |
for (col_idx, col) in enumerate(st.columns(COLUMNS)):
|
|
|
|
| 13 |
variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
|
| 14 |
|
| 15 |
fs = HfFileSystem()
|
| 16 |
+
with fs.open("PrakhAI/AIPlane3/g_checkpoint_200000.msgpack", "rb") as f:
|
| 17 |
g_state = from_state_dict(variables, msgpack_restore(f.read()))
|
| 18 |
|
| 19 |
def sample_latent(batch, key):
|
|
|
|
| 22 |
def to_img(normalized):
|
| 23 |
return ((normalized+1)*255./2.).astype(np.uint8)
|
| 24 |
|
| 25 |
+
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane3")
|
| 26 |
if st.button('Generate Random'):
|
| 27 |
st.session_state['generate'] = None
|
| 28 |
|
| 29 |
+
ROWS = 2
|
| 30 |
+
COLUMNS = 2
|
| 31 |
|
| 32 |
def set_latent(latent):
|
| 33 |
st.session_state['generate'] = latent
|
|
|
|
| 40 |
if "similarity" not in st.session_state:
|
| 41 |
st.session_state["similarity"] = 0.5
|
| 42 |
similarity = st.number_input(label="Mutation (for \"Generate Similar\") - lower value generates more similar images", key="similarity", min_value=0.01, max_value=1.0)
|
| 43 |
+
latents = np.repeat([previous], repeats=4, axis=0) + similarity * latents
|
| 44 |
+
g_out = generator.apply({'params': g_state['params']}, latents)
|
| 45 |
+
img = np.array(to_img(g_out))
|
| 46 |
for row in range(ROWS):
|
| 47 |
with st.container():
|
| 48 |
for (col_idx, col) in enumerate(st.columns(COLUMNS)):
|