Upload 49 files
Browse files- .gitattributes +21 -0
- frontend/README.md +126 -0
- frontend/configs/webpack/common.js +85 -0
- frontend/configs/webpack/dev.js +25 -0
- frontend/configs/webpack/prod.js +22 -0
- frontend/model/sam_onnx_example.onnx +3 -0
- frontend/model/sam_onnx_quantized_example.onnx +3 -0
- frontend/package.json +64 -0
- frontend/postcss.config.js +10 -0
- frontend/src/App.tsx +306 -0
- frontend/src/assets/examples/1.jpg +3 -0
- frontend/src/assets/examples/10.jpg +3 -0
- frontend/src/assets/examples/11.jpg +3 -0
- frontend/src/assets/examples/12.jpg +3 -0
- frontend/src/assets/examples/13.jpg +3 -0
- frontend/src/assets/examples/14.jpg +3 -0
- frontend/src/assets/examples/15.jpg +3 -0
- frontend/src/assets/examples/16.jpg +3 -0
- frontend/src/assets/examples/17.jpg +3 -0
- frontend/src/assets/examples/18.jpg +3 -0
- frontend/src/assets/examples/19.jpg +3 -0
- frontend/src/assets/examples/2.jpg +3 -0
- frontend/src/assets/examples/20.jpg +3 -0
- frontend/src/assets/examples/21.jpg +3 -0
- frontend/src/assets/examples/3.jpg +3 -0
- frontend/src/assets/examples/4.jpg +3 -0
- frontend/src/assets/examples/5.jpg +3 -0
- frontend/src/assets/examples/6.jpg +3 -0
- frontend/src/assets/examples/7.jpg +3 -0
- frontend/src/assets/examples/8.jpg +3 -0
- frontend/src/assets/examples/9.jpg +3 -0
- frontend/src/assets/index.html +18 -0
- frontend/src/assets/scss/App.scss +99 -0
- frontend/src/components/ErrorModal.tsx +32 -0
- frontend/src/components/LoadingOverlay.tsx +30 -0
- frontend/src/components/QueueStatusIndicator.tsx +29 -0
- frontend/src/components/Stage.tsx +343 -0
- frontend/src/components/Tool.tsx +182 -0
- frontend/src/components/helpers/Interfaces.tsx +47 -0
- frontend/src/components/helpers/imageUtils.tsx +21 -0
- frontend/src/components/helpers/maskUtils.tsx +65 -0
- frontend/src/components/helpers/onnxModelAPI.tsx +71 -0
- frontend/src/components/helpers/scaleHelper.tsx +18 -0
- frontend/src/components/hooks/context.tsx +35 -0
- frontend/src/components/hooks/createContext.tsx +35 -0
- frontend/src/index.tsx +17 -0
- frontend/src/services/maskApi.tsx +211 -0
- frontend/tailwind.config.js +12 -0
- frontend/tsconfig.json +24 -0
- frontend/yarn.lock +0 -0
.gitattributes
CHANGED
|
@@ -54,3 +54,24 @@ dist/examples/6.jpg filter=lfs diff=lfs merge=lfs -text
|
|
| 54 |
dist/examples/7.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
dist/examples/8.jpg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
dist/examples/9.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
dist/examples/7.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
dist/examples/8.jpg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
dist/examples/9.jpg filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
frontend/src/assets/examples/1.jpg filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
frontend/src/assets/examples/10.jpg filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
frontend/src/assets/examples/11.jpg filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
frontend/src/assets/examples/12.jpg filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
frontend/src/assets/examples/13.jpg filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
frontend/src/assets/examples/14.jpg filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
frontend/src/assets/examples/15.jpg filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
frontend/src/assets/examples/16.jpg filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
frontend/src/assets/examples/17.jpg filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
frontend/src/assets/examples/18.jpg filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
frontend/src/assets/examples/19.jpg filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
frontend/src/assets/examples/2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
frontend/src/assets/examples/20.jpg filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
frontend/src/assets/examples/21.jpg filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
frontend/src/assets/examples/3.jpg filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
frontend/src/assets/examples/4.jpg filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
frontend/src/assets/examples/5.jpg filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
frontend/src/assets/examples/6.jpg filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
frontend/src/assets/examples/7.jpg filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
frontend/src/assets/examples/8.jpg filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
frontend/src/assets/examples/9.jpg filter=lfs diff=lfs merge=lfs -text
|
frontend/README.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Segment Anything Simple Web demo
|
| 2 |
+
|
| 3 |
+
This **front-end only** React based web demo shows how to load a fixed image and corresponding `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128.
|
| 4 |
+
|
| 5 |
+
<img src="https://github.com/facebookresearch/segment-anything/raw/main/assets/minidemo.gif" width="500"/>
|
| 6 |
+
|
| 7 |
+
## Run the app
|
| 8 |
+
|
| 9 |
+
Install Yarn
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
npm install --g yarn
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
Build and run:
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
yarn && yarn start
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Navigate to [`http://localhost:8081/`](http://localhost:8081/)
|
| 22 |
+
|
| 23 |
+
Move your cursor around to see the mask prediction update in real time.
|
| 24 |
+
|
| 25 |
+
## Export the image embedding
|
| 26 |
+
|
| 27 |
+
In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding.
|
| 28 |
+
|
| 29 |
+
Initialize the predictor:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
checkpoint = "sam_vit_h_4b8939.pth"
|
| 33 |
+
model_type = "vit_h"
|
| 34 |
+
sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
| 35 |
+
sam.to(device='cuda')
|
| 36 |
+
predictor = SamPredictor(sam)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Set the new image and export the embedding:
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
image = cv2.imread('src/assets/dogs.jpg')
|
| 43 |
+
predictor.set_image(image)
|
| 44 |
+
image_embedding = predictor.get_image_embedding().cpu().numpy()
|
| 45 |
+
np.save("dogs_embedding.npy", image_embedding)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Save the new image and embedding in `src/assets/data`.
|
| 49 |
+
|
| 50 |
+
## Export the ONNX model
|
| 51 |
+
|
| 52 |
+
You also need to export the quantized ONNX model from the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb).
|
| 53 |
+
|
| 54 |
+
Run the cell in the notebook which saves the `sam_onnx_quantized_example.onnx` file, download it and copy it to the path `/model/sam_onnx_quantized_example.onnx`.
|
| 55 |
+
|
| 56 |
+
Here is a snippet of the export/quantization code:
|
| 57 |
+
|
| 58 |
+
```
|
| 59 |
+
onnx_model_path = "sam_onnx_example.onnx"
|
| 60 |
+
onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
|
| 61 |
+
quantize_dynamic(
|
| 62 |
+
model_input=onnx_model_path,
|
| 63 |
+
model_output=onnx_model_quantized_path,
|
| 64 |
+
optimize_model=True,
|
| 65 |
+
per_channel=False,
|
| 66 |
+
reduce_range=False,
|
| 67 |
+
weight_type=QuantType.QUInt8,
|
| 68 |
+
)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.**
|
| 72 |
+
|
| 73 |
+
## Update the image, embedding, model in the app
|
| 74 |
+
|
| 75 |
+
Update the following file paths at the top of`App.tsx`:
|
| 76 |
+
|
| 77 |
+
```py
|
| 78 |
+
const IMAGE_PATH = "/assets/data/dogs.jpg";
|
| 79 |
+
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
|
| 80 |
+
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## ONNX multithreading with SharedArrayBuffer
|
| 84 |
+
|
| 85 |
+
To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details)
|
| 86 |
+
|
| 87 |
+
The headers below are set in `configs/webpack/dev.js`:
|
| 88 |
+
|
| 89 |
+
```js
|
| 90 |
+
headers: {
|
| 91 |
+
"Cross-Origin-Opener-Policy": "same-origin",
|
| 92 |
+
"Cross-Origin-Embedder-Policy": "credentialless",
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## Structure of the app
|
| 97 |
+
|
| 98 |
+
**`App.tsx`**
|
| 99 |
+
|
| 100 |
+
- Initializes ONNX model
|
| 101 |
+
- Loads image embedding and image
|
| 102 |
+
- Runs the ONNX model based on input prompts
|
| 103 |
+
|
| 104 |
+
**`Stage.tsx`**
|
| 105 |
+
|
| 106 |
+
- Handles mouse move interaction to update the ONNX model prompt
|
| 107 |
+
|
| 108 |
+
**`Tool.tsx`**
|
| 109 |
+
|
| 110 |
+
- Renders the image and the mask prediction
|
| 111 |
+
|
| 112 |
+
**`helpers/maskUtils.tsx`**
|
| 113 |
+
|
| 114 |
+
- Conversion of ONNX model output from array to an HTMLImageElement
|
| 115 |
+
|
| 116 |
+
**`helpers/onnxModelAPI.tsx`**
|
| 117 |
+
|
| 118 |
+
- Formats the inputs for the ONNX model
|
| 119 |
+
|
| 120 |
+
**`helpers/scaleHelper.tsx`**
|
| 121 |
+
|
| 122 |
+
- Handles image scaling logic for SAM (longest size 1024)
|
| 123 |
+
|
| 124 |
+
**`hooks/`**
|
| 125 |
+
|
| 126 |
+
- Handle shared state for the app
|
frontend/configs/webpack/common.js
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
const { resolve } = require("path");
|
| 8 |
+
const HtmlWebpackPlugin = require("html-webpack-plugin");
|
| 9 |
+
const FriendlyErrorsWebpackPlugin = require("friendly-errors-webpack-plugin");
|
| 10 |
+
const CopyPlugin = require("copy-webpack-plugin");
|
| 11 |
+
const webpack = require("webpack");
|
| 12 |
+
|
| 13 |
+
module.exports = {
|
| 14 |
+
entry: "./src/index.tsx",
|
| 15 |
+
resolve: {
|
| 16 |
+
extensions: [".js", ".jsx", ".ts", ".tsx"],
|
| 17 |
+
fallback: { 'process/browser': require.resolve('process/browser'), }
|
| 18 |
+
},
|
| 19 |
+
output: {
|
| 20 |
+
path: resolve(__dirname, "dist"),
|
| 21 |
+
},
|
| 22 |
+
module: {
|
| 23 |
+
rules: [
|
| 24 |
+
{
|
| 25 |
+
test: /\.mjs$/,
|
| 26 |
+
include: /node_modules/,
|
| 27 |
+
type: "javascript/auto",
|
| 28 |
+
resolve: {
|
| 29 |
+
fullySpecified: false,
|
| 30 |
+
},
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
test: [/\.jsx?$/, /\.tsx?$/],
|
| 34 |
+
use: ["ts-loader"],
|
| 35 |
+
exclude: /node_modules/,
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
test: /\.css$/,
|
| 39 |
+
use: ["style-loader", "css-loader"],
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
test: /\.(scss|sass)$/,
|
| 43 |
+
use: ["style-loader", "css-loader", "postcss-loader"],
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
test: /\.(jpe?g|png|gif|svg)$/i,
|
| 47 |
+
use: [
|
| 48 |
+
"file-loader?hash=sha512&digest=hex&name=img/[contenthash].[ext]",
|
| 49 |
+
"image-webpack-loader?bypassOnDebug&optipng.optimizationLevel=7&gifsicle.interlaced=false",
|
| 50 |
+
],
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
test: /\.(woff|woff2|ttf)$/,
|
| 54 |
+
use: {
|
| 55 |
+
loader: "url-loader",
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
],
|
| 59 |
+
},
|
| 60 |
+
plugins: [
|
| 61 |
+
new CopyPlugin({
|
| 62 |
+
patterns: [
|
| 63 |
+
{
|
| 64 |
+
from: "node_modules/onnxruntime-web/dist/*.wasm",
|
| 65 |
+
to: "[name][ext]",
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
from: "model",
|
| 69 |
+
to: "model",
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
from: "src/assets/examples",
|
| 73 |
+
to: "examples",
|
| 74 |
+
},
|
| 75 |
+
],
|
| 76 |
+
}),
|
| 77 |
+
new HtmlWebpackPlugin({
|
| 78 |
+
template: "./src/assets/index.html",
|
| 79 |
+
}),
|
| 80 |
+
new FriendlyErrorsWebpackPlugin(),
|
| 81 |
+
new webpack.ProvidePlugin({
|
| 82 |
+
process: "process/browser",
|
| 83 |
+
}),
|
| 84 |
+
],
|
| 85 |
+
};
|
frontend/configs/webpack/dev.js
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// development config
|
| 8 |
+
const { merge } = require("webpack-merge");
|
| 9 |
+
const commonConfig = require("./common");
|
| 10 |
+
|
| 11 |
+
module.exports = merge(commonConfig, {
|
| 12 |
+
mode: "development",
|
| 13 |
+
devServer: {
|
| 14 |
+
hot: true, // enable HMR on the server
|
| 15 |
+
open: true,
|
| 16 |
+
// These headers enable the cross origin isolation state
|
| 17 |
+
// needed to enable use of SharedArrayBuffer for ONNX
|
| 18 |
+
// multithreading.
|
| 19 |
+
headers: {
|
| 20 |
+
"Cross-Origin-Opener-Policy": "same-origin",
|
| 21 |
+
"Cross-Origin-Embedder-Policy": "credentialless",
|
| 22 |
+
},
|
| 23 |
+
},
|
| 24 |
+
devtool: "cheap-module-source-map",
|
| 25 |
+
});
|
frontend/configs/webpack/prod.js
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// production config
|
| 8 |
+
const { merge } = require("webpack-merge");
|
| 9 |
+
const { resolve } = require("path");
|
| 10 |
+
const Dotenv = require("dotenv-webpack");
|
| 11 |
+
const commonConfig = require("./common");
|
| 12 |
+
|
| 13 |
+
module.exports = merge(commonConfig, {
|
| 14 |
+
mode: "production",
|
| 15 |
+
output: {
|
| 16 |
+
filename: "js/bundle.[contenthash].min.js",
|
| 17 |
+
path: resolve(__dirname, "../../dist"),
|
| 18 |
+
publicPath: "/",
|
| 19 |
+
},
|
| 20 |
+
devtool: "source-map",
|
| 21 |
+
plugins: [new Dotenv()],
|
| 22 |
+
});
|
frontend/model/sam_onnx_example.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76edead1afcb7a3ed7672939bf375aeb64b57f943769ee286fa02f0a382b3393
|
| 3 |
+
size 16501766
|
frontend/model/sam_onnx_quantized_example.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b134e7a1a9779862761aa207bb91dd2bdbf2c1182486dc216c71dbba1b109819
|
| 3 |
+
size 8824322
|
frontend/package.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "segment-anything-mini-demo",
|
| 3 |
+
"version": "0.1.0",
|
| 4 |
+
"license": "MIT",
|
| 5 |
+
"scripts": {
|
| 6 |
+
"build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js && rsync -r --delete dist ../",
|
| 7 |
+
"clean-dist": "rimraf dist/*",
|
| 8 |
+
"lint": "eslint './src/**/*.{js,ts,tsx}' --quiet",
|
| 9 |
+
"start": "yarn run start-dev",
|
| 10 |
+
"test": "yarn run start-model-test",
|
| 11 |
+
"start-dev": "webpack serve --config=configs/webpack/dev.js"
|
| 12 |
+
},
|
| 13 |
+
"devDependencies": {
|
| 14 |
+
"@babel/core": "^7.18.13",
|
| 15 |
+
"@babel/preset-env": "^7.18.10",
|
| 16 |
+
"@babel/preset-react": "^7.18.6",
|
| 17 |
+
"@babel/preset-typescript": "^7.18.6",
|
| 18 |
+
"@pmmmwh/react-refresh-webpack-plugin": "^0.5.7",
|
| 19 |
+
"@testing-library/react": "^13.3.0",
|
| 20 |
+
"@types/node": "^18.7.13",
|
| 21 |
+
"@types/react": "^18.0.17",
|
| 22 |
+
"@types/react-dom": "^18.0.6",
|
| 23 |
+
"@types/underscore": "^1.11.4",
|
| 24 |
+
"@typescript-eslint/eslint-plugin": "^5.35.1",
|
| 25 |
+
"@typescript-eslint/parser": "^5.35.1",
|
| 26 |
+
"babel-loader": "^8.2.5",
|
| 27 |
+
"copy-webpack-plugin": "^11.0.0",
|
| 28 |
+
"css-loader": "^6.7.1",
|
| 29 |
+
"dotenv": "^16.0.2",
|
| 30 |
+
"dotenv-webpack": "^8.0.1",
|
| 31 |
+
"eslint": "^8.22.0",
|
| 32 |
+
"eslint-plugin-react": "^7.31.0",
|
| 33 |
+
"file-loader": "^6.2.0",
|
| 34 |
+
"fork-ts-checker-webpack-plugin": "^7.2.13",
|
| 35 |
+
"friendly-errors-webpack-plugin": "^1.7.0",
|
| 36 |
+
"html-webpack-plugin": "^5.5.0",
|
| 37 |
+
"image-webpack-loader": "^8.1.0",
|
| 38 |
+
"postcss-loader": "^7.0.1",
|
| 39 |
+
"postcss-preset-env": "^7.8.0",
|
| 40 |
+
"process": "^0.11.10",
|
| 41 |
+
"rimraf": "^3.0.2",
|
| 42 |
+
"sass": "^1.54.5",
|
| 43 |
+
"sass-loader": "^13.0.2",
|
| 44 |
+
"style-loader": "^3.3.1",
|
| 45 |
+
"tailwindcss": "^3.1.8",
|
| 46 |
+
"ts-loader": "^9.3.1",
|
| 47 |
+
"typescript": "^4.8.2",
|
| 48 |
+
"webpack": "^5.74.0",
|
| 49 |
+
"webpack-cli": "^4.10.0",
|
| 50 |
+
"webpack-dev-server": "^4.10.0",
|
| 51 |
+
"webpack-dotenv-plugin": "^2.1.0",
|
| 52 |
+
"webpack-merge": "^5.8.0"
|
| 53 |
+
},
|
| 54 |
+
"dependencies": {
|
| 55 |
+
"@gradio/client": "^1.7.1",
|
| 56 |
+
"npyjs": "^0.4.0",
|
| 57 |
+
"onnxruntime-web": "1.14.0",
|
| 58 |
+
"react": "^18.2.0",
|
| 59 |
+
"react-dom": "^18.2.0",
|
| 60 |
+
"react-refresh": "^0.14.0",
|
| 61 |
+
"underscore": "^1.13.6",
|
| 62 |
+
"axios": "^1.6.7"
|
| 63 |
+
}
|
| 64 |
+
}
|
frontend/postcss.config.js
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
const tailwindcss = require("tailwindcss");
|
| 8 |
+
module.exports = {
|
| 9 |
+
plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss],
|
| 10 |
+
};
|
frontend/src/App.tsx
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { InferenceSession, Tensor } from "onnxruntime-web";
|
| 8 |
+
import React, { useContext, useEffect, useState, useRef } from "react";
|
| 9 |
+
import axios from "axios";
|
| 10 |
+
import "./assets/scss/App.scss";
|
| 11 |
+
import { handleImageScale } from "./components/helpers/scaleHelper";
|
| 12 |
+
import { modelScaleProps, QueueStatus } from "./components/helpers/Interfaces";
|
| 13 |
+
import { onnxMaskToImage, arrayToImageData, imageDataToURL } from "./components/helpers/maskUtils";
|
| 14 |
+
import { modelData } from "./components/helpers/onnxModelAPI";
|
| 15 |
+
import Stage, { DescriptionState } from "./components/Stage";
|
| 16 |
+
import AppContext from "./components/hooks/createContext";
|
| 17 |
+
import { imageToSamEmbedding } from "./services/maskApi";
|
| 18 |
+
import LoadingOverlay from "./components/LoadingOverlay";
|
| 19 |
+
import ErrorModal from './components/ErrorModal';
|
| 20 |
+
import QueueStatusIndicator from "./components/QueueStatusIndicator";
|
| 21 |
+
|
| 22 |
+
const ort = require("onnxruntime-web");
|
| 23 |
+
|
| 24 |
+
// Define image and model paths
|
| 25 |
+
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
|
| 26 |
+
|
| 27 |
+
const App = () => {
|
| 28 |
+
const {
|
| 29 |
+
clicks: [clicks, setClicks],
|
| 30 |
+
image: [image, setImage],
|
| 31 |
+
maskImg: [maskImg, setMaskImg],
|
| 32 |
+
maskImgData: [maskImgData, setMaskImgData],
|
| 33 |
+
isClicked: [isClicked, setIsClicked]
|
| 34 |
+
} = useContext(AppContext)!;
|
| 35 |
+
const [model, setModel] = useState<InferenceSession | null>(null);
|
| 36 |
+
const [tensor, setTensor] = useState<Tensor | null>(null);
|
| 37 |
+
const [modelScale, setModelScale] = useState<modelScaleProps | null>(null);
|
| 38 |
+
const [isLoading, setIsLoading] = useState<boolean>(false);
|
| 39 |
+
const [error, setError] = useState<string | null>(null);
|
| 40 |
+
const [descriptionState, setDescriptionState] = useState<DescriptionState>({
|
| 41 |
+
state: 'ready',
|
| 42 |
+
description: ''
|
| 43 |
+
});
|
| 44 |
+
const [queueStatus, setQueueStatus] = useState<QueueStatus>({ inQueue: false });
|
| 45 |
+
|
| 46 |
+
// Initialize the ONNX model
|
| 47 |
+
useEffect(() => {
|
| 48 |
+
const initModel = async () => {
|
| 49 |
+
try {
|
| 50 |
+
if (MODEL_DIR === undefined) return;
|
| 51 |
+
const URL: string = MODEL_DIR;
|
| 52 |
+
const model = await InferenceSession.create(URL);
|
| 53 |
+
setModel(model);
|
| 54 |
+
} catch (e) {
|
| 55 |
+
console.log(e);
|
| 56 |
+
}
|
| 57 |
+
};
|
| 58 |
+
initModel();
|
| 59 |
+
}, []);
|
| 60 |
+
|
| 61 |
+
const handleImageUpload = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
| 62 |
+
const file = event.target.files?.[0];
|
| 63 |
+
if (!file) return;
|
| 64 |
+
|
| 65 |
+
try {
|
| 66 |
+
const url = URL.createObjectURL(file);
|
| 67 |
+
await loadImage(new URL(url));
|
| 68 |
+
} catch (error) {
|
| 69 |
+
setError('Failed to load image. Please try again with a different image.');
|
| 70 |
+
console.error('Error loading image:', error);
|
| 71 |
+
}
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
const loadImage = async (url: URL) => {
|
| 75 |
+
try {
|
| 76 |
+
setIsLoading(true);
|
| 77 |
+
const img = new Image();
|
| 78 |
+
img.src = url.href;
|
| 79 |
+
img.onload = async () => {
|
| 80 |
+
const { height, width, samScale } = handleImageScale(img);
|
| 81 |
+
setModelScale({
|
| 82 |
+
height: height,
|
| 83 |
+
width: width,
|
| 84 |
+
samScale: samScale,
|
| 85 |
+
});
|
| 86 |
+
img.width = width;
|
| 87 |
+
img.height = height;
|
| 88 |
+
setImage(img);
|
| 89 |
+
|
| 90 |
+
// After image is loaded, fetch its embedding from Gradio
|
| 91 |
+
await fetchImageEmbedding(img);
|
| 92 |
+
setIsLoading(false);
|
| 93 |
+
};
|
| 94 |
+
} catch (error) {
|
| 95 |
+
console.log(error);
|
| 96 |
+
setIsLoading(false);
|
| 97 |
+
}
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
const fetchImageEmbedding = async (img: HTMLImageElement) => {
|
| 101 |
+
try {
|
| 102 |
+
// Create a canvas to convert the image to base64
|
| 103 |
+
const canvas = document.createElement('canvas');
|
| 104 |
+
canvas.width = img.width;
|
| 105 |
+
canvas.height = img.height;
|
| 106 |
+
const ctx = canvas.getContext('2d');
|
| 107 |
+
ctx?.drawImage(img, 0, 0);
|
| 108 |
+
|
| 109 |
+
// Convert image to base64 data URL and extract the base64 string
|
| 110 |
+
const base64Image = canvas.toDataURL('image/jpeg').split(',')[1];
|
| 111 |
+
|
| 112 |
+
// Make request to Gradio API
|
| 113 |
+
const samEmbedding = await imageToSamEmbedding(
|
| 114 |
+
base64Image,
|
| 115 |
+
(status: QueueStatus) => {
|
| 116 |
+
setQueueStatus(status);
|
| 117 |
+
}
|
| 118 |
+
);
|
| 119 |
+
|
| 120 |
+
// Convert base64 embedding back to array buffer
|
| 121 |
+
const binaryString = window.atob(samEmbedding);
|
| 122 |
+
const len = binaryString.length;
|
| 123 |
+
const bytes = new Uint8Array(len);
|
| 124 |
+
for (let i = 0; i < len; i++) {
|
| 125 |
+
bytes[i] = binaryString.charCodeAt(i);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// Create tensor from the embedding
|
| 129 |
+
const embedding = new ort.Tensor(
|
| 130 |
+
'float32',
|
| 131 |
+
new Float32Array(bytes.buffer), // Convert to Float32Array
|
| 132 |
+
[1, 256, 64, 64] // SAM embedding shape
|
| 133 |
+
);
|
| 134 |
+
setTensor(embedding);
|
| 135 |
+
} catch (error) {
|
| 136 |
+
setQueueStatus({ inQueue: false }); // Reset queue status on error
|
| 137 |
+
let errorMessage = 'Failed to process image. Please try again.';
|
| 138 |
+
if (axios.isAxiosError(error)) {
|
| 139 |
+
errorMessage = error.response?.data?.message || errorMessage;
|
| 140 |
+
}
|
| 141 |
+
setError(errorMessage);
|
| 142 |
+
console.error('Error fetching embedding:', error);
|
| 143 |
+
}
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
useEffect(() => {
|
| 147 |
+
const handleMaskUpdate = async () => {
|
| 148 |
+
await runONNX();
|
| 149 |
+
};
|
| 150 |
+
handleMaskUpdate();
|
| 151 |
+
}, [clicks]);
|
| 152 |
+
|
| 153 |
+
const runONNX = async () => {
|
| 154 |
+
try {
|
| 155 |
+
// Don't run if already described or is describing
|
| 156 |
+
if (descriptionState.state !== 'ready') return;
|
| 157 |
+
|
| 158 |
+
console.log('Running ONNX model with:', {
|
| 159 |
+
modelLoaded: model !== null,
|
| 160 |
+
hasClicks: clicks !== null,
|
| 161 |
+
hasTensor: tensor !== null,
|
| 162 |
+
hasModelScale: modelScale !== null
|
| 163 |
+
});
|
| 164 |
+
|
| 165 |
+
if (
|
| 166 |
+
model === null ||
|
| 167 |
+
clicks === null ||
|
| 168 |
+
tensor === null ||
|
| 169 |
+
modelScale === null
|
| 170 |
+
) {
|
| 171 |
+
console.log('Missing required inputs, returning early');
|
| 172 |
+
return;
|
| 173 |
+
}
|
| 174 |
+
else {
|
| 175 |
+
console.log('Preparing model feeds with:', {
|
| 176 |
+
clicks,
|
| 177 |
+
tensorShape: tensor.dims,
|
| 178 |
+
modelScale
|
| 179 |
+
});
|
| 180 |
+
|
| 181 |
+
const feeds = modelData({
|
| 182 |
+
clicks,
|
| 183 |
+
tensor,
|
| 184 |
+
modelScale,
|
| 185 |
+
});
|
| 186 |
+
|
| 187 |
+
if (feeds === undefined) {
|
| 188 |
+
console.log('Model feeds undefined, returning early');
|
| 189 |
+
return;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
console.log('Running model with feeds:', feeds);
|
| 193 |
+
const results = await model.run(feeds);
|
| 194 |
+
console.log('Model run complete, got results:', results);
|
| 195 |
+
|
| 196 |
+
const output = results[model.outputNames[0]];
|
| 197 |
+
console.log('Processing output with dims:', output.dims);
|
| 198 |
+
|
| 199 |
+
// Calculate and log the mask area (number of non-zero values)
|
| 200 |
+
const maskArray = Array.from(output.data as Uint8Array);
|
| 201 |
+
const maskArea = maskArray.filter(val => val > 0).length;
|
| 202 |
+
console.log('Mask area (number of non-zero pixels):', maskArea);
|
| 203 |
+
|
| 204 |
+
// Double check that the state is ready before processing the mask since the state may have changed
|
| 205 |
+
if (descriptionState.state !== 'ready') return;
|
| 206 |
+
// If clicked, we only handle the first mask (note that mask will be cleared after clicking before handling to let us know if it's the first mask).
|
| 207 |
+
if (isClicked && maskImgData != null) return;
|
| 208 |
+
if (maskArea > 0) {
|
| 209 |
+
setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3], false));
|
| 210 |
+
setMaskImgData(imageDataToURL(arrayToImageData(output.data, output.dims[2], output.dims[3], true)));
|
| 211 |
+
} else {
|
| 212 |
+
console.warn('No mask area detected, clearing mask');
|
| 213 |
+
setMaskImg(null);
|
| 214 |
+
// setMaskImgData(null);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
console.log('Mask processing complete');
|
| 218 |
+
}
|
| 219 |
+
} catch (e) {
|
| 220 |
+
setError('Failed to process the image. Please try again.');
|
| 221 |
+
console.error('Error running ONNX model:', e);
|
| 222 |
+
}
|
| 223 |
+
};
|
| 224 |
+
|
| 225 |
+
const handleNewRegion = () => {
|
| 226 |
+
setDescriptionState({
|
| 227 |
+
state: 'ready',
|
| 228 |
+
description: ''
|
| 229 |
+
} as DescriptionState);
|
| 230 |
+
setMaskImg(null);
|
| 231 |
+
// setMaskImgData(null);
|
| 232 |
+
setIsClicked(false);
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
const handleCopyDescription = () => {
|
| 236 |
+
navigator.clipboard.writeText(descriptionState.description);
|
| 237 |
+
};
|
| 238 |
+
|
| 239 |
+
const handleReset = () => {
|
| 240 |
+
// Clear all states
|
| 241 |
+
setDescriptionState({
|
| 242 |
+
state: 'ready',
|
| 243 |
+
description: ''
|
| 244 |
+
} as DescriptionState);
|
| 245 |
+
setMaskImg(null);
|
| 246 |
+
// setMaskImgData(null);
|
| 247 |
+
setImage(null);
|
| 248 |
+
setClicks(null);
|
| 249 |
+
setIsClicked(false);
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
return (
|
| 253 |
+
<div className="flex flex-col h-screen">
|
| 254 |
+
{isLoading && <LoadingOverlay />}
|
| 255 |
+
{error && <ErrorModal message={error} onClose={() => setError(null)} />}
|
| 256 |
+
<QueueStatusIndicator queueStatus={queueStatus} />
|
| 257 |
+
<div className="flex-1">
|
| 258 |
+
<Stage
|
| 259 |
+
onImageUpload={handleImageUpload}
|
| 260 |
+
descriptionState={descriptionState}
|
| 261 |
+
setDescriptionState={setDescriptionState}
|
| 262 |
+
queueStatus={queueStatus}
|
| 263 |
+
setQueueStatus={setQueueStatus}
|
| 264 |
+
/>
|
| 265 |
+
</div>
|
| 266 |
+
<div className="description-container">
|
| 267 |
+
<div className={`description-box ${descriptionState.state !== 'described' ? descriptionState.state : ''}`}>
|
| 268 |
+
{descriptionState.description ? (
|
| 269 |
+
descriptionState.description + (descriptionState.state === 'describing' ? '...' : '')
|
| 270 |
+
) : descriptionState.state === 'describing' ? (
|
| 271 |
+
<em>Describing the region... (this may take a while if compute resources are busy)</em>
|
| 272 |
+
) : (
|
| 273 |
+
image ? (
|
| 274 |
+
<em>Click on the image to describe the region</em>
|
| 275 |
+
) : (
|
| 276 |
+
<em>Upload an image to describe the region</em>
|
| 277 |
+
)
|
| 278 |
+
)}
|
| 279 |
+
</div>
|
| 280 |
+
<div className="description-controls">
|
| 281 |
+
<button
|
| 282 |
+
onClick={handleCopyDescription}
|
| 283 |
+
disabled={descriptionState.state !== 'described'}
|
| 284 |
+
>
|
| 285 |
+
Copy description
|
| 286 |
+
</button>
|
| 287 |
+
<button
|
| 288 |
+
onClick={handleNewRegion}
|
| 289 |
+
disabled={descriptionState.state !== 'described'}
|
| 290 |
+
>
|
| 291 |
+
Describe a new region
|
| 292 |
+
</button>
|
| 293 |
+
<button
|
| 294 |
+
onClick={handleReset}
|
| 295 |
+
className="reset-button"
|
| 296 |
+
disabled={descriptionState.state === 'describing' || !image}
|
| 297 |
+
>
|
| 298 |
+
Try a new image
|
| 299 |
+
</button>
|
| 300 |
+
</div>
|
| 301 |
+
</div>
|
| 302 |
+
</div>
|
| 303 |
+
);
|
| 304 |
+
};
|
| 305 |
+
|
| 306 |
+
export default App;
|
frontend/src/assets/examples/1.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/10.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/11.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/12.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/13.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/14.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/15.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/16.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/17.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/18.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/19.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/2.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/20.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/21.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/3.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/4.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/5.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/6.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/7.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/8.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/examples/9.jpg
ADDED
|
Git LFS Details
|
frontend/src/assets/index.html
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en" dir="ltr" prefix="og: https://ogp.me/ns#" class="w-full h-full">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8" />
|
| 5 |
+
<meta
|
| 6 |
+
name="viewport"
|
| 7 |
+
content="width=device-width, initial-scale=1, shrink-to-fit=no"
|
| 8 |
+
/>
|
| 9 |
+
<title>Describe Anything Demo</title>
|
| 10 |
+
|
| 11 |
+
<!-- Meta Tags -->
|
| 12 |
+
<meta property="og:type" content="website" />
|
| 13 |
+
<meta property="og:title" content="Describe Anything Demo" />
|
| 14 |
+
</head>
|
| 15 |
+
<body class="w-full h-full">
|
| 16 |
+
<div id="root" class="w-full h-full"></div>
|
| 17 |
+
</body>
|
| 18 |
+
</html>
|
frontend/src/assets/scss/App.scss
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@tailwind base;
|
| 2 |
+
@tailwind components;
|
| 3 |
+
@tailwind utilities;
|
| 4 |
+
|
| 5 |
+
.fixed {
|
| 6 |
+
position: fixed;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
.inset-0 {
|
| 10 |
+
top: 0;
|
| 11 |
+
right: 0;
|
| 12 |
+
bottom: 0;
|
| 13 |
+
left: 0;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
.bg-opacity-75 {
|
| 17 |
+
--tw-bg-opacity: 0.75;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.z-50 {
|
| 21 |
+
z-index: 50;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
.description-container {
|
| 25 |
+
margin: 20px;
|
| 26 |
+
display: flex;
|
| 27 |
+
gap: 20px;
|
| 28 |
+
height: 140px;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
.description-box {
|
| 32 |
+
flex: 1;
|
| 33 |
+
background-color: #f5f5f5;
|
| 34 |
+
border-radius: 4px;
|
| 35 |
+
padding: 15px;
|
| 36 |
+
margin-bottom: 0;
|
| 37 |
+
color: #333;
|
| 38 |
+
overflow-y: auto;
|
| 39 |
+
|
| 40 |
+
&.describing, &.ready {
|
| 41 |
+
background-color: #e9ecef;
|
| 42 |
+
color: #6c757d;
|
| 43 |
+
font-style: italic;
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.description-controls {
|
| 48 |
+
display: flex;
|
| 49 |
+
flex-direction: column;
|
| 50 |
+
justify-content: space-between;
|
| 51 |
+
width: 200px;
|
| 52 |
+
gap: 10px;
|
| 53 |
+
|
| 54 |
+
button {
|
| 55 |
+
padding: 8px 16px;
|
| 56 |
+
border: none;
|
| 57 |
+
border-radius: 4px;
|
| 58 |
+
background-color: #007bff;
|
| 59 |
+
color: white;
|
| 60 |
+
cursor: pointer;
|
| 61 |
+
white-space: nowrap;
|
| 62 |
+
|
| 63 |
+
&:disabled {
|
| 64 |
+
background-color: #cccccc;
|
| 65 |
+
cursor: not-allowed;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
&:hover:not(:disabled) {
|
| 69 |
+
background-color: #0056b3;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
&.reset-button {
|
| 73 |
+
background-color: #007bff;
|
| 74 |
+
|
| 75 |
+
&:hover:not(:disabled) {
|
| 76 |
+
background-color: #0056b3;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
&:disabled {
|
| 80 |
+
background-color: #cccccc;
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
#root {
|
| 87 |
+
display: flex;
|
| 88 |
+
flex-direction: column;
|
| 89 |
+
height: 100vh;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
.stage-container {
|
| 93 |
+
flex: 1;
|
| 94 |
+
min-height: 0;
|
| 95 |
+
display: flex;
|
| 96 |
+
align-items: center;
|
| 97 |
+
justify-content: center;
|
| 98 |
+
overflow: hidden;
|
| 99 |
+
}
|
frontend/src/components/ErrorModal.tsx
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
|
| 3 |
+
interface ErrorModalProps {
|
| 4 |
+
message: string;
|
| 5 |
+
onClose: () => void;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
const ErrorModal: React.FC<ErrorModalProps> = ({ message, onClose }) => {
|
| 9 |
+
return (
|
| 10 |
+
<div className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50">
|
| 11 |
+
<div className="bg-white p-6 rounded-lg shadow-xl max-w-md w-full mx-4">
|
| 12 |
+
<div className="flex flex-col items-center">
|
| 13 |
+
<div className="bg-red-100 p-4 rounded-full mb-4">
|
| 14 |
+
<svg className="w-6 h-6 text-red-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 15 |
+
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth="2" d="M6 18L18 6M6 6l12 12" />
|
| 16 |
+
</svg>
|
| 17 |
+
</div>
|
| 18 |
+
<h3 className="text-lg font-semibold text-gray-900 mb-2">Error</h3>
|
| 19 |
+
<p className="text-gray-600 text-center mb-6">{message}</p>
|
| 20 |
+
<button
|
| 21 |
+
onClick={onClose}
|
| 22 |
+
className="bg-red-600 text-white px-4 py-2 rounded hover:bg-red-700 transition-colors"
|
| 23 |
+
>
|
| 24 |
+
Close
|
| 25 |
+
</button>
|
| 26 |
+
</div>
|
| 27 |
+
</div>
|
| 28 |
+
</div>
|
| 29 |
+
);
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
export default ErrorModal;
|
frontend/src/components/LoadingOverlay.tsx
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
|
| 3 |
+
const LoadingOverlay: React.FC = () => {
|
| 4 |
+
return (
|
| 5 |
+
<div className="fixed inset-0 bg-gray-500 bg-opacity-75 flex items-center justify-center z-50">
|
| 6 |
+
<div className="bg-white p-8 rounded-lg shadow-xl flex flex-col items-center">
|
| 7 |
+
<svg width="54" height="54" viewBox="0 0 54 54" fill="none" xmlns="http://www.w3.org/2000/svg" className="w-16 h-16 mb-4">
|
| 8 |
+
<path d="M5.92017 41.0562L27.0002 48.0802" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 9 |
+
<path d="M5.92017 12.9438L27.0002 26.9998" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 10 |
+
<path d="M27 5.91992L48.08 26.9999" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 11 |
+
<path d="M5.92017 41.0559L27.0002 5.91992" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 12 |
+
<path d="M27 48.08L48.08 27" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 13 |
+
<path d="M27 27H48.08" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 14 |
+
<path d="M5.92017 12.9439L27.0002 5.91992" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 15 |
+
<path d="M5.92017 41.056L27.0002 27" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 16 |
+
<path d="M5.92017 12.9438L27.0002 48.0798" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 17 |
+
<path d="M26.9998 31.9201C29.7171 31.9201 31.9198 29.7173 31.9198 27.0001C31.9198 24.2828 29.7171 22.0801 26.9998 22.0801C24.2826 22.0801 22.0798 24.2828 22.0798 27.0001C22.0798 29.7173 24.2826 31.9201 26.9998 31.9201Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 18 |
+
<path d="M5.92 17.8639C8.63724 17.8639 10.84 15.6612 10.84 12.9439C10.84 10.2267 8.63724 8.02393 5.92 8.02393C3.20276 8.02393 1 10.2267 1 12.9439C1 15.6612 3.20276 17.8639 5.92 17.8639Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 19 |
+
<path d="M5.92 45.9757C8.63724 45.9757 10.84 43.773 10.84 41.0557C10.84 38.3385 8.63724 36.1357 5.92 36.1357C3.20276 36.1357 1 38.3385 1 41.0557C1 43.773 3.20276 45.9757 5.92 45.9757Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 20 |
+
<path d="M48.0806 31.9201C50.7979 31.9201 53.0006 29.7173 53.0006 27.0001C53.0006 24.2828 50.7979 22.0801 48.0806 22.0801C45.3634 22.0801 43.1606 24.2828 43.1606 27.0001C43.1606 29.7173 45.3634 31.9201 48.0806 31.9201Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 21 |
+
<path d="M26.9998 53.0002C29.7171 53.0002 31.9198 50.7974 31.9198 48.0802C31.9198 45.3629 29.7171 43.1602 26.9998 43.1602C24.2826 43.1602 22.0798 45.3629 22.0798 48.0802C22.0798 50.7974 24.2826 53.0002 26.9998 53.0002Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 22 |
+
<path d="M26.9998 10.84C29.7171 10.84 31.9198 8.63724 31.9198 5.92C31.9198 3.20276 29.7171 1 26.9998 1C24.2826 1 22.0798 3.20276 22.0798 5.92C22.0798 8.63724 24.2826 10.84 26.9998 10.84Z" fill="white" stroke="#1C2B33" strokeWidth="2" strokeMiterlimit="10"/>
|
| 23 |
+
</svg>
|
| 24 |
+
<p className="text-lg font-semibold text-gray-800">Loading image embedding...</p>
|
| 25 |
+
</div>
|
| 26 |
+
</div>
|
| 27 |
+
);
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
export default LoadingOverlay;
|
frontend/src/components/QueueStatusIndicator.tsx
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
import { QueueStatus } from './helpers/Interfaces';
|
| 3 |
+
|
| 4 |
+
interface QueueStatusIndicatorProps {
|
| 5 |
+
queueStatus: QueueStatus;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
const QueueStatusIndicator: React.FC<QueueStatusIndicatorProps> = ({ queueStatus }) => {
|
| 9 |
+
if (!queueStatus.inQueue) return null;
|
| 10 |
+
|
| 11 |
+
return (
|
| 12 |
+
<div className="fixed top-4 right-4 bg-white rounded-lg shadow-lg p-4 z-50">
|
| 13 |
+
<div className="flex flex-col gap-2">
|
| 14 |
+
{queueStatus.rank === 0 ? (
|
| 15 |
+
<p className="text-sm">You're next in line! ({queueStatus.queueSize} total in queue)</p>
|
| 16 |
+
) : (
|
| 17 |
+
<p className="text-sm">Queue position: {queueStatus.rank! + 1} of {queueStatus.queueSize}</p>
|
| 18 |
+
)}
|
| 19 |
+
{queueStatus.rankEta && (
|
| 20 |
+
<p className="text-sm text-gray-600">
|
| 21 |
+
Estimated wait: {Math.ceil(queueStatus.rankEta)} seconds
|
| 22 |
+
</p>
|
| 23 |
+
)}
|
| 24 |
+
</div>
|
| 25 |
+
</div>
|
| 26 |
+
);
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
export default QueueStatusIndicator;
|
frontend/src/components/Stage.tsx
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import React, { useContext, useState, useEffect } from "react";
|
| 8 |
+
import * as _ from "underscore";
|
| 9 |
+
import Tool from "./Tool";
|
| 10 |
+
import { modelInputProps, QueueStatus } from "./helpers/Interfaces";
|
| 11 |
+
import AppContext from "./hooks/createContext";
|
| 12 |
+
// import { describeMask } from '../services/maskApi';
|
| 13 |
+
|
| 14 |
+
interface DescriptionState {
|
| 15 |
+
state: string; // 'ready', 'describing', 'described'
|
| 16 |
+
description: string;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
interface StageProps {
|
| 20 |
+
onImageUpload: (event: React.ChangeEvent<HTMLInputElement>) => Promise<void>;
|
| 21 |
+
descriptionState: DescriptionState;
|
| 22 |
+
setDescriptionState: React.Dispatch<React.SetStateAction<DescriptionState>>;
|
| 23 |
+
queueStatus: QueueStatus;
|
| 24 |
+
setQueueStatus: (status: QueueStatus) => void;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
const EXAMPLE_IMAGES = Array.from({ length: 21 }, (_, i) => `/examples/${i + 1}.jpg`);
|
| 28 |
+
const BREAKPOINT_MEDIUM = 2100;
|
| 29 |
+
const BREAKPOINT_SMALL = 1100;
|
| 30 |
+
|
| 31 |
+
const Stage = ({ onImageUpload, descriptionState, setDescriptionState, queueStatus, setQueueStatus }: StageProps) => {
|
| 32 |
+
const {
|
| 33 |
+
clicks: [, setClicks],
|
| 34 |
+
image: [image],
|
| 35 |
+
maskImg: [maskImg],
|
| 36 |
+
maskImgData: [maskImgData]
|
| 37 |
+
} = useContext(AppContext)!;
|
| 38 |
+
|
| 39 |
+
const [isDragging, setIsDragging] = useState(false);
|
| 40 |
+
const [currentPage, setCurrentPage] = useState(1);
|
| 41 |
+
const [imagesPerPage, setImagesPerPage] = useState(8);
|
| 42 |
+
|
| 43 |
+
useEffect(() => {
|
| 44 |
+
const handleResize = () => {
|
| 45 |
+
if (window.innerWidth < BREAKPOINT_SMALL) {
|
| 46 |
+
setImagesPerPage(1);
|
| 47 |
+
} else if (window.innerWidth < BREAKPOINT_MEDIUM) {
|
| 48 |
+
setImagesPerPage(4);
|
| 49 |
+
} else {
|
| 50 |
+
setImagesPerPage(8);
|
| 51 |
+
}
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
// Set initial value
|
| 55 |
+
handleResize();
|
| 56 |
+
|
| 57 |
+
// Add event listener
|
| 58 |
+
window.addEventListener('resize', handleResize);
|
| 59 |
+
|
| 60 |
+
// Cleanup
|
| 61 |
+
return () => window.removeEventListener('resize', handleResize);
|
| 62 |
+
}, []);
|
| 63 |
+
|
| 64 |
+
const getClick = (x: number, y: number): modelInputProps => {
|
| 65 |
+
const clickType = 1;
|
| 66 |
+
return { x, y, clickType };
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
const handleMouseMove = _.throttle((e: any) => {
|
| 70 |
+
if (descriptionState.state !== 'ready') return;
|
| 71 |
+
if (e.clientX === undefined || e.clientY === undefined) {
|
| 72 |
+
console.warn('Mouse move event does not contain clientX or clientY');
|
| 73 |
+
return;
|
| 74 |
+
}
|
| 75 |
+
let el = e.nativeEvent.target;
|
| 76 |
+
const rect = el.getBoundingClientRect();
|
| 77 |
+
|
| 78 |
+
// Calculate the actual dimensions of the contained image
|
| 79 |
+
const containerAspectRatio = el.offsetWidth / el.offsetHeight;
|
| 80 |
+
const imageAspectRatio = image ? image.width / image.height : 1;
|
| 81 |
+
|
| 82 |
+
let renderedWidth, renderedHeight;
|
| 83 |
+
if (containerAspectRatio > imageAspectRatio) {
|
| 84 |
+
// Image is constrained by height
|
| 85 |
+
renderedHeight = el.offsetHeight;
|
| 86 |
+
renderedWidth = renderedHeight * imageAspectRatio;
|
| 87 |
+
} else {
|
| 88 |
+
// Image is constrained by width
|
| 89 |
+
renderedWidth = el.offsetWidth;
|
| 90 |
+
renderedHeight = renderedWidth / imageAspectRatio;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// Calculate the empty space offset
|
| 94 |
+
const offsetX = (el.offsetWidth - renderedWidth) / 2;
|
| 95 |
+
const offsetY = (el.offsetHeight - renderedHeight) / 2;
|
| 96 |
+
|
| 97 |
+
// Get click position relative to the actual image
|
| 98 |
+
let x = e.clientX - rect.left - offsetX;
|
| 99 |
+
let y = e.clientY - rect.top - offsetY;
|
| 100 |
+
|
| 101 |
+
// Convert to original image coordinates
|
| 102 |
+
const scaleX = image ? image.width / renderedWidth : 1;
|
| 103 |
+
const scaleY = image ? image.height / renderedHeight : 1;
|
| 104 |
+
x *= scaleX;
|
| 105 |
+
y *= scaleY;
|
| 106 |
+
|
| 107 |
+
// Ensure coordinates are within bounds
|
| 108 |
+
if (image) {
|
| 109 |
+
x = Math.max(0, Math.min(x, image.width));
|
| 110 |
+
y = Math.max(0, Math.min(y, image.height));
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
const click = getClick(x, y);
|
| 114 |
+
if (click) {
|
| 115 |
+
setClicks([click]);
|
| 116 |
+
}
|
| 117 |
+
}, 15);
|
| 118 |
+
|
| 119 |
+
const handleDragEnter = (e: React.DragEvent) => {
|
| 120 |
+
e.preventDefault();
|
| 121 |
+
e.stopPropagation();
|
| 122 |
+
setIsDragging(true);
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
const handleDragLeave = (e: React.DragEvent) => {
|
| 126 |
+
e.preventDefault();
|
| 127 |
+
e.stopPropagation();
|
| 128 |
+
setIsDragging(false);
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
const handleDragOver = (e: React.DragEvent) => {
|
| 132 |
+
e.preventDefault();
|
| 133 |
+
e.stopPropagation();
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
const handleDrop = async (e: React.DragEvent) => {
|
| 137 |
+
e.preventDefault();
|
| 138 |
+
e.stopPropagation();
|
| 139 |
+
setIsDragging(false);
|
| 140 |
+
|
| 141 |
+
const files = e.dataTransfer.files;
|
| 142 |
+
if (files && files[0]) {
|
| 143 |
+
const file = files[0];
|
| 144 |
+
// Cast to unknown first, then to the desired type
|
| 145 |
+
const syntheticEvent = {
|
| 146 |
+
target: {
|
| 147 |
+
files: [file]
|
| 148 |
+
}
|
| 149 |
+
} as unknown as React.ChangeEvent<HTMLInputElement>;
|
| 150 |
+
|
| 151 |
+
onImageUpload(syntheticEvent);
|
| 152 |
+
}
|
| 153 |
+
};
|
| 154 |
+
|
| 155 |
+
const flexCenterClasses = "flex items-center justify-center";
|
| 156 |
+
|
| 157 |
+
// const handleDescribeMask = async () => {
|
| 158 |
+
// if (!maskImg || !maskImgData || !image) {
|
| 159 |
+
// console.warn('No mask or image available to describe');
|
| 160 |
+
// return;
|
| 161 |
+
// }
|
| 162 |
+
|
| 163 |
+
// try {
|
| 164 |
+
// const canvas = document.createElement('canvas');
|
| 165 |
+
// canvas.width = image.width;
|
| 166 |
+
// canvas.height = image.height;
|
| 167 |
+
// const ctx = canvas.getContext('2d');
|
| 168 |
+
// ctx?.drawImage(image, 0, 0);
|
| 169 |
+
// const imageBase64 = canvas.toDataURL('image/jpeg').split(',')[1];
|
| 170 |
+
// const maskBase64 = maskImgData.split(',')[1];
|
| 171 |
+
|
| 172 |
+
// const result = await describeMask(maskBase64, imageBase64);
|
| 173 |
+
// console.log('Mask description:', result.description);
|
| 174 |
+
|
| 175 |
+
// alert("Mask description: " + result.description);
|
| 176 |
+
// } catch (error) {
|
| 177 |
+
// console.error('Failed to describe mask:', error);
|
| 178 |
+
// }
|
| 179 |
+
// };
|
| 180 |
+
|
| 181 |
+
return (
|
| 182 |
+
<div
|
| 183 |
+
className={`flex flex-col w-full h-full relative`}
|
| 184 |
+
onDragEnter={handleDragEnter}
|
| 185 |
+
onDragOver={handleDragOver}
|
| 186 |
+
onDragLeave={handleDragLeave}
|
| 187 |
+
onDrop={handleDrop}
|
| 188 |
+
>
|
| 189 |
+
{/* Title and Description */}
|
| 190 |
+
<div className="w-full px-8 mb-8 flex flex-col justify-center mt-4">
|
| 191 |
+
<div className="flex flex-col sm:flex-row justify-between items-center gap-4">
|
| 192 |
+
<h1 className="text-3xl font-bold text-center sm:text-left"><a href="/">Describe Anything Model Demo</a></h1>
|
| 193 |
+
<div className="flex flex-wrap justify-center gap-4 sm:space-x-8 text-lg font-medium">
|
| 194 |
+
<a href="https://describe-anything.github.io/" target="_blank" rel="noopener noreferrer" className="text-gray-600 hover:text-gray-800">Project Page</a>
|
| 195 |
+
<a href="https://github.com/NVlabs/describe-anything?tab=readme-ov-file#simple-gradio-demo-for-detailed-localized-video-descriptions" target="_blank" rel="noopener noreferrer" className="text-gray-600 hover:text-gray-800">DAM for video</a>
|
| 196 |
+
</div>
|
| 197 |
+
</div>
|
| 198 |
+
<div className="border-b border-gray-300 mt-4 mb-4"></div>
|
| 199 |
+
{!image && <div className="space-y-4 text-gray-600 text-left">
|
| 200 |
+
<p>Describe Anything Model (DAM) takes in a region of an image or a video in the form of points/boxes/scribbles/masks and outputs detailed descriptions to the region. For videos, it is sufficient to supply annotation on any frame.</p>
|
| 201 |
+
<p>This demo supports DAM model that takes points on images as queries. For other use cases, please refer to the <a href="" className="text-gray-600 hover:text-gray-800 underline">inference scripts and video demo</a> for more details.</p>
|
| 202 |
+
</div>}
|
| 203 |
+
</div>
|
| 204 |
+
|
| 205 |
+
{/* Main Content Area */}
|
| 206 |
+
<div className={`flex items-center justify-center flex-grow overflow-hidden`}>
|
| 207 |
+
{/* Main Stage */}
|
| 208 |
+
<div
|
| 209 |
+
className={`${flexCenterClasses} relative w-full h-full max-h-[calc(100vh-300px)] px-8 ${
|
| 210 |
+
isDragging ? 'border-4 border-dashed border-blue-500 bg-blue-50' : ''
|
| 211 |
+
}`}
|
| 212 |
+
>
|
| 213 |
+
{image ? (
|
| 214 |
+
<>
|
| 215 |
+
<Tool
|
| 216 |
+
handleMouseMove={handleMouseMove}
|
| 217 |
+
descriptionState={descriptionState}
|
| 218 |
+
setDescriptionState={setDescriptionState}
|
| 219 |
+
queueStatus={queueStatus}
|
| 220 |
+
setQueueStatus={setQueueStatus}
|
| 221 |
+
/>
|
| 222 |
+
</>
|
| 223 |
+
) : (
|
| 224 |
+
<>
|
| 225 |
+
<div className="flex flex-col items-center gap-6 w-full h-full">
|
| 226 |
+
<div className="flex-1" />
|
| 227 |
+
|
| 228 |
+
<div className="text-gray-500 text-lg">
|
| 229 |
+
{isDragging ? 'Drop image here' : 'Upload your own image'}
|
| 230 |
+
</div>
|
| 231 |
+
<div className="flex gap-4 mb-8">
|
| 232 |
+
<input
|
| 233 |
+
type="file"
|
| 234 |
+
id="imageUpload"
|
| 235 |
+
accept="image/*"
|
| 236 |
+
onChange={onImageUpload}
|
| 237 |
+
className="hidden"
|
| 238 |
+
/>
|
| 239 |
+
<label
|
| 240 |
+
htmlFor="imageUpload"
|
| 241 |
+
className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded cursor-pointer"
|
| 242 |
+
>
|
| 243 |
+
Upload Image
|
| 244 |
+
</label>
|
| 245 |
+
</div>
|
| 246 |
+
|
| 247 |
+
<div className="text-gray-500 text-lg">
|
| 248 |
+
or choose an example image below
|
| 249 |
+
</div>
|
| 250 |
+
|
| 251 |
+
<div className="relative w-full max-w-[2200px]">
|
| 252 |
+
{/* Left Arrow */}
|
| 253 |
+
<button
|
| 254 |
+
onClick={() => setCurrentPage(prev => Math.max(prev - 1, 1))}
|
| 255 |
+
disabled={currentPage === 1}
|
| 256 |
+
className={`absolute left-0 top-1/2 -translate-y-1/2 z-10 p-4 ${
|
| 257 |
+
currentPage === 1
|
| 258 |
+
? 'text-gray-300 cursor-not-allowed'
|
| 259 |
+
: 'text-gray-600 hover:text-gray-800'
|
| 260 |
+
}`}
|
| 261 |
+
>
|
| 262 |
+
<svg xmlns="http://www.w3.org/2000/svg" className="h-8 w-8" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
| 263 |
+
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M15 19l-7-7 7-7" />
|
| 264 |
+
</svg>
|
| 265 |
+
</button>
|
| 266 |
+
|
| 267 |
+
{/* Example Images */}
|
| 268 |
+
<div className="flex flex-wrap justify-center gap-8 px-16">
|
| 269 |
+
{EXAMPLE_IMAGES.slice(
|
| 270 |
+
(currentPage - 1) * imagesPerPage,
|
| 271 |
+
currentPage * imagesPerPage
|
| 272 |
+
).map((src, index) => (
|
| 273 |
+
<img
|
| 274 |
+
key={index}
|
| 275 |
+
src={src}
|
| 276 |
+
alt={`Example ${index + 1}`}
|
| 277 |
+
className="w-[200px] h-[150px] object-cover rounded-sm cursor-pointer hover:opacity-80 transition-opacity"
|
| 278 |
+
onClick={() => {
|
| 279 |
+
fetch(src)
|
| 280 |
+
.then(res => res.blob())
|
| 281 |
+
.then(blob => {
|
| 282 |
+
const file = new File([blob], `example-${index + 1}.jpg`, { type: 'image/jpeg' });
|
| 283 |
+
const syntheticEvent = {
|
| 284 |
+
target: {
|
| 285 |
+
files: [file]
|
| 286 |
+
}
|
| 287 |
+
} as unknown as React.ChangeEvent<HTMLInputElement>;
|
| 288 |
+
|
| 289 |
+
onImageUpload(syntheticEvent);
|
| 290 |
+
});
|
| 291 |
+
}}
|
| 292 |
+
/>
|
| 293 |
+
))}
|
| 294 |
+
</div>
|
| 295 |
+
|
| 296 |
+
{/* Right Arrow */}
|
| 297 |
+
<button
|
| 298 |
+
onClick={() => setCurrentPage(prev => Math.min(prev + 1, Math.ceil(EXAMPLE_IMAGES.length / imagesPerPage)))}
|
| 299 |
+
disabled={currentPage === Math.ceil(EXAMPLE_IMAGES.length / imagesPerPage)}
|
| 300 |
+
className={`absolute right-0 top-1/2 -translate-y-1/2 z-10 p-4 ${
|
| 301 |
+
currentPage === Math.ceil(EXAMPLE_IMAGES.length / imagesPerPage)
|
| 302 |
+
? 'text-gray-300 cursor-not-allowed'
|
| 303 |
+
: 'text-gray-600 hover:text-gray-800'
|
| 304 |
+
}`}
|
| 305 |
+
>
|
| 306 |
+
<svg xmlns="http://www.w3.org/2000/svg" className="h-8 w-8" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
| 307 |
+
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 5l7 7-7 7" />
|
| 308 |
+
</svg>
|
| 309 |
+
</button>
|
| 310 |
+
|
| 311 |
+
{/* Page Indicator */}
|
| 312 |
+
{/* <div className="w-full text-center mt-4 text-gray-600">
|
| 313 |
+
Page {currentPage} of {Math.ceil(EXAMPLE_IMAGES.length / imagesPerPage)}
|
| 314 |
+
</div> */}
|
| 315 |
+
</div>
|
| 316 |
+
|
| 317 |
+
<div className="flex-1" /> {/* Bottom spacer */}
|
| 318 |
+
{/* Image Credits */}
|
| 319 |
+
{!image && (
|
| 320 |
+
<div className="pl-5 pr-5 text-gray-500 text-sm">
|
| 321 |
+
Image credit for example images: {' '}
|
| 322 |
+
<a
|
| 323 |
+
href="https://segment-anything.com/terms"
|
| 324 |
+
target="_blank"
|
| 325 |
+
className="text-gray-600 hover:text-gray-800 underline"
|
| 326 |
+
>
|
| 327 |
+
Segment Anything Materials
|
| 328 |
+
</a>
|
| 329 |
+
{' '}(CC BY-SA 4.0)
|
| 330 |
+
</div>
|
| 331 |
+
)}
|
| 332 |
+
</div>
|
| 333 |
+
</>
|
| 334 |
+
)}
|
| 335 |
+
</div>
|
| 336 |
+
</div>
|
| 337 |
+
|
| 338 |
+
</div>
|
| 339 |
+
);
|
| 340 |
+
};
|
| 341 |
+
|
| 342 |
+
export default Stage;
|
| 343 |
+
export type { DescriptionState };
|
frontend/src/components/Tool.tsx
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useContext, useEffect, useState } from "react";
|
| 2 |
+
import AppContext from "./hooks/createContext";
|
| 3 |
+
import { ToolProps, QueueStatus } from "./helpers/Interfaces";
|
| 4 |
+
import * as _ from "underscore";
|
| 5 |
+
import { describeMask, describeMaskWithoutStreaming } from "../services/maskApi";
|
| 6 |
+
import ErrorModal from './ErrorModal';
|
| 7 |
+
import { DescriptionState } from "./Stage";
|
| 8 |
+
|
| 9 |
+
const prompt = "<image>\nDescribe the masked region in detail.";
|
| 10 |
+
|
| 11 |
+
const Tool = ({
|
| 12 |
+
handleMouseMove,
|
| 13 |
+
descriptionState,
|
| 14 |
+
setDescriptionState,
|
| 15 |
+
queueStatus,
|
| 16 |
+
setQueueStatus
|
| 17 |
+
}: ToolProps) => {
|
| 18 |
+
console.log("Tool handleMouseMove");
|
| 19 |
+
const {
|
| 20 |
+
image: [image],
|
| 21 |
+
maskImg: [maskImg, setMaskImg],
|
| 22 |
+
maskImgData: [maskImgData, setMaskImgData],
|
| 23 |
+
isClicked: [isClicked, setIsClicked]
|
| 24 |
+
} = useContext(AppContext)!;
|
| 25 |
+
|
| 26 |
+
const [shouldFitToWidth, setShouldFitToWidth] = useState(true);
|
| 27 |
+
const bodyEl = document.body;
|
| 28 |
+
const fitToPage = () => {
|
| 29 |
+
if (!image) return;
|
| 30 |
+
const maxWidth = window.innerWidth - 64; // Account for padding (32px on each side)
|
| 31 |
+
const maxHeight = window.innerHeight - 200; // Account for header and some padding
|
| 32 |
+
const imageAspectRatio = image.width / image.height;
|
| 33 |
+
const containerAspectRatio = maxWidth / maxHeight;
|
| 34 |
+
|
| 35 |
+
setShouldFitToWidth(
|
| 36 |
+
imageAspectRatio > containerAspectRatio ||
|
| 37 |
+
image.width > maxWidth
|
| 38 |
+
);
|
| 39 |
+
};
|
| 40 |
+
const resizeObserver = new ResizeObserver((entries) => {
|
| 41 |
+
for (const entry of entries) {
|
| 42 |
+
if (entry.target === bodyEl) {
|
| 43 |
+
fitToPage();
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
});
|
| 47 |
+
useEffect(() => {
|
| 48 |
+
fitToPage();
|
| 49 |
+
resizeObserver.observe(bodyEl);
|
| 50 |
+
return () => {
|
| 51 |
+
resizeObserver.unobserve(bodyEl);
|
| 52 |
+
};
|
| 53 |
+
}, [image]);
|
| 54 |
+
|
| 55 |
+
const imageClasses = "";
|
| 56 |
+
const maskImageClasses = `absolute opacity-40 pointer-events-none`;
|
| 57 |
+
|
| 58 |
+
const [error, setError] = useState<string | null>(null);
|
| 59 |
+
const [useStreaming, setUseStreaming] = useState(true);
|
| 60 |
+
|
| 61 |
+
useEffect(() => {
|
| 62 |
+
if (!isClicked || !maskImg || !maskImgData || !image || descriptionState.state !== 'ready') {
|
| 63 |
+
console.log("Not ready to call model, isClicked:", isClicked, "maskImg:", maskImg !== null, "maskImgData:", maskImgData !== null, "image:", image !== null, "descriptionState.state:", descriptionState.state);
|
| 64 |
+
return;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
try {
|
| 68 |
+
setDescriptionState({
|
| 69 |
+
state: 'describing',
|
| 70 |
+
description: ''
|
| 71 |
+
} as DescriptionState);
|
| 72 |
+
|
| 73 |
+
const canvas = document.createElement('canvas');
|
| 74 |
+
canvas.width = image.width;
|
| 75 |
+
canvas.height = image.height;
|
| 76 |
+
const ctx = canvas.getContext('2d');
|
| 77 |
+
ctx?.drawImage(image, 0, 0);
|
| 78 |
+
const imageBase64 = canvas.toDataURL('image/jpeg').split(',')[1];
|
| 79 |
+
const maskBase64 = maskImgData.split(',')[1];
|
| 80 |
+
|
| 81 |
+
const describeMaskWithFallback = async (useStreamingInFunction: boolean) => {
|
| 82 |
+
try {
|
| 83 |
+
let result;
|
| 84 |
+
console.log("useStreaming", useStreaming, "useStreamingInFunction", useStreamingInFunction);
|
| 85 |
+
if (useStreamingInFunction) {
|
| 86 |
+
result = await describeMask(
|
| 87 |
+
maskBase64,
|
| 88 |
+
imageBase64,
|
| 89 |
+
prompt,
|
| 90 |
+
(streamResult: string) => {
|
| 91 |
+
setDescriptionState({
|
| 92 |
+
state: 'describing',
|
| 93 |
+
description: streamResult
|
| 94 |
+
} as DescriptionState);
|
| 95 |
+
},
|
| 96 |
+
(status: QueueStatus) => {
|
| 97 |
+
setQueueStatus(status);
|
| 98 |
+
}
|
| 99 |
+
);
|
| 100 |
+
} else {
|
| 101 |
+
result = await describeMaskWithoutStreaming(
|
| 102 |
+
maskBase64,
|
| 103 |
+
imageBase64,
|
| 104 |
+
prompt
|
| 105 |
+
);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
setDescriptionState({
|
| 109 |
+
state: 'described',
|
| 110 |
+
description: result
|
| 111 |
+
} as DescriptionState);
|
| 112 |
+
setQueueStatus({ inQueue: false });
|
| 113 |
+
setIsClicked(false);
|
| 114 |
+
} catch (error) {
|
| 115 |
+
if (useStreaming) {
|
| 116 |
+
console.log("Error describing mask, switching to non-streaming", error);
|
| 117 |
+
setUseStreaming(false);
|
| 118 |
+
describeMaskWithFallback(false);
|
| 119 |
+
} else {
|
| 120 |
+
setError('Failed to generate description. Please try again.');
|
| 121 |
+
setDescriptionState({
|
| 122 |
+
state: 'ready',
|
| 123 |
+
description: ''
|
| 124 |
+
} as DescriptionState);
|
| 125 |
+
setIsClicked(false);
|
| 126 |
+
console.error('Failed to describe mask:', error);
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
describeMaskWithFallback(useStreaming);
|
| 132 |
+
|
| 133 |
+
} catch (error) {
|
| 134 |
+
setIsClicked(false);
|
| 135 |
+
setError('Failed to generate description. Please try again.');
|
| 136 |
+
setDescriptionState({
|
| 137 |
+
state: 'ready',
|
| 138 |
+
description: ''
|
| 139 |
+
} as DescriptionState);
|
| 140 |
+
console.error('Failed to describe mask:', error);
|
| 141 |
+
}
|
| 142 |
+
}, [maskImgData]);
|
| 143 |
+
|
| 144 |
+
const handleClick = async (e: React.MouseEvent<HTMLImageElement>) => {
|
| 145 |
+
if (descriptionState.state !== 'ready') return;
|
| 146 |
+
|
| 147 |
+
setMaskImg(null);
|
| 148 |
+
setMaskImgData(null);
|
| 149 |
+
setIsClicked(true);
|
| 150 |
+
handleMouseMove(e);
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
return (
|
| 154 |
+
<>
|
| 155 |
+
{error && <ErrorModal message={error} onClose={() => setError(null)} />}
|
| 156 |
+
<div className="relative flex items-center justify-center w-full h-full">
|
| 157 |
+
{image && (
|
| 158 |
+
<img
|
| 159 |
+
onMouseMove={handleMouseMove}
|
| 160 |
+
onMouseLeave={() => _.defer(() => (descriptionState.state === 'ready' && !isClicked) ? setMaskImg(null) : undefined)}
|
| 161 |
+
onTouchStart={handleMouseMove}
|
| 162 |
+
onClick={handleClick}
|
| 163 |
+
src={image.src}
|
| 164 |
+
className={`${
|
| 165 |
+
shouldFitToWidth ? "w-full" : "h-full"
|
| 166 |
+
} ${imageClasses} object-contain max-h-full max-w-full`}
|
| 167 |
+
></img>
|
| 168 |
+
)}
|
| 169 |
+
{maskImg && (
|
| 170 |
+
<img
|
| 171 |
+
src={maskImg.src}
|
| 172 |
+
className={`${
|
| 173 |
+
shouldFitToWidth ? "w-full" : "h-full"
|
| 174 |
+
} ${maskImageClasses} object-contain max-h-full max-w-full`}
|
| 175 |
+
></img>
|
| 176 |
+
)}
|
| 177 |
+
</div>
|
| 178 |
+
</>
|
| 179 |
+
);
|
| 180 |
+
};
|
| 181 |
+
|
| 182 |
+
export default Tool;
|
frontend/src/components/helpers/Interfaces.tsx
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { Tensor } from "onnxruntime-web";
|
| 8 |
+
import { DescriptionState } from "../Stage";
|
| 9 |
+
|
| 10 |
+
export interface modelScaleProps {
|
| 11 |
+
samScale: number;
|
| 12 |
+
height: number;
|
| 13 |
+
width: number;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
export interface modelInputProps {
|
| 17 |
+
x: number;
|
| 18 |
+
y: number;
|
| 19 |
+
clickType: number;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
export interface modeDataProps {
|
| 23 |
+
clicks?: Array<modelInputProps>;
|
| 24 |
+
tensor: Tensor;
|
| 25 |
+
modelScale: modelScaleProps;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
export interface ToolProps {
|
| 29 |
+
handleMouseMove: (e: any) => void;
|
| 30 |
+
descriptionState: DescriptionState;
|
| 31 |
+
setDescriptionState: (value: DescriptionState) => void;
|
| 32 |
+
queueStatus: QueueStatus;
|
| 33 |
+
setQueueStatus: (value: QueueStatus) => void;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
export interface StageProps {
|
| 37 |
+
onImageUpload: (event: React.ChangeEvent<HTMLInputElement>) => void;
|
| 38 |
+
descriptionState: DescriptionState;
|
| 39 |
+
setDescriptionState: (value: DescriptionState) => void;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
export interface QueueStatus {
|
| 43 |
+
inQueue: boolean;
|
| 44 |
+
rank?: number;
|
| 45 |
+
queueSize?: number;
|
| 46 |
+
rankEta?: number | null;
|
| 47 |
+
}
|
frontend/src/components/helpers/imageUtils.tsx
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { Buffer } from 'buffer';
|
| 2 |
+
|
| 3 |
+
export const base64ToImage = async (base64String: string): Promise<HTMLImageElement> => {
|
| 4 |
+
return new Promise((resolve, reject) => {
|
| 5 |
+
const img = new Image();
|
| 6 |
+
img.onload = () => resolve(img);
|
| 7 |
+
img.onerror = reject;
|
| 8 |
+
img.src = base64String.startsWith('data:') ?
|
| 9 |
+
base64String :
|
| 10 |
+
`data:image/png;base64,${base64String}`;
|
| 11 |
+
});
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
export const imageToBase64 = (img: HTMLImageElement): string => {
|
| 15 |
+
const canvas = document.createElement('canvas');
|
| 16 |
+
canvas.width = img.width;
|
| 17 |
+
canvas.height = img.height;
|
| 18 |
+
const ctx = canvas.getContext('2d');
|
| 19 |
+
ctx?.drawImage(img, 0, 0);
|
| 20 |
+
return canvas.toDataURL('image/png');
|
| 21 |
+
};
|
frontend/src/components/helpers/maskUtils.tsx
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// Convert the onnx model mask prediction to ImageData
|
| 8 |
+
function arrayToImageData(input: any, width: number, height: number, binary: boolean) {
|
| 9 |
+
let [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
|
| 10 |
+
let [r_bg, g_bg, b_bg, a_bg] = [0, 0, 0, 0]; // the background's white color
|
| 11 |
+
if (binary) {
|
| 12 |
+
[r, g, b, a] = [255, 255, 255, 255]; // black and white
|
| 13 |
+
[r_bg, g_bg, b_bg, a_bg] = [0, 0, 0, 255]; // black and white
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
const arr = new Uint8ClampedArray(4 * width * height).fill(0);
|
| 17 |
+
for (let i = 0; i < input.length; i++) {
|
| 18 |
+
|
| 19 |
+
// Threshold the onnx model mask prediction at 0.0
|
| 20 |
+
// This is equivalent to thresholding the mask using predictor.model.mask_threshold
|
| 21 |
+
// in python
|
| 22 |
+
if (input[i] > 0.0) {
|
| 23 |
+
arr[4 * i + 0] = r;
|
| 24 |
+
arr[4 * i + 1] = g;
|
| 25 |
+
arr[4 * i + 2] = b;
|
| 26 |
+
arr[4 * i + 3] = a;
|
| 27 |
+
} else if (binary){
|
| 28 |
+
arr[4 * i + 0] = r_bg;
|
| 29 |
+
arr[4 * i + 1] = g_bg;
|
| 30 |
+
arr[4 * i + 2] = b_bg;
|
| 31 |
+
arr[4 * i + 3] = a_bg;
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
return new ImageData(arr, height, width);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Use a Canvas element to produce an image from ImageData
|
| 38 |
+
function imageDataToImage(imageData: ImageData) {
|
| 39 |
+
const canvas = imageDataToCanvas(imageData);
|
| 40 |
+
const image = new Image();
|
| 41 |
+
image.src = canvas.toDataURL();
|
| 42 |
+
return image;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
function imageDataToURL(imageData: ImageData) {
|
| 46 |
+
const canvas = imageDataToCanvas(imageData);
|
| 47 |
+
return canvas.toDataURL();
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// Canvas elements can be created from ImageData
|
| 51 |
+
function imageDataToCanvas(imageData: ImageData) {
|
| 52 |
+
const canvas = document.createElement("canvas");
|
| 53 |
+
const ctx = canvas.getContext("2d");
|
| 54 |
+
canvas.width = imageData.width;
|
| 55 |
+
canvas.height = imageData.height;
|
| 56 |
+
ctx?.putImageData(imageData, 0, 0);
|
| 57 |
+
return canvas;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
// Convert the onnx model mask output to an HTMLImageElement
|
| 61 |
+
function onnxMaskToImage(input: any, width: number, height: number, binary: boolean) {
|
| 62 |
+
return imageDataToImage(arrayToImageData(input, width, height, binary));
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
export { arrayToImageData, imageDataToImage, onnxMaskToImage, imageDataToURL };
|
frontend/src/components/helpers/onnxModelAPI.tsx
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { Tensor } from "onnxruntime-web";
|
| 8 |
+
import { modeDataProps } from "./Interfaces";
|
| 9 |
+
|
| 10 |
+
const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => {
|
| 11 |
+
const imageEmbedding = tensor;
|
| 12 |
+
let pointCoords;
|
| 13 |
+
let pointLabels;
|
| 14 |
+
let pointCoordsTensor;
|
| 15 |
+
let pointLabelsTensor;
|
| 16 |
+
|
| 17 |
+
// Check there are input click prompts
|
| 18 |
+
if (clicks) {
|
| 19 |
+
let n = clicks.length;
|
| 20 |
+
|
| 21 |
+
// If there is no box input, a single padding point with
|
| 22 |
+
// label -1 and coordinates (0.0, 0.0) should be concatenated
|
| 23 |
+
// so initialize the array to support (n + 1) points.
|
| 24 |
+
pointCoords = new Float32Array(2 * (n + 1));
|
| 25 |
+
pointLabels = new Float32Array(n + 1);
|
| 26 |
+
|
| 27 |
+
// Add clicks and scale to what SAM expects
|
| 28 |
+
for (let i = 0; i < n; i++) {
|
| 29 |
+
pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
|
| 30 |
+
pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
|
| 31 |
+
pointLabels[i] = clicks[i].clickType;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// Add in the extra point/label when only clicks and no box
|
| 35 |
+
// The extra point is at (0, 0) with label -1
|
| 36 |
+
pointCoords[2 * n] = 0.0;
|
| 37 |
+
pointCoords[2 * n + 1] = 0.0;
|
| 38 |
+
pointLabels[n] = -1.0;
|
| 39 |
+
|
| 40 |
+
// Create the tensor
|
| 41 |
+
pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
|
| 42 |
+
pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
|
| 43 |
+
}
|
| 44 |
+
const imageSizeTensor = new Tensor("float32", [
|
| 45 |
+
modelScale.height,
|
| 46 |
+
modelScale.width,
|
| 47 |
+
]);
|
| 48 |
+
|
| 49 |
+
if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
|
| 50 |
+
return;
|
| 51 |
+
|
| 52 |
+
// There is no previous mask, so default to an empty tensor
|
| 53 |
+
const maskInput = new Tensor(
|
| 54 |
+
"float32",
|
| 55 |
+
new Float32Array(256 * 256),
|
| 56 |
+
[1, 1, 256, 256]
|
| 57 |
+
);
|
| 58 |
+
// There is no previous mask, so default to 0
|
| 59 |
+
const hasMaskInput = new Tensor("float32", [0]);
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
image_embeddings: imageEmbedding,
|
| 63 |
+
point_coords: pointCoordsTensor,
|
| 64 |
+
point_labels: pointLabelsTensor,
|
| 65 |
+
orig_im_size: imageSizeTensor,
|
| 66 |
+
mask_input: maskInput,
|
| 67 |
+
has_mask_input: hasMaskInput,
|
| 68 |
+
};
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
export { modelData };
|
frontend/src/components/helpers/scaleHelper.tsx
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
// Helper function for handling image scaling needed for SAM
|
| 9 |
+
const handleImageScale = (image: HTMLImageElement) => {
|
| 10 |
+
// Input images to SAM must be resized so the longest side is 1024
|
| 11 |
+
const LONG_SIDE_LENGTH = 1024;
|
| 12 |
+
let w = image.naturalWidth;
|
| 13 |
+
let h = image.naturalHeight;
|
| 14 |
+
const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
|
| 15 |
+
return { height: h, width: w, samScale };
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
export { handleImageScale };
|
frontend/src/components/hooks/context.tsx
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import React, { useState } from "react";
|
| 8 |
+
import { modelInputProps } from "../helpers/Interfaces";
|
| 9 |
+
import AppContext from "./createContext";
|
| 10 |
+
|
| 11 |
+
const AppContextProvider = (props: {
|
| 12 |
+
children: React.ReactElement<any, string | React.JSXElementConstructor<any>>;
|
| 13 |
+
}) => {
|
| 14 |
+
const [clicks, setClicks] = useState<Array<modelInputProps> | null>(null);
|
| 15 |
+
const [image, setImage] = useState<HTMLImageElement | null>(null);
|
| 16 |
+
const [maskImg, setMaskImg] = useState<HTMLImageElement | null>(null);
|
| 17 |
+
const [maskImgData, setMaskImgData] = useState<string | null>(null);
|
| 18 |
+
const [isClicked, setIsClicked] = useState<boolean>(false);
|
| 19 |
+
|
| 20 |
+
return (
|
| 21 |
+
<AppContext.Provider
|
| 22 |
+
value={{
|
| 23 |
+
clicks: [clicks, setClicks],
|
| 24 |
+
image: [image, setImage],
|
| 25 |
+
maskImg: [maskImg, setMaskImg],
|
| 26 |
+
maskImgData: [maskImgData, setMaskImgData],
|
| 27 |
+
isClicked: [isClicked, setIsClicked],
|
| 28 |
+
}}
|
| 29 |
+
>
|
| 30 |
+
{props.children}
|
| 31 |
+
</AppContext.Provider>
|
| 32 |
+
);
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
export default AppContextProvider;
|
frontend/src/components/hooks/createContext.tsx
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import { createContext } from "react";
|
| 8 |
+
import { modelInputProps } from "../helpers/Interfaces";
|
| 9 |
+
|
| 10 |
+
interface contextProps {
|
| 11 |
+
clicks: [
|
| 12 |
+
clicks: modelInputProps[] | null,
|
| 13 |
+
setClicks: (e: modelInputProps[] | null) => void
|
| 14 |
+
];
|
| 15 |
+
image: [
|
| 16 |
+
image: HTMLImageElement | null,
|
| 17 |
+
setImage: (e: HTMLImageElement | null) => void
|
| 18 |
+
];
|
| 19 |
+
maskImg: [
|
| 20 |
+
maskImg: HTMLImageElement | null,
|
| 21 |
+
setMaskImg: (e: HTMLImageElement | null) => void
|
| 22 |
+
];
|
| 23 |
+
maskImgData: [
|
| 24 |
+
maskImgData: string | null,
|
| 25 |
+
setMaskImgData: (e: string | null) => void
|
| 26 |
+
];
|
| 27 |
+
isClicked: [
|
| 28 |
+
isClicked: boolean,
|
| 29 |
+
setIsClicked: (e: boolean) => void
|
| 30 |
+
];
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
const AppContext = createContext<contextProps | null>(null);
|
| 34 |
+
|
| 35 |
+
export default AppContext;
|
frontend/src/index.tsx
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import * as React from "react";
|
| 8 |
+
import { createRoot } from "react-dom/client";
|
| 9 |
+
import AppContextProvider from "./components/hooks/context";
|
| 10 |
+
import App from "./App";
|
| 11 |
+
const container = document.getElementById("root");
|
| 12 |
+
const root = createRoot(container!);
|
| 13 |
+
root.render(
|
| 14 |
+
<AppContextProvider>
|
| 15 |
+
<App/>
|
| 16 |
+
</AppContextProvider>
|
| 17 |
+
);
|
frontend/src/services/maskApi.tsx
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import axios from 'axios';
|
| 2 |
+
import * as _ from 'underscore';
|
| 3 |
+
|
| 4 |
+
const API_URL = process.env.NODE_ENV === 'development' ? 'http://localhost:7860/gradio_api' : '/gradio_api';
|
| 5 |
+
|
| 6 |
+
export const describeMaskWithoutStreaming = _.throttle(async (
|
| 7 |
+
maskBase64: string,
|
| 8 |
+
imageBase64: string,
|
| 9 |
+
query: string
|
| 10 |
+
): Promise<string> => {
|
| 11 |
+
try {
|
| 12 |
+
const response = await axios.post(`${API_URL}/run/describe_without_streaming`, {
|
| 13 |
+
data: [imageBase64, maskBase64, query],
|
| 14 |
+
});
|
| 15 |
+
|
| 16 |
+
console.log("response", response.data);
|
| 17 |
+
return response.data.data[0];
|
| 18 |
+
} catch (error) {
|
| 19 |
+
console.error('Error describing mask:', error);
|
| 20 |
+
throw error;
|
| 21 |
+
}
|
| 22 |
+
}, 100);
|
| 23 |
+
|
| 24 |
+
export const describeMask = _.throttle(async (
|
| 25 |
+
maskBase64: string,
|
| 26 |
+
imageBase64: string,
|
| 27 |
+
query: string,
|
| 28 |
+
onStreamUpdate: (token: string) => void,
|
| 29 |
+
onQueueUpdate?: (status: {
|
| 30 |
+
inQueue: boolean,
|
| 31 |
+
rank?: number,
|
| 32 |
+
queueSize?: number,
|
| 33 |
+
rankEta?: number | null
|
| 34 |
+
}) => void
|
| 35 |
+
): Promise<string> => {
|
| 36 |
+
console.log("describeMask");
|
| 37 |
+
const initiateResponse = await axios.post(`${API_URL}/call/describe`, {
|
| 38 |
+
data: [imageBase64, maskBase64, query],
|
| 39 |
+
});
|
| 40 |
+
|
| 41 |
+
const eventId = initiateResponse.data.event_id;
|
| 42 |
+
|
| 43 |
+
const response = await axios.get(`${API_URL}/queue/data?session_hash=${eventId}`, {
|
| 44 |
+
headers: {
|
| 45 |
+
'Accept': 'text/event-stream',
|
| 46 |
+
},
|
| 47 |
+
responseType: 'stream',
|
| 48 |
+
adapter: 'fetch',
|
| 49 |
+
});
|
| 50 |
+
|
| 51 |
+
const stream = response.data;
|
| 52 |
+
const reader = stream.pipeThrough(new TextDecoderStream()).getReader();
|
| 53 |
+
|
| 54 |
+
let result = '';
|
| 55 |
+
let partialMessage = '';
|
| 56 |
+
|
| 57 |
+
while (true) {
|
| 58 |
+
const { value, done } = await reader.read();
|
| 59 |
+
if (done) {
|
| 60 |
+
return result;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
// Concatenate with any previous partial message
|
| 64 |
+
const currentData = partialMessage + value;
|
| 65 |
+
const lines = currentData.split('\n');
|
| 66 |
+
|
| 67 |
+
// Save the last line if it's incomplete
|
| 68 |
+
partialMessage = lines[lines.length - 1];
|
| 69 |
+
|
| 70 |
+
// Process all complete lines except the last one
|
| 71 |
+
let eventType = '';
|
| 72 |
+
for (let i = 0; i < lines.length - 1; i++) {
|
| 73 |
+
const line = lines[i];
|
| 74 |
+
if (line.startsWith('event: ')) {
|
| 75 |
+
eventType = line.slice(7); // Remove 'event: ' prefix
|
| 76 |
+
console.log('Event message', line);
|
| 77 |
+
} else if (line.startsWith('data: ')) {
|
| 78 |
+
const eventData = line.slice(6); // Remove 'data: ' prefix
|
| 79 |
+
try {
|
| 80 |
+
let data = JSON.parse(eventData);
|
| 81 |
+
if (data['msg']) {
|
| 82 |
+
eventType = data['msg'];
|
| 83 |
+
if (eventType === 'process_generating') {
|
| 84 |
+
eventType = 'generating';
|
| 85 |
+
data = data['output']['data'];
|
| 86 |
+
} else if (eventType === 'process_completed') {
|
| 87 |
+
eventType = 'complete';
|
| 88 |
+
data = data['output']['data'];
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if (eventType === 'estimation' && onQueueUpdate) {
|
| 93 |
+
onQueueUpdate({
|
| 94 |
+
inQueue: true,
|
| 95 |
+
rank: data.rank,
|
| 96 |
+
queueSize: data.queue_size,
|
| 97 |
+
rankEta: data.rank_eta
|
| 98 |
+
});
|
| 99 |
+
} else if (eventType === 'process_starts' && onQueueUpdate) {
|
| 100 |
+
onQueueUpdate({
|
| 101 |
+
inQueue: false
|
| 102 |
+
});
|
| 103 |
+
} else if ((eventType === 'generating' || eventType === 'complete') && data[0]) {
|
| 104 |
+
result = data[0];
|
| 105 |
+
onStreamUpdate(data[0]);
|
| 106 |
+
|
| 107 |
+
if (eventType === 'complete') {
|
| 108 |
+
return result;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
} catch (e) {
|
| 112 |
+
console.log('Error parsing SSE message:', e);
|
| 113 |
+
}
|
| 114 |
+
} else if (line !== '') {
|
| 115 |
+
console.log('Unknown message', line);
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
}, 100);
|
| 120 |
+
|
| 121 |
+
export const imageToSamEmbedding = _.throttle(async (
|
| 122 |
+
imageBase64: string,
|
| 123 |
+
onQueueUpdate?: (status: {
|
| 124 |
+
inQueue: boolean,
|
| 125 |
+
rank?: number,
|
| 126 |
+
queueSize?: number,
|
| 127 |
+
rankEta?: number | null
|
| 128 |
+
}) => void
|
| 129 |
+
): Promise<string> => {
|
| 130 |
+
// First call to initiate the process
|
| 131 |
+
const initiateResponse = await axios.post(`${API_URL}/call/image_to_sam_embedding`, {
|
| 132 |
+
data: [imageBase64]
|
| 133 |
+
});
|
| 134 |
+
|
| 135 |
+
const eventId = initiateResponse.data.event_id;
|
| 136 |
+
|
| 137 |
+
// Get the stream for queue updates and results
|
| 138 |
+
const response = await axios.get(`${API_URL}/queue/data?session_hash=${eventId}`, {
|
| 139 |
+
headers: {
|
| 140 |
+
'Accept': 'text/event-stream',
|
| 141 |
+
},
|
| 142 |
+
responseType: 'stream',
|
| 143 |
+
adapter: 'fetch',
|
| 144 |
+
});
|
| 145 |
+
|
| 146 |
+
const stream = response.data;
|
| 147 |
+
const reader = stream.pipeThrough(new TextDecoderStream()).getReader();
|
| 148 |
+
|
| 149 |
+
let result = '';
|
| 150 |
+
let partialMessage = '';
|
| 151 |
+
|
| 152 |
+
while (true) {
|
| 153 |
+
const { value, done } = await reader.read();
|
| 154 |
+
if (done) {
|
| 155 |
+
return result;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// Concatenate with any previous partial message
|
| 159 |
+
const currentData = partialMessage + value;
|
| 160 |
+
const lines = currentData.split('\n');
|
| 161 |
+
|
| 162 |
+
// Save the last line if it's incomplete (doesn't end with \n)
|
| 163 |
+
// The endpoint will send an empty line to indicate the end of a message, so it's ok to not process the partial message.
|
| 164 |
+
partialMessage = lines[lines.length - 1];
|
| 165 |
+
|
| 166 |
+
// Process all complete lines except the last one
|
| 167 |
+
let eventType = '';
|
| 168 |
+
for (let i = 0; i < lines.length - 1; i++) {
|
| 169 |
+
const line = lines[i];
|
| 170 |
+
if (line.startsWith('event: ')) {
|
| 171 |
+
eventType = line.slice(7);
|
| 172 |
+
} else if (line.startsWith('data: ')) {
|
| 173 |
+
const eventData = line.slice(6);
|
| 174 |
+
try {
|
| 175 |
+
let data = JSON.parse(eventData);
|
| 176 |
+
if (data['msg']) {
|
| 177 |
+
eventType = data['msg'];
|
| 178 |
+
console.log("Event type:", eventType);
|
| 179 |
+
if (eventType === 'process_completed') {
|
| 180 |
+
eventType = 'complete';
|
| 181 |
+
data = data['output']['data'];
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
if (eventType === 'estimation' && onQueueUpdate) {
|
| 186 |
+
onQueueUpdate({
|
| 187 |
+
inQueue: true,
|
| 188 |
+
rank: data.rank,
|
| 189 |
+
queueSize: data.queue_size,
|
| 190 |
+
rankEta: data.rank_eta
|
| 191 |
+
});
|
| 192 |
+
} else if (eventType === 'process_starts' && onQueueUpdate) {
|
| 193 |
+
onQueueUpdate({
|
| 194 |
+
inQueue: false
|
| 195 |
+
});
|
| 196 |
+
} else if (eventType === 'complete' && data[0]) {
|
| 197 |
+
result = data[0];
|
| 198 |
+
console.log("Result for image to sam embedding:", result);
|
| 199 |
+
return result;
|
| 200 |
+
} else {
|
| 201 |
+
console.log("Unknown event type:", eventType);
|
| 202 |
+
}
|
| 203 |
+
} catch (e) {
|
| 204 |
+
console.log('Error parsing SSE message:', e, 'Raw data:', eventData);
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
}, 100);
|
| 210 |
+
|
| 211 |
+
export { API_URL };
|
frontend/tailwind.config.js
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
/** @type {import('tailwindcss').Config} */
|
| 8 |
+
module.exports = {
|
| 9 |
+
content: ["./src/**/*.{html,js,tsx}"],
|
| 10 |
+
theme: {},
|
| 11 |
+
plugins: [],
|
| 12 |
+
};
|
frontend/tsconfig.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"lib": ["dom", "dom.iterable", "esnext"],
|
| 4 |
+
"allowJs": true,
|
| 5 |
+
"skipLibCheck": true,
|
| 6 |
+
"strict": true,
|
| 7 |
+
"forceConsistentCasingInFileNames": true,
|
| 8 |
+
"noEmit": false,
|
| 9 |
+
"esModuleInterop": true,
|
| 10 |
+
"module": "esnext",
|
| 11 |
+
"moduleResolution": "node",
|
| 12 |
+
"resolveJsonModule": true,
|
| 13 |
+
"isolatedModules": true,
|
| 14 |
+
"jsx": "react",
|
| 15 |
+
"incremental": true,
|
| 16 |
+
"target": "ESNext",
|
| 17 |
+
"useDefineForClassFields": true,
|
| 18 |
+
"allowSyntheticDefaultImports": true,
|
| 19 |
+
"outDir": "./dist/",
|
| 20 |
+
"sourceMap": true
|
| 21 |
+
},
|
| 22 |
+
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", "src"],
|
| 23 |
+
"exclude": ["node_modules"]
|
| 24 |
+
}
|
frontend/yarn.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|