Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import argparse | |
| import itertools | |
| import time | |
| import ast | |
| import re | |
| from tracr.compiler import compiling | |
| from typing import get_args | |
| import inspect | |
| import pickle | |
| import base64 | |
| from abstract_syntax_tree import * | |
| from python_embedded_rasp import * | |
| from rasp_synthesizer import * | |
| # HELPER FUNCTIONS | |
| def download_model(model): | |
| output_model = pickle.dumps(model) | |
| b64 = base64.b64encode(output_model).decode() | |
| href = f'<a href="data:file/output_model;base64,{b64}" download="model_params.pkl">Download Haiku Model Parameters in a .pkl File</a>' | |
| st.markdown(href, unsafe_allow_html=True) | |
| # APP DRIVER CODE | |
| st.title("Bottom Up Synthesis for RASP") | |
| max_weight = st.slider("Choose the maximum program weight to search for (~ size of transformer)", 2, 20, 15) | |
| default_example = "[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]" | |
| example_text = st.text_input(label = "Provide Input and Output Examples", value = default_example) | |
| inputs, outs = analyze_examples(example_text) | |
| examples = list(zip(inputs, outs)) | |
| st.write("Received the following input and output examples:") | |
| st.write(examples) | |
| max_seq_len = 0 | |
| for i in inputs: | |
| max_seq_len = max(len(i), max_seq_len) | |
| vocab = get_vocabulary(examples) | |
| st.subheader("Synthesis Configuration") | |
| st.write("Running synthesizer with") | |
| st.write("Vocab: {}".format(vocab)) | |
| st.write("Max sequence length: {}".format(max_seq_len)) | |
| st.write("Max weight: {}".format(max_weight)) | |
| program, approx_programs = run_synthesizer(examples, max_weight) | |
| st.subheader("Synthesis Results:") | |
| st.caption("May take a while.") | |
| if program: | |
| algorithm = program.to_python() | |
| bos = "BOS" | |
| model = compiling.compile_rasp_to_model( | |
| algorithm, | |
| vocab=vocab, | |
| max_seq_len=max_seq_len, | |
| compiler_bos=bos, | |
| ) | |
| def extract_layer_number(s): | |
| match = re.search(r'layer_(\d+)', s) | |
| if match: | |
| return int(match.group(1)) + 1 | |
| else: | |
| return None | |
| layer_num = extract_layer_number(list(model.params.keys())[-1]) | |
| st.write(f"The following program has been compiled to a transformer with {layer_num} layer(s):") | |
| st.write(program.str()) | |
| st.write("Here is a model download link: ") | |
| hk_model = model.params | |
| download_model(hk_model) | |
| else: | |
| st.write("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs)) |