Files
BiPy/Xenith/node_editor.cpp
T
2026-05-15 18:45:50 +07:00

513 lines
21 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "node_editor.hpp"
#include "core.hpp"
#include <algorithm>
#include <cmath>
namespace NodeEditor {
// === Вспомогательные функции ===
ImVec2 GetPortPos(const Node& node, const Port& port, const ImVec2& canvasOffset) {
float portY = node.pos.y + 30 + port.index * 25;
if (port.type == PortType::Input) {
return ImVec2(node.pos.x + canvasOffset.x, portY + canvasOffset.y);
} else {
return ImVec2(node.pos.x + node.size.x + canvasOffset.x, portY + canvasOffset.y);
}
}
void DrawBezier(ImDrawList* dl, ImVec2 start, ImVec2 end, ImU32 color, float thickness) {
ImVec2 ctrl1 = start + ImVec2(50, 0);
ImVec2 ctrl2 = end - ImVec2(50, 0);
dl->AddBezierCubic(start, ctrl1, ctrl2, end, color, thickness, 32);
// Стрелочка на конце
float angle = atan2(end.y - ctrl2.y, end.x - ctrl2.x);
ImVec2 arrow1 = end + ImVec2(cos(angle - 0.5f) * 8, sin(angle - 0.5f) * 8);
ImVec2 arrow2 = end + ImVec2(cos(angle + 0.5f) * 8, sin(angle + 0.5f) * 8);
dl->AddLine(end, arrow1, color, thickness);
dl->AddLine(end, arrow2, color, thickness);
}
ImU32 GetNodeColor(NodeType type, bool selected) {
if (selected) return IM_COL32(255, 255, 100, 255);
switch(type) {
case NodeType::Input: return IM_COL32(100, 200, 100, 255);
case NodeType::Hidden: return IM_COL32(100, 150, 255, 255);
case NodeType::Output: return IM_COL32(255, 100, 100, 255);
default: return IM_COL32(150, 150, 150, 255);
}
}
// === Методы Node ===
ImVec2 Node::GetInputPos(int portIdx) const {
return ImVec2(pos.x, pos.y + 35 + portIdx * 25);
}
ImVec2 Node::GetOutputPos(int portIdx) const {
return ImVec2(pos.x + size.x, pos.y + 35 + portIdx * 25);
}
// === Инициализация ===
void Init(GraphState& graph) {
graph.nextNodeId = 0;
graph.zoom = 1.0f;
graph.panOffset = ImVec2(100, 50);
}
// === Отрисовка узла ===
void DrawNode(GraphState& graph, Node& node, const ImVec2& canvasOffset) {
ImDrawList* dl = ImGui::GetWindowDrawList();
ImVec2 pos = node.pos + canvasOffset;
// Фон узла
ImU32 bgColor = GetNodeColor(node.type, node.selected);
dl->AddRectFilled(pos, pos + node.size, IM_COL32(40, 45, 55, 255), 8);
dl->AddRect(pos, pos + node.size, bgColor, 8, 0, 2.0f);
// Заголовок
ImVec2 titlePos = pos + ImVec2(10, 5);
dl->AddText(titlePos, IM_COL32(255, 255, 255, 255), node.title.c_str());
// Размер слоя
std::string sizeText = std::to_string(node.layerSize) + " нейронов";
dl->AddText(pos + ImVec2(10, 22), IM_COL32(180, 180, 180, 200), sizeText.c_str());
// Ветка (если есть)
if (node.branch != -1) {
const char* branchName = node.branch == 0 ? "A" : "B";
ImU32 branchColor = node.branch == 0 ? IM_COL32(100, 255, 100, 255) : IM_COL32(100, 150, 255, 255);
dl->AddRectFilled(pos + ImVec2(node.size.x - 35, 3),
pos + ImVec2(node.size.x - 5, 18),
branchColor, 3);
dl->AddText(pos + ImVec2(node.size.x - 28, 5), IM_COL32(0,0,0,255), branchName);
}
// Порты ввода
for (size_t i = 0; i < node.inputs.size(); i++) {
Port& port = node.inputs[i];
ImVec2 portPos = GetPortPos(node, port, canvasOffset);
// Кружок порта
ImU32 portColor = IM_COL32(180, 180, 180, 255);
if (graph.hoveredPortNode == node.id && graph.hoveredPortIdx == (int)i &&
graph.hoveredPortType == PortType::Input) {
portColor = IM_COL32(255, 255, 100, 255);
}
dl->AddCircleFilled(portPos, 6, portColor);
dl->AddCircle(portPos, 6, IM_COL32(50, 50, 50, 255), 12, 1.5f);
// Название порта
dl->AddText(portPos + ImVec2(12, -6), IM_COL32(200, 200, 200, 255), port.name.c_str());
// Выбор ветки для порта
if (port.isBranchPort) {
ImVec2 btnPos = portPos + ImVec2(100, 0);
if (ImGui::InvisibleButton(("##branch" + std::to_string(node.id) + "_" + std::to_string(i)).c_str(), ImVec2(50, 15))) {
node.branch = (node.branch + 1) % 3 - 1; // -1 -> 0 -> 1 -> -1
}
const char* branchTxt = node.branch == -1 ? "All" : (node.branch == 0 ? "A" : "B");
dl->AddText(btnPos, IM_COL32(255, 255, 255, 200), branchTxt);
}
}
// Порты вывода
for (size_t i = 0; i < node.outputs.size(); i++) {
Port& port = node.outputs[i];
ImVec2 portPos = GetPortPos(node, port, canvasOffset);
ImU32 portColor = IM_COL32(180, 180, 180, 255);
if (graph.hoveredPortNode == node.id && graph.hoveredPortIdx == (int)i &&
graph.hoveredPortType == PortType::Output) {
portColor = IM_COL32(255, 255, 100, 255);
}
dl->AddCircleFilled(portPos, 6, portColor);
dl->AddCircle(portPos, 6, IM_COL32(50, 50, 50, 255), 12, 1.5f);
// Название справа от порта
ImVec2 textPos = portPos - ImVec2(12 + ImGui::CalcTextSize(port.name.c_str()).x, 6);
dl->AddText(textPos, IM_COL32(200, 200, 200, 255), port.name.c_str());
}
// Индикатор разделения
if (node.isSplit) {
dl->AddText(pos + ImVec2(10, node.size.y - 20), IM_COL32(255, 200, 100, 255), "✦ Split Output");
}
}
// === Отрисовка соединений ===
void DrawConnections(GraphState& graph, const ImVec2& canvasOffset) {
ImDrawList* dl = ImGui::GetWindowDrawList();
for (const auto& conn : graph.connections) {
auto fromNodeIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
[conn](const Node& n) { return n.id == conn.fromNode; });
auto toNodeIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
[conn](const Node& n) { return n.id == conn.toNode; });
if (fromNodeIt == graph.nodes.end() || toNodeIt == graph.nodes.end()) continue;
const Node& from = *fromNodeIt;
const Node& to = *toNodeIt;
ImVec2 start = GetPortPos(from, from.outputs[conn.fromPort], canvasOffset);
ImVec2 end = GetPortPos(to, to.inputs[conn.toPort], canvasOffset);
// Цвет в зависимости от ветки
ImU32 color = IM_COL32(200, 200, 200, 180);
if (to.branch == 0) color = IM_COL32(100, 255, 100, 180);
else if (to.branch == 1) color = IM_COL32(100, 150, 255, 180);
DrawBezier(dl, start, end, color);
}
// Линия при создании соединения
if (graph.creatingConnection) {
auto nodeIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
[graph](const Node& n) { return n.id == graph.connectionStartNode; });
if (nodeIt != graph.nodes.end()) {
const Node& startNode = *nodeIt;
ImVec2 start = graph.connectionStartType == PortType::Output
? startNode.GetOutputPos(graph.connectionStartPort)
: startNode.GetInputPos(graph.connectionStartPort);
start += canvasOffset;
DrawBezier(dl, start, graph.connectionMousePos + canvasOffset,
IM_COL32(255, 255, 100, 200), 2.5f);
}
}
}
// === Обработка ввода ===
void HandleInput(GraphState& graph, const ImVec2& canvasPos) {
ImGuiIO& io = ImGui::GetIO();
ImVec2 mousePos = io.MousePos - canvasPos - graph.panOffset;
// Панорамирование (средняя кнопка мыши)
if (ImGui::IsMouseDown(2)) {
if (!graph.panning) {
graph.panning = true;
graph.panStart = io.MousePos - graph.panOffset;
}
graph.panOffset = io.MousePos - graph.panStart;
} else {
graph.panning = false;
}
// Масштаб колесом
if (ImGui::IsWindowHovered() && io.MouseWheel != 0) {
float zoomDelta = io.MouseWheel * 0.1f;
graph.zoom = ImClamp(graph.zoom + zoomDelta, 0.5f, 3.0f);
}
// Проверка ховера над портами
graph.hoveredPortNode = -1;
for (auto& node : graph.nodes) {
for (size_t i = 0; i < node.inputs.size(); i++) {
ImVec2 pos = GetPortPos(node, node.inputs[i], graph.panOffset);
if (ImLengthSqr(mousePos - pos) < 100) {
graph.hoveredPortNode = node.id;
graph.hoveredPortIdx = (int)i;
graph.hoveredPortType = PortType::Input;
}
}
for (size_t i = 0; i < node.outputs.size(); i++) {
ImVec2 pos = GetPortPos(node, node.outputs[i], graph.panOffset);
if (ImLengthSqr(mousePos - pos) < 100) {
graph.hoveredPortNode = node.id;
graph.hoveredPortIdx = (int)i;
graph.hoveredPortType = PortType::Output;
}
}
}
// Создание соединения
if (ImGui::IsMouseClicked(0) && graph.hoveredPortNode != -1) {
graph.creatingConnection = true;
graph.connectionStartNode = graph.hoveredPortNode;
graph.connectionStartPort = graph.hoveredPortIdx;
graph.connectionStartType = graph.hoveredPortType;
}
if (graph.creatingConnection) {
graph.connectionMousePos = mousePos;
if (ImGui::IsMouseReleased(0)) {
// Завершение соединения
if (graph.hoveredPortNode != -1 &&
graph.hoveredPortNode != graph.connectionStartNode &&
graph.hoveredPortType != graph.connectionStartType) {
// Добавляем соединение (Output -> Input)
if (graph.connectionStartType == PortType::Output) {
graph.connections.emplace_back(
graph.connectionStartNode, graph.connectionStartPort,
graph.hoveredPortNode, graph.hoveredPortIdx);
} else {
graph.connections.emplace_back(
graph.hoveredPortNode, graph.hoveredPortIdx,
graph.connectionStartNode, graph.connectionStartPort);
}
}
graph.creatingConnection = false;
}
// Отмена правой кнопкой
if (ImGui::IsMouseClicked(1)) {
graph.creatingConnection = false;
}
}
// Перетаскивание узлов
for (auto& node : graph.nodes) {
ImVec2 nodeScreenPos = node.pos + graph.panOffset;
if (ImRect(nodeScreenPos, nodeScreenPos + node.size).Contains(io.MousePos - graph.panOffset)) {
if (ImGui::IsMouseClicked(0) && !graph.creatingConnection) {
node.selected = true;
node.dragging = true;
graph.selectedNode = node.id;
node.dragOffset = io.MousePos - node.pos - graph.panOffset;
}
}
if (node.dragging) {
node.pos = io.MousePos - node.dragOffset - graph.panOffset;
if (ImGui::IsMouseReleased(0)) {
node.dragging = false;
}
}
}
// Выбор узла по клику на пустом месте
if (ImGui::IsMouseClicked(0) && graph.hoveredPortNode == -1 && !graph.creatingConnection) {
bool clickedOnNode = false;
for (const auto& node : graph.nodes) {
if (ImRect(node.pos + graph.panOffset, node.pos + graph.panOffset + node.size)
.Contains(io.MousePos)) {
clickedOnNode = true;
break;
}
}
if (!clickedOnNode) {
graph.selectedNode = -1;
for (auto& node : graph.nodes) node.selected = false;
}
}
// Удаление соединения по правому клику
if (ImGui::IsMouseClicked(1) && !graph.creatingConnection) {
for (auto it = graph.connections.begin(); it != graph.connections.end(); ) {
// Проверка ховера над линией (упрощенная)
auto fromIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
[it](const Node& n) { return n.id == it->fromNode; });
auto toIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
[it](const Node& n) { return n.id == it->toNode; });
if (fromIt != graph.nodes.end() && toIt != graph.nodes.end()) {
ImVec2 start = GetPortPos(*fromIt, fromIt->outputs[it->fromPort], graph.panOffset);
ImVec2 end = GetPortPos(*toIt, toIt->inputs[it->toPort], graph.panOffset);
// Простая проверка расстояния до линии
float dist = ImLineClosestPoint(start, end, mousePos + graph.panOffset);
if (ImLengthSqr((mousePos + graph.panOffset) - dist) < 25) {
it = graph.connections.erase(it);
continue;
}
}
++it;
}
}
}
// === Основная отрисовка ===
void DrawGraph(GraphState& graph, const ImVec2& canvasSize) {
ImDrawList* dl = ImGui::GetWindowDrawList();
ImVec2 canvasPos = ImGui::GetCursorScreenPos();
// Фон с сеткой
dl->AddRectFilled(canvasPos, canvasPos + canvasSize, IM_COL32(30, 32, 40, 255));
// Сетка
float gridSize = 50 * graph.zoom;
for (float x = fmodf(-graph.panOffset.x, gridSize); x < canvasSize.x; x += gridSize) {
dl->AddLine(canvasPos + ImVec2(x, 0), canvasPos + ImVec2(x, canvasSize.y),
IM_COL32(50, 55, 70, 100));
}
for (float y = fmodf(-graph.panOffset.y, gridSize); y < canvasSize.y; y += gridSize) {
dl->AddLine(canvasPos + ImVec2(0, y), canvasPos + ImVec2(canvasSize.x, y),
IM_COL32(50, 55, 70, 100));
}
// Соединения (сначала, чтобы были под узлами)
DrawConnections(graph, graph.panOffset);
// Узлы
for (auto& node : graph.nodes) {
DrawNode(graph, node, graph.panOffset);
}
// Обработка ввода
ImGui::InvisibleButton("##GraphCanvas", canvasSize);
if (ImGui::IsItemHovered()) {
HandleInput(graph, canvasPos);
}
// Контекстное меню для добавления узлов
if (ImGui::BeginPopupContextItem("##GraphContext")) {
if (ImGui::MenuItem(" Входной слой")) {
Node newNode(graph.nextNodeId++, "Input", NodeType::Input);
newNode.pos = ImGui::GetMousePos() - canvasPos - graph.panOffset;
newNode.size = ImVec2(180, 90);
newNode.inputs = {};
newNode.outputs = {Port("Output", PortType::Output)};
newNode.layerSize = 256 * 8; // CONTEXT * EMBED
graph.nodes.push_back(newNode);
}
if (ImGui::MenuItem("⬜ Скрытый слой")) {
Node newNode(graph.nextNodeId++, "Hidden", NodeType::Hidden);
newNode.pos = ImGui::GetMousePos() - canvasPos - graph.panOffset;
newNode.size = ImVec2(180, 100);
newNode.inputs = {Port("Input", PortType::Input)};
newNode.outputs = {Port("Output", PortType::Output)};
graph.nodes.push_back(newNode);
}
if (ImGui::MenuItem("🔴 Выходной слой")) {
Node newNode(graph.nextNodeId++, "Output", NodeType::Output);
newNode.pos = ImGui::GetMousePos() - canvasPos - graph.panOffset;
newNode.size = ImVec2(180, 90);
newNode.inputs = {Port("Input", PortType::Input)};
newNode.outputs = {};
newNode.layerSize = 300; // VOCAB
graph.nodes.push_back(newNode);
}
ImGui::Separator();
ImGui::Text("Управление:");
ImGui::Text("• ЛКМ: перетащить узел / создать связь");
ImGui::Text("• ПКМ: удалить связь / отмена");
ImGui::Text("• Колесо: масштаб");
ImGui::Text("• Средняя кнопка: панорамирование");
ImGui::EndPopup();
}
// Панель свойств выбранного узла
if (graph.selectedNode != -1) {
auto selIt = std::find_if(graph.nodes.begin(), graph.nodes.end(),
[graph](const Node& n) { return n.id == graph.selectedNode; });
if (selIt != graph.nodes.end()) {
Node& sel = *selIt;
ImGui::SetNextWindowPos(canvasPos + ImVec2(10, 10));
ImGui::SetNextWindowSize(ImVec2(250, 200));
if (ImGui::Begin("##NodeProperties", nullptr,
ImGuiWindowFlags_NoTitleBar | ImGuiWindowFlags_AlwaysAutoResize |
ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoSavedSettings)) {
ImGui::TextColored(ImVec4(1,1,0,1), "Свойства: %s", sel.title.c_str());
ImGui::Separator();
if (sel.type != NodeType::Output) {
ImGui::InputInt("Нейронов", &sel.layerSize);
if (sel.layerSize < 1) sel.layerSize = 1;
}
if (sel.type == NodeType::Input) {
ImGui::Text("Ветка:");
if (ImGui::RadioButton("Объединенная", sel.branch == -1)) sel.branch = -1;
if (ImGui::RadioButton("Ветка A", sel.branch == 0)) sel.branch = 0;
if (ImGui::RadioButton("Ветка B", sel.branch == 1)) sel.branch = 1;
}
ImGui::Checkbox("Разделить выход", &sel.isSplit);
if (ImGui::Button("🗑 Удалить узел", ImVec2(-1, 30))) {
// Удаляем узел и все его соединения
graph.connections.erase(
std::remove_if(graph.connections.begin(), graph.connections.end(),
[id = sel.id](const Connection& c) {
return c.fromNode == id || c.toNode == id;
}),
graph.connections.end());
graph.nodes.erase(selIt);
graph.selectedNode = -1;
}
ImGui::End();
}
}
}
}
// === Синхронизация с LayerStructure ===
void SyncToLayerConfigs(GraphState& graph, std::vector<LayerStructure_t>& configs) {
configs.clear();
// Сортируем узлы по позиции (приблизительный топологический порядок)
std::vector<Node*> sortedNodes;
for (auto& n : graph.nodes) sortedNodes.push_back(&n);
std::sort(sortedNodes.begin(), sortedNodes.end(),
[](Node* a, Node* b) { return a->pos.x < b->pos.x; });
for (auto* node : sortedNodes) {
LayerStructure_t layer;
layer.size = node->layerSize;
layer.branch = node->branch;
layer.isSplit = node->isSplit;
// Находим источники по соединениям
for (const auto& conn : graph.connections) {
if (conn.toNode == node->id) {
// Находим индекс слоя-источника в configs
for (int i = 0; i < (int)configs.size(); i++) {
// Упрощенная логика - в реальности нужен маппинг node.id -> layer index
if (i == conn.fromNode) {
layer.sources.push_back(i);
layer.sourceBranches.push_back(node->branch);
}
}
}
}
configs.push_back(layer);
}
}
void SyncFromLayerConfigs(GraphState& graph, const std::vector<LayerStructure_t>& configs) {
graph.nodes.clear();
graph.connections.clear();
for (size_t i = 0; i < configs.size(); i++) {
const auto& cfg = configs[i];
NodeType type = cfg.sources.empty() ? NodeType::Input
: (i == configs.size()-1 ? NodeType::Output : NodeType::Hidden);
Node node((int)i, type == NodeType::Input ? "Input" :
type == NodeType::Output ? "Output" : "Hidden", type);
node.pos = ImVec2(100 + i * 250, 100 + (i % 3) * 150);
node.size = ImVec2(180, type == NodeType::Hidden ? 100 : 90);
node.layerSize = cfg.size;
node.branch = cfg.branch;
node.isSplit = cfg.isSplit;
node.layerIndex = (int)i;
if (type != NodeType::Output)
node.outputs.push_back(Port("Out", PortType::Output));
if (type != NodeType::Input)
node.inputs.push_back(Port("In", PortType::Input));
graph.nodes.push_back(node);
}
// Восстанавливаем соединения
for (size_t i = 0; i < configs.size(); i++) {
for (size_t j = 0; j < configs[i].sources.size(); j++) {
int srcIdx = configs[i].sources[j];
graph.connections.emplace_back(srcIdx, 0, (int)i, 0);
}
}
}
} // namespace NodeEditor