Files changed (3) hide show
  1. package.json +1 -0
  2. src/App.jsx +106 -45
  3. src/worker.js +5 -10
package.json CHANGED
@@ -13,6 +13,7 @@
13
  "@huggingface/transformers": "3.7.5",
14
  "dompurify": "^3.1.2",
15
  "marked": "^12.0.2",
 
16
  "react": "^18.3.1",
17
  "react-dom": "^18.3.1"
18
  },
 
13
  "@huggingface/transformers": "3.7.5",
14
  "dompurify": "^3.1.2",
15
  "marked": "^12.0.2",
16
+ "pdfjs-dist": "^5.4.296",
17
  "react": "^18.3.1",
18
  "react-dom": "^18.3.1"
19
  },
src/App.jsx CHANGED
@@ -26,18 +26,46 @@ function App() {
26
  const [loadingMessage, setLoadingMessage] = useState("");
27
  const [progressItems, setProgressItems] = useState([]);
28
  const [isRunning, setIsRunning] = useState(false);
 
29
 
30
  // Inputs and outputs
31
  const [input, setInput] = useState("");
32
  const [messages, setMessages] = useState([]);
33
  const [tps, setTps] = useState(null);
34
  const [numTokens, setNumTokens] = useState(null);
35
-
36
- function onEnter(message) {
37
- setMessages((prev) => [...prev, { role: "user", content: message }]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  setTps(null);
39
  setIsRunning(true);
40
  setInput("");
 
41
  }
42
 
43
  function onInterrupt() {
@@ -73,7 +101,6 @@ function App() {
73
  const onMessageReceived = (e) => {
74
  switch (e.data.status) {
75
  case "loading":
76
- // Model file start load: add a new progress item to the list.
77
  setStatus("loading");
78
  setLoadingMessage(e.data.data);
79
  break;
@@ -83,7 +110,6 @@ function App() {
83
  break;
84
 
85
  case "progress":
86
- // Model file progress: update one of the progress items.
87
  setProgressItems((prev) =>
88
  prev.map((item) => {
89
  if (item.file === e.data.file) {
@@ -95,53 +121,45 @@ function App() {
95
  break;
96
 
97
  case "done":
98
- // Model file loaded: remove the progress item from the list.
99
  setProgressItems((prev) =>
100
  prev.filter((item) => item.file !== e.data.file),
101
  );
102
  break;
103
 
104
  case "ready":
105
- // Pipeline ready: the worker is ready to accept messages.
106
  setStatus("ready");
107
  break;
108
 
109
  case "start":
110
- {
111
- // Start generation
112
- setMessages((prev) => [
113
- ...prev,
114
- { role: "assistant", content: "" },
115
- ]);
116
- }
117
  break;
118
 
119
  case "update":
120
- {
121
- // Generation update: update the output text.
122
- // Parse messages
123
- const { output, tps, numTokens } = e.data;
124
- setTps(tps);
125
- setNumTokens(numTokens);
126
- setMessages((prev) => {
127
- const cloned = [...prev];
128
- const last = cloned.at(-1);
129
- cloned[cloned.length - 1] = {
130
- ...last,
131
- content: last.content + output,
132
- };
133
- return cloned;
134
- });
135
- }
136
  break;
137
 
138
  case "complete":
139
- // Generation complete: re-enable the "Generate" button
140
  setIsRunning(false);
141
  break;
142
 
143
  case "error":
144
- setError(e.data.data);
 
 
145
  break;
146
  }
147
  };
@@ -239,16 +257,36 @@ function App() {
239
  </div>
240
  )}
241
 
242
- <button
243
- className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
244
- onClick={() => {
245
- worker.current.postMessage({ type: "load" });
246
- setStatus("loading");
247
- }}
248
- disabled={status !== null || error !== null}
249
- >
250
- Load model
251
- </button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  </div>
253
  </div>
254
  )}
@@ -326,10 +364,10 @@ function App() {
326
  </div>
327
  )}
328
 
329
- <div className="mt-2 border dark:bg-gray-700 rounded-lg w-[600px] max-w-[80%] max-h-[200px] mx-auto relative mb-3 flex">
330
  <textarea
331
  ref={textareaRef}
332
- className="scrollbar-thin w-[550px] dark:bg-gray-700 px-3 py-4 rounded-lg bg-transparent border-none outline-none text-gray-800 disabled:text-gray-400 dark:text-gray-200 placeholder-gray-500 dark:placeholder-gray-400 disabled:placeholder-gray-200 resize-none disabled:cursor-not-allowed"
333
  placeholder="Type your message..."
334
  type="text"
335
  rows={1}
@@ -349,11 +387,31 @@ function App() {
349
  }}
350
  onInput={(e) => setInput(e.target.value)}
351
  />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  {isRunning ? (
353
  <div className="cursor-pointer" onClick={onInterrupt}>
354
  <StopIcon className="h-8 w-8 p-1 rounded-md text-gray-800 dark:text-gray-100 absolute right-3 bottom-3" />
355
  </div>
356
- ) : input.length > 0 ? (
357
  <div className="cursor-pointer" onClick={() => onEnter(input)}>
358
  <ArrowRightIcon
359
  className={`h-8 w-8 p-1 bg-gray-800 dark:bg-gray-100 text-white dark:text-black rounded-md absolute right-3 bottom-3`}
@@ -366,6 +424,9 @@ function App() {
366
  />
367
  </div>
368
  )}
 
 
 
369
  </div>
370
 
371
  <p className="text-xs text-gray-400 text-center mb-3">
 
26
  const [loadingMessage, setLoadingMessage] = useState("");
27
  const [progressItems, setProgressItems] = useState([]);
28
  const [isRunning, setIsRunning] = useState(false);
29
+ const [modelFiles, setModelFiles] = useState([]);
30
 
31
  // Inputs and outputs
32
  const [input, setInput] = useState("");
33
  const [messages, setMessages] = useState([]);
34
  const [tps, setTps] = useState(null);
35
  const [numTokens, setNumTokens] = useState(null);
36
+ const [attachedFile, setAttachedFile] = useState(null);
37
+
38
+ async function onEnter(message) {
39
+ let fileText = "";
40
+ if (attachedFile) {
41
+ if (attachedFile.name.endsWith(".txt")) {
42
+ fileText = await attachedFile.text();
43
+ } else if (attachedFile.name.endsWith(".pdf")) {
44
+ // Dynamically import pdfjs-dist
45
+ const pdfjsLib = await import("pdfjs-dist/build/pdf");
46
+ const workerSrc = (await import("pdfjs-dist/build/pdf.worker?url")).default;
47
+ pdfjsLib.GlobalWorkerOptions.workerSrc = workerSrc;
48
+ const arrayBuffer = await attachedFile.arrayBuffer();
49
+ const pdf = await pdfjsLib.getDocument({ data: arrayBuffer }).promise;
50
+ let pdfText = "";
51
+ for (let i = 1; i <= pdf.numPages; i++) {
52
+ const page = await pdf.getPage(i);
53
+ const content = await page.getTextContent();
54
+ pdfText += content.items.map(item => item.str).join(" ") + "\n";
55
+ }
56
+ fileText = pdfText;
57
+ }
58
+ }
59
+ let fullPrompt = message;
60
+ if (fileText) {
61
+ fullPrompt += "\n\n--- File Content ---\n" + fileText;
62
+ }
63
+ let userMsg = { role: "user", content: fullPrompt };
64
+ setMessages((prev) => [...prev, userMsg]);
65
  setTps(null);
66
  setIsRunning(true);
67
  setInput("");
68
+ setAttachedFile(null);
69
  }
70
 
71
  function onInterrupt() {
 
101
  const onMessageReceived = (e) => {
102
  switch (e.data.status) {
103
  case "loading":
 
104
  setStatus("loading");
105
  setLoadingMessage(e.data.data);
106
  break;
 
110
  break;
111
 
112
  case "progress":
 
113
  setProgressItems((prev) =>
114
  prev.map((item) => {
115
  if (item.file === e.data.file) {
 
121
  break;
122
 
123
  case "done":
 
124
  setProgressItems((prev) =>
125
  prev.filter((item) => item.file !== e.data.file),
126
  );
127
  break;
128
 
129
  case "ready":
 
130
  setStatus("ready");
131
  break;
132
 
133
  case "start":
134
+ setMessages((prev) => [
135
+ ...prev,
136
+ { role: "assistant", content: "" },
137
+ ]);
 
 
 
138
  break;
139
 
140
  case "update":
141
+ const { output, tps, numTokens } = e.data;
142
+ setTps(tps);
143
+ setNumTokens(numTokens);
144
+ setMessages((prev) => {
145
+ const cloned = [...prev];
146
+ const last = cloned.at(-1);
147
+ cloned[cloned.length - 1] = {
148
+ ...last,
149
+ content: last.content + output,
150
+ };
151
+ return cloned;
152
+ });
 
 
 
 
153
  break;
154
 
155
  case "complete":
 
156
  setIsRunning(false);
157
  break;
158
 
159
  case "error":
160
+ setError(e.data.data || "Unknown error during model loading.");
161
+ setStatus(null);
162
+ setLoadingMessage("");
163
  break;
164
  }
165
  };
 
257
  </div>
258
  )}
259
 
260
+ <div className="flex flex-col items-center gap-2">
261
+ <label className="border px-4 py-2 rounded-lg bg-blue-100 text-blue-700 cursor-pointer hover:bg-blue-200">
262
+ Select model directory
263
+ <input
264
+ type="file"
265
+ webkitdirectory="true"
266
+ directory="true"
267
+ multiple
268
+ className="hidden"
269
+ onChange={e => {
270
+ const files = Array.from(e.target.files);
271
+ setModelFiles(files);
272
+ }}
273
+ disabled={status !== null || error !== null}
274
+ />
275
+ </label>
276
+ <button
277
+ className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
278
+ onClick={() => {
279
+ worker.current.postMessage({ type: "load", files: modelFiles });
280
+ setStatus("loading");
281
+ }}
282
+ disabled={status !== null || error !== null}
283
+ >
284
+ {modelFiles.length > 0 ? `Load selected model directory` : `Load default model`}
285
+ </button>
286
+ {modelFiles.length > 0 && (
287
+ <span className="text-xs text-gray-600">Selected files: {modelFiles.map(f => f.name).join(", ")}</span>
288
+ )}
289
+ </div>
290
  </div>
291
  </div>
292
  )}
 
364
  </div>
365
  )}
366
 
367
+ <div className="mt-2 border dark:bg-gray-700 rounded-lg w-[600px] max-w-[80%] max-h-[200px] mx-auto relative mb-3 flex items-center gap-2">
368
  <textarea
369
  ref={textareaRef}
370
+ className="scrollbar-thin w-[420px] dark:bg-gray-700 px-3 py-4 rounded-lg bg-transparent border-none outline-none text-gray-800 disabled:text-gray-400 dark:text-gray-200 placeholder-gray-500 dark:placeholder-gray-400 disabled:placeholder-gray-200 resize-none disabled:cursor-not-allowed"
371
  placeholder="Type your message..."
372
  type="text"
373
  rows={1}
 
387
  }}
388
  onInput={(e) => setInput(e.target.value)}
389
  />
390
+ <label
391
+ className={`flex items-center px-2 py-2 bg-blue-100 text-blue-700 rounded cursor-pointer hover:bg-blue-200 ${status !== "ready" ? "opacity-50 cursor-not-allowed" : ""}`}
392
+ >
393
+ 📎 Attach
394
+ <input
395
+ type="file"
396
+ accept=".txt,.pdf"
397
+ className="hidden"
398
+ onChange={e => {
399
+ const file = e.target.files[0];
400
+ if (file && ![".txt", ".pdf"].some(ext => file.name.toLowerCase().endsWith(ext))) {
401
+ alert("Only .txt and .pdf files are allowed.");
402
+ e.target.value = "";
403
+ return;
404
+ }
405
+ setAttachedFile(file);
406
+ }}
407
+ disabled={status !== "ready"}
408
+ />
409
+ </label>
410
  {isRunning ? (
411
  <div className="cursor-pointer" onClick={onInterrupt}>
412
  <StopIcon className="h-8 w-8 p-1 rounded-md text-gray-800 dark:text-gray-100 absolute right-3 bottom-3" />
413
  </div>
414
+ ) : input.length > 0 || attachedFile ? (
415
  <div className="cursor-pointer" onClick={() => onEnter(input)}>
416
  <ArrowRightIcon
417
  className={`h-8 w-8 p-1 bg-gray-800 dark:bg-gray-100 text-white dark:text-black rounded-md absolute right-3 bottom-3`}
 
424
  />
425
  </div>
426
  )}
427
+ {attachedFile && (
428
+ <span className="ml-2 text-xs text-gray-600">{attachedFile.name}</span>
429
+ )}
430
  </div>
431
 
432
  <p className="text-xs text-gray-400 text-center mb-3">
src/worker.js CHANGED
@@ -30,18 +30,15 @@ async function check() {
30
  */
31
  class TextGenerationPipeline {
32
  static model_id = "onnx-community/granite-4.0-micro-ONNX-web";
33
-
34
  static async getInstance(progress_callback = null) {
35
  this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
36
  progress_callback,
37
  });
38
-
39
  this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
40
  dtype: "q4f16",
41
  device: "webgpu",
42
  progress_callback,
43
  });
44
-
45
  return Promise.all([this.tokenizer, this.model]);
46
  }
47
  }
@@ -114,18 +111,16 @@ async function generate(messages) {
114
  });
115
  }
116
 
117
- async function load() {
118
  self.postMessage({
119
  status: "loading",
120
- data: "Loading model...",
121
  });
122
 
123
  // Load the pipeline and save it for future use.
124
  const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
125
- // We also add a progress callback to the pipeline so that we can
126
- // track model loading.
127
  self.postMessage(x);
128
- });
129
 
130
  self.postMessage({
131
  status: "loading",
@@ -139,7 +134,7 @@ async function load() {
139
  }
140
  // Listen for messages from the main thread
141
  self.addEventListener("message", async (e) => {
142
- const { type, data } = e.data;
143
 
144
  switch (type) {
145
  case "check":
@@ -147,7 +142,7 @@ self.addEventListener("message", async (e) => {
147
  break;
148
 
149
  case "load":
150
- load();
151
  break;
152
 
153
  case "generate":
 
30
  */
31
  class TextGenerationPipeline {
32
  static model_id = "onnx-community/granite-4.0-micro-ONNX-web";
 
33
  static async getInstance(progress_callback = null) {
34
  this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
35
  progress_callback,
36
  });
 
37
  this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
38
  dtype: "q4f16",
39
  device: "webgpu",
40
  progress_callback,
41
  });
 
42
  return Promise.all([this.tokenizer, this.model]);
43
  }
44
  }
 
111
  });
112
  }
113
 
114
+ async function load(files = null) {
115
  self.postMessage({
116
  status: "loading",
117
+ data: files && files.length > 0 ? `Loading model from selected directory` : "Loading model...",
118
  });
119
 
120
  // Load the pipeline and save it for future use.
121
  const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
 
 
122
  self.postMessage(x);
123
+ }, files);
124
 
125
  self.postMessage({
126
  status: "loading",
 
134
  }
135
  // Listen for messages from the main thread
136
  self.addEventListener("message", async (e) => {
137
+ const { type, data, files } = e.data;
138
 
139
  switch (type) {
140
  case "check":
 
142
  break;
143
 
144
  case "load":
145
+ await load(files || null);
146
  break;
147
 
148
  case "generate":