|
|
from diffusers import FluxTransformer2DModel |
|
|
from diffusers.modular_pipelines import ModularPipelineBlocks, ComponentSpec, InputParam, PipelineState, OutputParam |
|
|
from typing import List |
|
|
|
|
|
class DummyCustomBlockSimple(ModularPipelineBlocks): |
|
|
def __init__(self, use_dummy_model_component=False): |
|
|
self.use_dummy_model_component = use_dummy_model_component |
|
|
super().__init__() |
|
|
|
|
|
@property |
|
|
def expected_components(self): |
|
|
if self.use_dummy_model_component: |
|
|
return [ComponentSpec("transformer", FluxTransformer2DModel)] |
|
|
else: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")] |
|
|
|
|
|
@property |
|
|
def intermediate_inputs(self) -> List[InputParam]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"output_prompt", |
|
|
type_hint=str, |
|
|
description="Modified prompt", |
|
|
) |
|
|
] |
|
|
|
|
|
def __call__(self, components, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
old_prompt = block_state.prompt |
|
|
block_state.output_prompt = "Modular diffusers + " + old_prompt |
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |