diff --git a/playground/testcases/flowchart.ts b/playground/testcases/flowchart.ts index 51ed4e8e..a2981e8d 100644 --- a/playground/testcases/flowchart.ts +++ b/playground/testcases/flowchart.ts @@ -360,6 +360,15 @@ style id2 fill:#bbf,stroke:#f66,stroke-width:2px,color:#fff,stroke-dasharray: 5 `, type: "flowchart", }, + { + name: "Styling a node using class", + definition: `flowchart LR + A:::foo & B:::bar --> C:::foobar + classDef foo stroke:#1971c2, fill:#4dabf7 + classDef bar stroke:#d6336c, fill:#f783ac + classDef foobar stroke:#00f stroke-width:2px`, + type: "flowchart", + }, { name: "Classes", definition: `flowchart LR diff --git a/src/parser/flowchart.ts b/src/parser/flowchart.ts index cae51d7c..ee096534 100644 --- a/src/parser/flowchart.ts +++ b/src/parser/flowchart.ts @@ -12,6 +12,7 @@ import { } from "../interfaces.js"; import type { Diagram } from "mermaid/dist/Diagram.js"; +import { DiagramStyleClassDef } from "mermaid/dist/diagram-api/types.js"; export interface Flowchart { type: "flowchart"; @@ -74,7 +75,11 @@ const parseSubGraph = (data: any, containerEl: Element): SubGraph => { }; }; -const parseVertex = (data: any, containerEl: Element): Vertex | undefined => { +const parseVertex = ( + data: any, + containerEl: Element, + classes: { [key: string]: DiagramStyleClassDef } +): Vertex | undefined => { // Find Vertex element const el: SVGSVGElement | null = containerEl.querySelector( `[id*="flowchart-${data.id}-"]` @@ -117,6 +122,7 @@ const parseVertex = (data: any, containerEl: Element): Vertex | undefined => { const value = property.split(":")[1].trim(); containerStyle[key] = value; }); + const labelStyle: Vertex["labelStyle"] = {}; labelStyleText?.split(";").forEach((property) => { if (!property) { @@ -128,6 +134,19 @@ const parseVertex = (data: any, containerEl: Element): Vertex | undefined => { labelStyle[key] = value; }); + if (data.classes) { + const classDef = classes[data.classes]; + if (classDef) { + classDef.styles?.forEach((style) => { + const [key, value] = style.split(":"); + containerStyle[key.trim() as CONTAINER_STYLE_PROPERTY] = value.trim(); + }); + classDef.textStyles?.forEach((style) => { + const [key, value] = style.split(":"); + labelStyle[key.trim() as LABEL_STYLE_PROPERTY] = value.trim(); + }); + } + } return { id: data.id, labelType: data.labelType, @@ -230,8 +249,9 @@ export const parseMermaidFlowChartDiagram = ( //@ts-ignore const mermaidParser = diagram.parser.yy; const vertices = mermaidParser.getVertices(); + const classes = mermaidParser.getClasses(); Object.keys(vertices).forEach((id) => { - vertices[id] = parseVertex(vertices[id], containerEl); + vertices[id] = parseVertex(vertices[id], containerEl, classes); }); // Track the count of edges based on the edge id