Files
2026-05-16 01:43:15 +07:00

1097 lines
45 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// ============================================================================
// Xenith AI Studio v3.2
// ============================================================================
#define IMGUI_DEFINE_MATH_OPERATORS
#include "imgui.h"
#include "imgui_internal.h"
#include "imgui_impl_glfw.h"
#include "imgui_impl_opengl3.h"
#include <GLFW/glfw3.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <map>
#include <algorithm>
#include <thread>
#include <atomic>
#include <mutex>
#include <cmath>
#include <sstream>
#include <iomanip>
#include <chrono>
#include <cstring>
#include "Xenith/core.hpp"
#include "Xenith/token/token.hpp"
// ============================================================================
// ⚠️ ГЛОБАЛЬНЫЕ НАСТРОЙКИ - ДОЛЖНЫ БЫТЬ ПЕРЕД ВСЕМИ ФУНКЦИЯМИ КОТОРЫЕ ИХ ИСПОЛЬЗУЮТ
// ============================================================================
int UI_CONTEXT = 512;
int UI_EMBED = 128;
int UI_VOCAB = 300;
int MAX_RESPONSE_TOKENS = 128;
// ============================================================================
// NODE EDITOR ENGINE (теперь видит UI_CONTEXT, UI_EMBED, UI_VOCAB)
// ============================================================================
namespace NodeEditor {
enum class PortType { Input, Output };
enum class NodeType { Input, Hidden, Output, Splitter, Merger };
struct Port {
std::string name;
PortType type;
int index;
Port(const std::string& n, PortType t, int idx = 0) : name(n), type(t), index(idx) {}
};
struct Node {
int id;
std::string title;
NodeType type;
ImVec2 pos;
ImVec2 size;
bool selected;
bool dragging;
ImVec2 dragOffset;
int layerSize;
int branchCount;
int activeBranch;
bool isSizeFixed; // NEW: для Input/Output
std::vector<Port> inputs;
std::vector<Port> outputs;
Node(int id_, const std::string& t_, NodeType type_)
: id(id_), title(t_), type(type_), pos(0,0), size(180,100),
selected(false), dragging(false), layerSize(128), branchCount(2),
activeBranch(-1), isSizeFixed(false) {
UpdatePorts();
}
void UpdatePorts() {
inputs.clear(); outputs.clear();
switch(type) {
case NodeType::Input:
outputs.push_back(Port("Out", PortType::Output));
isSizeFixed = true;
// Размер вычисляется из глобальных настроек
layerSize = UI_CONTEXT * UI_EMBED;
break;
case NodeType::Hidden:
inputs.push_back(Port("In", PortType::Input));
outputs.push_back(Port("Out", PortType::Output));
isSizeFixed = false;
break;
case NodeType::Output:
inputs.push_back(Port("In", PortType::Input));
isSizeFixed = true;
layerSize = UI_VOCAB;
break;
case NodeType::Splitter:
inputs.push_back(Port("In", PortType::Input));
for(int i=0;i<branchCount;i++)
outputs.push_back(Port("Br "+std::to_string(i), PortType::Output));
isSizeFixed = false;
break;
case NodeType::Merger:
for(int i=0;i<branchCount;i++)
inputs.push_back(Port("Br "+std::to_string(i), PortType::Input));
outputs.push_back(Port("Out", PortType::Output));
isSizeFixed = false;
break;
}
}
ImVec2 GetPortScreenPos(int portIdx, bool isInput, const ImVec2& canvasPos, const ImVec2& pan) const {
float y = pos.y + 30 + portIdx * 22;
return canvasPos + (isInput ? ImVec2(pos.x - 6, y) : ImVec2(pos.x + size.x + 6, y)) + pan;
}
};
struct Connection {
int fromNode, fromPort, toNode, toPort;
Connection(int fn, int fp, int tn, int tp) : fromNode(fn), fromPort(fp), toNode(tn), toPort(tp) {}
};
struct GraphState {
std::vector<Node> nodes;
std::vector<Connection> connections;
int nextId = 0;
int selectedNode = -1;
int hoveredPortNode = -1, hoveredPortIdx = -1;
PortType hoveredPortType = PortType::Input;
bool creatingConn = false;
int connStartNode = -1, connStartPort = -1;
PortType connStartType = PortType::Output;
ImVec2 connMousePos;
ImVec2 pan = ImVec2(100, 80);
bool panning = false;
ImVec2 panStart;
};
ImU32 GetNodeColor(NodeType type, bool selected) {
if(selected) return IM_COL32(255,255,100,255);
switch(type) {
case NodeType::Input: return IM_COL32(80,200,120,255);
case NodeType::Hidden: return IM_COL32(80,140,220,255);
case NodeType::Output: return IM_COL32(220,80,80,255);
case NodeType::Splitter: return IM_COL32(220,180,50,255);
case NodeType::Merger: return IM_COL32(180,80,220,255);
default: return IM_COL32(120,120,120,255);
}
}
void DrawNode(ImDrawList* dl, const Node& n, const GraphState& g, const ImVec2& canvasPos, const ImVec2& pan) {
ImVec2 scr = canvasPos + n.pos + pan;
dl->AddRectFilled(scr, scr + n.size, IM_COL32(35,38,48,255), 6);
dl->AddRect(scr, scr + n.size, GetNodeColor(n.type, n.selected), 6, 0, 2.0f);
dl->AddText(scr + ImVec2(10,4), IM_COL32(255,255,255,255), n.title.c_str());
// Отображение размера
if(n.isSizeFixed) {
std::string sizeInfo = std::to_string(n.layerSize) + " (auto)";
dl->AddText(scr + ImVec2(10,20), IM_COL32(160,160,160,200), sizeInfo.c_str());
if(n.type == NodeType::Input) {
dl->AddText(scr + ImVec2(10,36), IM_COL32(100,100,100,150), "CTX×EMB");
} else if(n.type == NodeType::Output) {
dl->AddText(scr + ImVec2(10,36), IM_COL32(100,100,100,150), "VOCAB");
}
} else {
dl->AddText(scr + ImVec2(10,20), IM_COL32(160,160,160,200), std::to_string(n.layerSize).c_str());
}
if(n.type == NodeType::Splitter || n.type == NodeType::Merger) {
dl->AddText(scr + ImVec2(10, n.size.y-18), IM_COL32(255,220,100,255), (" x"+std::to_string(n.branchCount)).c_str());
}
// Порты ввода
for(size_t i=0; i<n.inputs.size(); i++) {
ImVec2 p = n.GetPortScreenPos((int)i, true, canvasPos, pan);
bool hov = (g.hoveredPortNode == n.id && g.hoveredPortIdx == (int)i && g.hoveredPortType == PortType::Input);
dl->AddCircleFilled(p, 5, hov ? IM_COL32(255,255,100,255) : IM_COL32(180,180,180,255));
dl->AddCircle(p, 5, IM_COL32(50,50,50,255), 12, 1.5f);
dl->AddText(p + ImVec2(10,-5), IM_COL32(200,200,200,255), n.inputs[i].name.c_str());
}
// Порты вывода
for(size_t i=0; i<n.outputs.size(); i++) {
ImVec2 p = n.GetPortScreenPos((int)i, false, canvasPos, pan);
bool hov = (g.hoveredPortNode == n.id && g.hoveredPortIdx == (int)i && g.hoveredPortType == PortType::Output);
dl->AddCircleFilled(p, 5, hov ? IM_COL32(255,255,100,255) : IM_COL32(180,180,180,255));
dl->AddCircle(p, 5, IM_COL32(50,50,50,255), 12, 1.5f);
ImVec2 txt = p - ImVec2(10 + ImGui::CalcTextSize(n.outputs[i].name.c_str()).x, 5);
dl->AddText(txt, IM_COL32(200,200,200,255), n.outputs[i].name.c_str());
}
}
void DrawConnections(ImDrawList* dl, const GraphState& g, const ImVec2& canvasPos, const ImVec2& pan) {
for(const auto& c : g.connections) {
auto fn = std::find_if(g.nodes.begin(), g.nodes.end(), [c](const Node& n){return n.id==c.fromNode;});
auto tn = std::find_if(g.nodes.begin(), g.nodes.end(), [c](const Node& n){return n.id==c.toNode;});
if(fn==g.nodes.end() || tn==g.nodes.end()) continue;
ImVec2 start = fn->GetPortScreenPos(c.fromPort, false, canvasPos, pan);
ImVec2 end = tn->GetPortScreenPos(c.toPort, true, canvasPos, pan);
ImVec2 cp1 = start + ImVec2(60,0);
ImVec2 cp2 = end - ImVec2(60,0);
dl->AddBezierCubic(start, cp1, cp2, end, IM_COL32(160,160,160,150), 2.0f, 24);
}
if(g.creatingConn) {
auto sn = std::find_if(g.nodes.begin(), g.nodes.end(), [g](const Node& n){return n.id==g.connStartNode;});
if(sn != g.nodes.end()) {
ImVec2 start = (g.connStartType==PortType::Output)
? sn->GetPortScreenPos(g.connStartPort, false, canvasPos, pan)
: sn->GetPortScreenPos(g.connStartPort, true, canvasPos, pan);
ImVec2 end = g.connMousePos;
ImVec2 cp1 = start + ImVec2(60,0);
ImVec2 cp2 = end - ImVec2(60,0);
dl->AddBezierCubic(start, cp1, cp2, end, IM_COL32(255,255,100,180), 2.5f, 24);
}
}
}
void HandleInput(GraphState& g, const ImVec2& canvasPos, const ImVec2& canvasSize) {
ImGuiIO& io = ImGui::GetIO();
ImVec2 mouseScreen = io.MousePos;
ImVec2 mouseWorld = mouseScreen - canvasPos - g.pan;
if(ImGui::IsMouseDown(ImGuiMouseButton_Middle)) {
if(!g.panning) { g.panning=true; g.panStart=mouseScreen-g.pan; }
g.pan = mouseScreen - g.panStart;
} else g.panning=false;
if(ImGui::IsWindowHovered() && io.MouseWheel != 0.0f) {
g.pan.y -= io.MouseWheel * 20.0f;
}
g.hoveredPortNode=-1; g.hoveredPortIdx=-1;
for(const auto& n : g.nodes) {
for(size_t i=0;i<n.inputs.size();i++) {
ImVec2 p = n.GetPortScreenPos((int)i, true, canvasPos, g.pan);
if(ImLengthSqr(mouseScreen-p) < 64) {
g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Input;
}
}
for(size_t i=0;i<n.outputs.size();i++) {
ImVec2 p = n.GetPortScreenPos((int)i, false, canvasPos, g.pan);
if(ImLengthSqr(mouseScreen-p) < 64) {
g.hoveredPortNode=n.id; g.hoveredPortIdx=(int)i; g.hoveredPortType=PortType::Output;
}
}
}
if(ImGui::IsMouseClicked(ImGuiMouseButton_Left) && g.hoveredPortNode!=-1 && !g.panning) {
g.creatingConn=true;
g.connStartNode=g.hoveredPortNode;
g.connStartPort=g.hoveredPortIdx;
g.connStartType=g.hoveredPortType;
}
if(g.creatingConn) {
g.connMousePos = mouseScreen;
if(ImGui::IsMouseReleased(ImGuiMouseButton_Left)) {
if(g.hoveredPortNode!=-1 && g.hoveredPortNode!=g.connStartNode && g.hoveredPortType!=g.connStartType) {
if(g.connStartType==PortType::Output)
g.connections.emplace_back(g.connStartNode, g.connStartPort, g.hoveredPortNode, g.hoveredPortIdx);
else
g.connections.emplace_back(g.hoveredPortNode, g.hoveredPortIdx, g.connStartNode, g.connStartPort);
}
g.creatingConn=false;
}
if(ImGui::IsMouseClicked(ImGuiMouseButton_Right)) g.creatingConn=false;
}
for(auto& n : g.nodes) {
ImVec2 scr = canvasPos + n.pos + g.pan;
if(ImRect(scr, scr+n.size).Contains(mouseScreen) && ImGui::IsMouseClicked(ImGuiMouseButton_Left) && !g.creatingConn) {
n.selected=true; n.dragging=true; g.selectedNode=n.id;
n.dragOffset = mouseWorld - n.pos;
}
if(n.dragging) {
n.pos = mouseWorld - n.dragOffset;
if(ImGui::IsMouseReleased(ImGuiMouseButton_Left)) n.dragging=false;
}
}
if(ImGui::IsMouseClicked(ImGuiMouseButton_Left) && g.hoveredPortNode==-1 && !g.creatingConn && !g.panning) {
bool onNode = false;
for(const auto& n : g.nodes)
if(ImRect(canvasPos+n.pos+g.pan, canvasPos+n.pos+g.pan+n.size).Contains(mouseScreen)) onNode=true;
if(!onNode) { g.selectedNode=-1; for(auto& n:g.nodes) n.selected=false; }
}
}
void DrawGraph(GraphState& g, const ImVec2& canvasSize) {
ImDrawList* dl = ImGui::GetWindowDrawList();
ImVec2 canvasPos = ImGui::GetCursorScreenPos();
dl->AddRectFilled(canvasPos, canvasPos+canvasSize, IM_COL32(25,27,35,255));
float gs = 40.0f;
ImVec2 off = ImVec2(fmodf(g.pan.x, gs), fmodf(g.pan.y, gs));
for(float x=off.x; x<canvasSize.x; x+=gs)
dl->AddLine(canvasPos+ImVec2(x,0), canvasPos+ImVec2(x,canvasSize.y), IM_COL32(40,44,55,120));
for(float y=off.y; y<canvasSize.y; y+=gs)
dl->AddLine(canvasPos+ImVec2(0,y), canvasPos+ImVec2(canvasSize.x,y), IM_COL32(40,44,55,120));
DrawConnections(dl, g, canvasPos, g.pan);
for(auto& n : g.nodes) DrawNode(dl, n, g, canvasPos, g.pan);
ImGui::InvisibleButton("##Canvas", canvasSize);
if(ImGui::IsItemHovered()) HandleInput(g, canvasPos, canvasSize);
if(ImGui::BeginPopupContextItem("##Ctx")) {
ImVec2 wPos = ImGui::GetMousePos() - canvasPos - g.pan;
if(ImGui::MenuItem("🟢 Input")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Input", NodeType::Input);
n.pos=wPos;
}
if(ImGui::MenuItem(" Hidden")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Hidden", NodeType::Hidden);
n.pos=wPos;
}
if(ImGui::MenuItem(" Output")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Output", NodeType::Output);
n.pos=wPos;
}
ImGui::Separator();
if(ImGui::MenuItem(" Splitter")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Splitter", NodeType::Splitter);
n.pos=wPos; n.branchCount=2; n.UpdatePorts();
}
if(ImGui::MenuItem(" Merger")) {
auto& n = g.nodes.emplace_back(g.nextId++, "Merger", NodeType::Merger);
n.pos=wPos; n.branchCount=2; n.UpdatePorts();
}
ImGui::Separator();
ImGui::Text("LMB: Drag/Connect | RMB: Cancel | MMB: Pan");
ImGui::EndPopup();
}
if(g.selectedNode != -1) {
auto it = std::find_if(g.nodes.begin(), g.nodes.end(), [g](const Node& n){return n.id==g.selectedNode;});
if(it != g.nodes.end()) {
Node& sel = *it;
ImGui::SetNextWindowPos(canvasPos + ImVec2(10,10));
ImGui::SetNextWindowSize(ImVec2(240,240));
if(ImGui::Begin("##Props", nullptr, ImGuiWindowFlags_NoTitleBar|ImGuiWindowFlags_AlwaysAutoResize|ImGuiWindowFlags_NoMove)) {
ImGui::TextColored(ImVec4(1,1,0.5f,1), "%s # %d", sel.title.c_str(), sel.id);
ImGui::Separator();
// Редактирование размера ТОЛЬКО для не-фиксированных слоёв
if(!sel.isSizeFixed) {
ImGui::InputInt("Neurons", &sel.layerSize);
if(sel.layerSize < 1) sel.layerSize = 1;
} else {
ImGui::Text("Size: %d (auto)", sel.layerSize);
if(sel.type == NodeType::Input) {
ImGui::TextColored(ImVec4(0.5,0.5,0.5,1), "= CONTEXT × EMBED");
ImGui::TextColored(ImVec4(0.5,0.5,0.5,1), "= %d × %d", UI_CONTEXT, UI_EMBED);
} else if(sel.type == NodeType::Output) {
ImGui::TextColored(ImVec4(0.5,0.5,0.5,1), "= VOCAB_SIZE");
ImGui::TextColored(ImVec4(0.5,0.5,0.5,1), "= %d", UI_VOCAB);
}
}
if(sel.type == NodeType::Splitter || sel.type == NodeType::Merger) {
ImGui::Separator();
ImGui::SliderInt("Branches", &sel.branchCount, 2, 8);
if(ImGui::IsItemDeactivatedAfterEdit()) sel.UpdatePorts();
} else if(sel.type == NodeType::Input) {
ImGui::Separator();
ImGui::Text("Branch:");
ImGui::RadioButton("Combined", &sel.activeBranch, -1);
ImGui::RadioButton("A (0)", &sel.activeBranch, 0);
ImGui::RadioButton("B (1)", &sel.activeBranch, 1);
}
ImGui::Spacing();
if(ImGui::Button("🗑 Delete", ImVec2(-1,28))) {
g.connections.erase(std::remove_if(g.connections.begin(), g.connections.end(),
[id=sel.id](const Connection& c){return c.fromNode==id||c.toNode==id;}), g.connections.end());
g.nodes.erase(it); g.selectedNode=-1;
}
ImGui::End();
}
}
}
}
void SyncFromConfigs(GraphState& g, const std::vector<LayerStructure_t>& cfgs) {
g.nodes.clear(); g.connections.clear(); g.nextId=0;
for(size_t i=0;i<cfgs.size();i++) {
auto& c = cfgs[i];
NodeType t = c.sources.empty() ? NodeType::Input : (i==cfgs.size()-1 ? NodeType::Output : NodeType::Hidden);
Node n(g.nextId++, t==NodeType::Input?"Input":t==NodeType::Output?"Output":"Hidden", t);
n.pos = ImVec2(120 + i*260, 150 + (i%3)*120);
// Фиксированные размеры для Input/Output
if(t == NodeType::Input) {
n.layerSize = UI_CONTEXT * UI_EMBED;
n.isSizeFixed = true;
} else if(t == NodeType::Output) {
n.layerSize = UI_VOCAB;
n.isSizeFixed = true;
} else {
n.layerSize = c.size;
n.isSizeFixed = false;
}
n.activeBranch = c.branch;
n.UpdatePorts();
g.nodes.push_back(n);
}
for(size_t i=0;i<cfgs.size();i++) {
for(size_t j=0;j<cfgs[i].sources.size();j++)
g.connections.emplace_back(cfgs[i].sources[j], 0, (int)i, 0);
}
}
void SyncToConfigs(GraphState& g, std::vector<LayerStructure_t>& cfgs) {
cfgs.clear();
std::vector<const Node*> sorted;
for(auto& n:g.nodes) if(n.type!=NodeType::Splitter && n.type!=NodeType::Merger) sorted.push_back(&n);
std::sort(sorted.begin(), sorted.end(), [](const Node*a, const Node*b){return a->pos.x < b->pos.x;});
std::map<int,int> idToIdx;
for(size_t i=0;i<sorted.size();i++) idToIdx[sorted[i]->id] = (int)i;
for(const Node* n : sorted) {
LayerStructure_t l;
l.size = n->layerSize;
l.branch = n->activeBranch;
for(const auto& c : g.connections) {
if(c.toNode==n->id && idToIdx.count(c.fromNode)) {
l.sources.push_back(idToIdx[c.fromNode]);
l.sourceBranches.push_back(n->activeBranch);
}
}
cfgs.push_back(l);
}
}
std::string GenerateArchitectureText(const GraphState& g, const std::vector<LayerStructure_t>& configs) {
std::stringstream ss;
ss << "=== NEURAL NETWORK ARCHITECTURE ===\n\n";
long long totalParams = 0;
for(size_t i=0;i<configs.size();i++) {
if(configs[i].sources.empty()) continue;
int inputSize = 0;
for(int src : configs[i].sources) inputSize += configs[src].size;
long long layerParams = (long long)inputSize * configs[i].size + configs[i].size;
totalParams += layerParams;
}
ss << "Total Parameters: " << totalParams << " (" << std::fixed << std::setprecision(2) << (totalParams/1000000.0) << "M)\n\n";
ss << "Layer Structure:\n----------------\n";
for(size_t i=0;i<configs.size();i++) {
auto& l = configs[i];
std::string type = l.sources.empty() ? "INPUT" : (i==configs.size()-1 ? "OUTPUT" : "HIDDEN");
if(type == "INPUT") {
ss << "Layer " << i << " [INPUT]: " << l.size << " neurons";
ss << " (CONTEXT=" << UI_CONTEXT << " × EMBED=" << UI_EMBED << ")";
} else if(type == "OUTPUT") {
ss << "Layer " << i << " [OUTPUT]: " << l.size << " neurons";
ss << " (VOCAB=" << UI_VOCAB << ")";
} else {
ss << "Layer " << i << " [HIDDEN]: " << l.size << " neurons";
}
if(l.branch != -1) ss << " (Branch " << (char)('A'+l.branch) << ")";
ss << "\n";
if(!l.sources.empty()) {
ss << " ← From: ";
for(size_t j=0;j<l.sources.size();j++) {
ss << "Layer " << l.sources[j];
if(j < l.sources.size()-1) ss << ", ";
}
ss << "\n";
}
}
ss << "\n=== NODE GRAPH ===\n";
ss << "Total Nodes: " << g.nodes.size() << "\n";
ss << "Total Connections: " << g.connections.size() << "\n\n";
for(const auto& n : g.nodes) {
ss << "Node #" << n.id << " [" << n.title << "] @ ("
<< (int)n.pos.x << "," << (int)n.pos.y << ")\n";
ss << " Size: " << n.layerSize << " neurons";
if(n.isSizeFixed) ss << " (auto)";
ss << "\n Inputs: " << n.inputs.size() << " | Outputs: " << n.outputs.size() << "\n";
if(n.type == NodeType::Splitter || n.type == NodeType::Merger)
ss << " Branches: " << n.branchCount << "\n";
}
return ss.str();
}
bool SaveArchitectureToFile(const std::vector<LayerStructure_t>& configs, const std::string& filename) {
std::ofstream file(filename);
if(!file.is_open()) return false;
file << "[ARCHITECTURE]\n";
file << "layers=" << configs.size() << "\n";
for(size_t i=0;i<configs.size();i++) {
auto& l = configs[i];
file << "[LAYER " << i << "]\n";
file << "size=" << l.size << "\n";
file << "branch=" << l.branch << "\n";
file << "sources=";
for(size_t j=0;j<l.sources.size();j++) {
file << l.sources[j];
if(j < l.sources.size()-1) file << ",";
}
file << "\n";
file << "branches=";
for(size_t j=0;j<l.sourceBranches.size();j++) {
file << l.sourceBranches[j];
if(j < l.sourceBranches.size()-1) file << ",";
}
file << "\n\n";
}
file.close();
return true;
}
bool LoadArchitectureFromFile(std::vector<LayerStructure_t>& configs, const std::string& filename) {
std::ifstream file(filename);
if(!file.is_open()) return false;
configs.clear();
std::string line;
int currentLayer = -1;
while(std::getline(file, line)) {
if(line.find("[LAYER ") == 0) {
size_t start = line.find(' ') + 1;
size_t end = line.find(']');
currentLayer = std::stoi(line.substr(start, end - start));
configs.push_back(LayerStructure_t());
} else if(line.find("size=") == 0 && currentLayer >= 0) {
configs[currentLayer].size = std::stoi(line.substr(5));
} else if(line.find("branch=") == 0 && currentLayer >= 0) {
configs[currentLayer].branch = std::stoi(line.substr(7));
} else if(line.find("sources=") == 0 && currentLayer >= 0) {
std::string srcs = line.substr(8);
std::stringstream ss(srcs);
std::string token;
while(std::getline(ss, token, ',')) {
if(!token.empty()) configs[currentLayer].sources.push_back(std::stoi(token));
}
} else if(line.find("branches=") == 0 && currentLayer >= 0) {
std::string brs = line.substr(9);
std::stringstream ss(brs);
std::string token;
while(std::getline(ss, token, ',')) {
if(!token.empty()) configs[currentLayer].sourceBranches.push_back(std::stoi(token));
}
}
}
file.close();
return !configs.empty();
}
} // namespace NodeEditor
// ============================================================================
// CHAT SYSTEM
// ============================================================================
struct ChatMessage {
std::string role;
std::string content;
};
struct ChatSession {
std::string name;
std::vector<ChatMessage> messages;
int id;
ChatSession(int id_, const std::string& name_) : name(name_), id(id_) {}
};
// ============================================================================
// DATASET
// ============================================================================
struct TrainingSample {
std::string userText;
std::string aiText;
};
std::vector<TrainingSample> ParseDataset(const std::string& text, Tokenizer& tok) {
std::vector<TrainingSample> samples;
size_t pos = 0;
while(pos < text.length()) {
size_t userStart = text.find("[USER]", pos);
if(userStart == std::string::npos) break;
userStart += 6;
size_t aiStart = text.find("[AI]", userStart);
if(aiStart == std::string::npos) break;
aiStart += 4;
size_t eosPos = text.find("[EOS]", aiStart);
if(eosPos == std::string::npos) break;
std::string userText = text.substr(userStart, aiStart - userStart - 4);
std::string aiText = text.substr(aiStart, eosPos - aiStart);
while(!userText.empty() && (userText.back()==' ' || userText.back()=='\n' || userText.back()=='\r')) userText.pop_back();
while(!aiText.empty() && (aiText.back()==' ' || aiText.back()=='\n' || aiText.back()=='\r')) aiText.pop_back();
if(!userText.empty() && !aiText.empty()) {
samples.push_back({userText, aiText});
}
pos = eosPos + 5;
}
return samples;
}
bool LoadDatasetFromFile(std::string& buffer, const std::string& filename) {
std::ifstream file(filename);
if(!file.is_open()) return false;
std::stringstream ss;
ss << file.rdbuf();
buffer = ss.str();
file.close();
return true;
}
// ============================================================================
// GLOBAL STATE
// ============================================================================
struct UIState {
std::vector<ChatSession> chats;
int activeChat = 0;
char inputBuf[512] = "";
bool scrollChat = true;
TrainStatus lastStatus;
std::mutex mtx;
float lr = 0.01f;
int epochs = 10;
int doneEpochs = 0;
// Performance metrics
double tokensPerSecond = 0.0;
double generationTokensPerSecond = 0.0;
int totalTokensTrained = 0;
std::chrono::steady_clock::time_point trainingStartTime;
char dsBuf[524288] = "";
std::vector<TrainingSample> parsedSamples;
bool datasetLoaded = false;
std::atomic<bool> training{false}, stop{false};
std::vector<LayerStructure_t> layers;
NodeEditor::GraphState graph;
std::string architectureText;
char archFilePath[256] = "architecture.txt";
char datasetFilePath[256] = "dataset.txt";
} ui;
// ============================================================================
// NEURAL NETWORK HELPERS
// ============================================================================
std::map<int, std::vector<double>> PrepInput(const std::vector<int>& toks, Embedder& emb) {
std::map<int, std::vector<double>> res;
for(size_t i=0;i<ui.layers.size();i++) {
if(ui.layers[i].sources.empty()) {
std::vector<double> d;
d.reserve(UI_CONTEXT*UI_EMBED);
int sz = (ui.layers[i].branch==-1) ? UI_CONTEXT : UI_CONTEXT/2;
int st = (ui.layers[i].branch==1) ? UI_CONTEXT/2 : 0;
int cnt=0;
for(int j=std::max(0,(int)toks.size()-UI_CONTEXT+st); j<(int)toks.size() && cnt<sz; j++) {
auto v=emb.get(toks[j]);
d.insert(d.end(),v.begin(),v.end());
cnt++;
}
while(cnt++<sz) for(int k=0;k<UI_EMBED;k++) d.push_back(0);
res[i]=d;
}
}
return res;
}
void TrainTask(NeuralNetwork* nn, Tokenizer* tk, Embedder* eb) {
if(ui.parsedSamples.empty()) {
ui.training=false;
return;
}
int os = ui.layers.back().size;
double totalLoss = 0;
int sampleCount = 0;
ui.trainingStartTime = std::chrono::steady_clock::now();
ui.totalTokensTrained = 0;
for(int e=0; e<ui.epochs && !ui.stop; e++) {
auto epochStart = std::chrono::steady_clock::now();
for(size_t s=0; s<ui.parsedSamples.size() && !ui.stop; s++) {
auto userToks = tk->textToTokens(ui.parsedSamples[s].userText);
auto aiToks = tk->textToTokens(ui.parsedSamples[s].aiText);
if(userToks.empty() || aiToks.empty()) continue;
std::vector<int> context = userToks;
for(size_t i=0; i<aiToks.size() && !ui.stop; i++) {
auto tokenStart = std::chrono::steady_clock::now();
std::vector<double> target(os, 0);
if(aiToks[i] < os) target[aiToks[i]] = 1.0;
double loss = nn->train(PrepInput(context, *eb), target, ui.lr);
totalLoss += loss;
sampleCount++;
ui.totalTokensTrained++;
context.push_back(aiToks[i]);
if((int)context.size() > UI_CONTEXT) context.erase(context.begin());
// Расчет tokens per second
auto tokenEnd = std::chrono::steady_clock::now();
double tokenTime = std::chrono::duration<double>(tokenEnd - tokenStart).count();
if(tokenTime > 0) {
ui.tokensPerSecond = 1.0 / tokenTime;
}
{
std::lock_guard<std::mutex> lk(ui.mtx);
ui.lastStatus.loss = loss;
float totalSteps = ui.epochs * ui.parsedSamples.size();
float currentStep = e * ui.parsedSamples.size() + s;
ui.lastStatus.progress = (currentStep / totalSteps) * 100.0f;
ui.lastStatus.epoch = e + 1;
ui.lastStatus.speed = ui.tokensPerSecond;
}
}
}
auto epochEnd = std::chrono::steady_clock::now();
double epochTime = std::chrono::duration<double>(epochEnd - epochStart).count();
ui.doneEpochs++;
std::cout << "Epoch " << (e+1) << "/" << ui.epochs
<< " | Loss: " << (totalLoss/(sampleCount>0?sampleCount:1))
<< " | Time: " << epochTime << "s"
<< " | Tokens/sec: " << ui.tokensPerSecond << "\n";
}
if(sampleCount > 0) {
std::lock_guard<std::mutex> lk(ui.mtx);
ui.lastStatus.totalLoss = totalLoss / sampleCount;
auto totalTime = std::chrono::duration<double>(
std::chrono::steady_clock::now() - ui.trainingStartTime).count();
if(totalTime > 0) {
ui.tokensPerSecond = ui.totalTokensTrained / totalTime;
}
}
ui.training=false;
ui.stop=false;
}
// ============================================================================
// MAIN
// ============================================================================
int main() {
if(!glfwInit()) return 1;
GLFWwindow* win = glfwCreateWindow(1600,1000,"Xenith Studio",nullptr,nullptr);
glfwMakeContextCurrent(win); glfwSwapInterval(1);
IMGUI_CHECKVERSION(); ImGui::CreateContext();
ImGuiIO& io = ImGui::GetIO();
io.Fonts->AddFontFromFileTTF("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16.0f, nullptr, io.Fonts->GetGlyphRangesCyrillic());
ImGui_ImplGlfw_InitForOpenGL(win, true); ImGui_ImplOpenGL3_Init("#version 130");
// Init default architecture
ui.layers = {
LayerStructure_t(UI_CONTEXT*UI_EMBED,{}),
LayerStructure_t(512,{0}),
LayerStructure_t(UI_VOCAB,{1})
};
ui.layers[0].branch = -1;
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
ui.chats.push_back(ChatSession(0, "Chat 1"));
Tokenizer tok;
Embedder emb(UI_VOCAB, UI_EMBED);
NeuralNetwork* nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
// Try load architecture
if(NodeEditor::LoadArchitectureFromFile(ui.layers, ui.archFilePath)) {
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
delete nn;
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
std::cout << "Loaded architecture from " << ui.archFilePath << "\n";
}
// Track last values for auto-update
int lastContext = UI_CONTEXT, lastEmbed = UI_EMBED, lastVocab = UI_VOCAB;
while(!glfwWindowShouldClose(win)) {
glfwPollEvents();
ImGui_ImplOpenGL3_NewFrame(); ImGui_ImplGlfw_NewFrame(); ImGui::NewFrame();
ImGui::SetNextWindowPos(ImVec2(0,0)); ImGui::SetNextWindowSize(io.DisplaySize);
ImGui::Begin("Main", nullptr, ImGuiWindowFlags_NoDecoration|ImGuiWindowFlags_MenuBar|ImGuiWindowFlags_NoMove|ImGuiWindowFlags_NoResize);
if(ImGui::BeginMenuBar()) {
if(ImGui::BeginMenu("File")) {
if(ImGui::MenuItem("Apply & Save Arch")) {
NodeEditor::SyncToConfigs(ui.graph, ui.layers);
NodeEditor::SaveArchitectureToFile(ui.layers, ui.archFilePath);
delete nn;
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
}
if(ImGui::MenuItem("Load Arch")) {
if(NodeEditor::LoadArchitectureFromFile(ui.layers, ui.archFilePath)) {
NodeEditor::SyncFromConfigs(ui.graph, ui.layers);
delete nn;
nn = new NeuralNetwork(ui.layers.data(), ui.layers.size(), true);
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
}
}
ImGui::EndMenu();
}
ImGui::EndMenuBar();
}
// Status bar
{
std::lock_guard<std::mutex> lk(ui.mtx);
ImGui::Text("Epoch: %d/%d | Loss: %.5f | %.1f%%",
ui.lastStatus.epoch, ui.epochs, ui.lastStatus.loss, ui.lastStatus.progress);
ImGui::SameLine();
ImGui::TextColored(ui.training?ImVec4(0,1,0,1):ImVec4(1,1,0,1),
ui.training?"[TRAINING]":"[IDLE]");
ImGui::ProgressBar(ui.lastStatus.progress/100.0f, ImVec2(-1,18));
}
ImGui::Separator();
// Three columns
ImGui::Columns(3, "MainCols", true);
ImGui::SetColumnWidth(0, io.DisplaySize.x*0.35f);
ImGui::SetColumnWidth(1, io.DisplaySize.x*0.35f);
// === COLUMN 1: NODE EDITOR ===
ImGui::BeginChild("Graph", ImVec2(0,0), true);
ImGui::TextColored(ImVec4(0,1,1,1), "NODE GRAPH EDITOR");
ImGui::Separator();
NodeEditor::DrawGraph(ui.graph, ImGui::GetContentRegionAvail());
ImGui::EndChild();
// === COLUMN 2: ARCHITECTURE & SETTINGS ===
ImGui::NextColumn();
ImGui::BeginChild("Info", ImVec2(0,0), true);
if(ImGui::BeginTabBar("InfoTabs")) {
if(ImGui::BeginTabItem("Architecture")) {
if(ImGui::Button("Refresh")) {
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
}
ImGui::SameLine();
ImGui::InputText("Arch File", ui.archFilePath, 256);
ImGui::Separator();
ImGui::BeginChild("ArchText", ImVec2(0,0), true);
ImGui::TextWrapped("%s", ui.architectureText.c_str());
ImGui::EndChild();
ImGui::EndTabItem();
}
if(ImGui::BeginTabItem("Training")) {
// Performance metrics
ImGui::TextColored(ImVec4(0,1,1,1), "Performance Metrics:");
ImGui::Separator();
{
std::lock_guard<std::mutex> lk(ui.mtx);
ImGui::Text("Training Speed: %.2f tokens/sec", ui.tokensPerSecond);
ImGui::Text("Generation Speed: %.2f tokens/sec", ui.generationTokensPerSecond);
ImGui::Text("Total Tokens Trained: %d", ui.totalTokensTrained);
}
ImGui::Separator();
// Settings
ImGui::SliderFloat("Learning Rate", &ui.lr, 0.0001f, 0.1f, "%.5f");
ImGui::InputInt("Epochs", &ui.epochs); if(ui.epochs<1) ui.epochs=1;
ImGui::InputInt("Max Response Tokens", &MAX_RESPONSE_TOKENS); if(MAX_RESPONSE_TOKENS<1) MAX_RESPONSE_TOKENS=1;
ImGui::InputInt("Context Size", &UI_CONTEXT); if(UI_CONTEXT<1) UI_CONTEXT=1;
ImGui::InputInt("Embedding Dim", &UI_EMBED); if(UI_EMBED<1) UI_EMBED=1;
ImGui::InputInt("Vocab Size", &UI_VOCAB); if(UI_VOCAB<1) UI_VOCAB=1;
// Auto-update fixed layer sizes when globals change
if(UI_CONTEXT != lastContext || UI_EMBED != lastEmbed || UI_VOCAB != lastVocab) {
lastContext = UI_CONTEXT; lastEmbed = UI_EMBED; lastVocab = UI_VOCAB;
for(auto& node : ui.graph.nodes) {
if(node.type == NodeEditor::NodeType::Input) {
node.layerSize = UI_CONTEXT * UI_EMBED;
node.UpdatePorts();
} else if(node.type == NodeEditor::NodeType::Output) {
node.layerSize = UI_VOCAB;
node.UpdatePorts();
}
}
ui.architectureText = NodeEditor::GenerateArchitectureText(ui.graph, ui.layers);
}
ImGui::Separator();
ImGui::TextColored(ImVec4(1,0.5f,0,1), "⚠️ Recommendations:");
if(ui.parsedSamples.size() < 10) {
ImGui::TextColored(ImVec4(1,0,0,1), "• Too few samples! Add at least 10-20 examples");
}
if(ui.epochs > 50 && ui.parsedSamples.size() < 5) {
ImGui::TextColored(ImVec4(1,0,0,1), "• Too many epochs for small dataset! Reduce to 10-20");
}
if(ui.lr > 0.01f) {
ImGui::TextColored(ImVec4(1,0.8f,0,1), "• High learning rate! Try 0.001-0.01");
}
// Dataset loading
ImGui::Separator();
ImGui::Text("Dataset File:");
ImGui::InputText("##DSFile", ui.datasetFilePath, 256);
if(ImGui::Button("Load Dataset from File")) {
std::string tempBuffer;
if(LoadDatasetFromFile(tempBuffer, ui.datasetFilePath)) {
strncpy(ui.dsBuf, tempBuffer.c_str(), sizeof(ui.dsBuf) - 1);
ui.dsBuf[sizeof(ui.dsBuf) - 1] = '\0';
ui.parsedSamples = ParseDataset(ui.dsBuf, tok);
ui.datasetLoaded = true;
std::cout << "Loaded " << ui.parsedSamples.size() << " samples\n";
} else {
std::cerr << "Failed to load dataset\n";
}
}
ImGui::SameLine();
if(ui.datasetLoaded) ImGui::TextColored(ImVec4(0,1,0,1), "✓ %lu samples", ui.parsedSamples.size());
else ImGui::Text("Not loaded");
ImGui::Separator();
if(!ui.training) {
if(ui.parsedSamples.empty()) {
ImGui::PushStyleVar(ImGuiStyleVar_Alpha, 0.5f);
ImGui::Button("▶ START TRAINING (Load dataset first)", ImVec2(-1,45));
ImGui::PopStyleVar();
} else {
if(ImGui::Button("▶ START TRAINING", ImVec2(-1,45))) {
ui.training=true;
ui.stop=false;
ui.doneEpochs=0;
std::thread(TrainTask, nn, &tok, &emb).detach();
}
}
} else {
ImGui::PushStyleColor(ImGuiCol_Button, ImVec4(0.8f,0.2f,0.2f,1));
if(ImGui::Button("⏹ STOP", ImVec2(-1,45))) ui.stop=true;
ImGui::PopStyleColor();
}
ImGui::EndTabItem();
}
if(ImGui::BeginTabItem("Dataset")) {
if(ImGui::Button("Parse from Text")) {
ui.parsedSamples = ParseDataset(std::string(ui.dsBuf), tok);
ui.datasetLoaded = !ui.parsedSamples.empty();
}
ImGui::SameLine();
ImGui::Text("(%lu samples)", ui.parsedSamples.size());
ImGui::InputTextMultiline("##DS", ui.dsBuf, sizeof(ui.dsBuf), ImVec2(-1,-1),
ImGuiInputTextFlags_AllowTabInput);
ImGui::TextWrapped("Format: [USER]text[AI]response[EOS]");
ImGui::EndTabItem();
}
ImGui::EndTabBar();
}
ImGui::EndChild();
// === COLUMN 3: CHATS ===
ImGui::NextColumn();
ImGui::BeginChild("Chats", ImVec2(0,0), true);
static char newChatName[64] = "";
if(ImGui::Button("+ New Chat")) {
ui.chats.push_back(ChatSession(ui.chats.size(), std::string("Chat ") + std::to_string(ui.chats.size()+1)));
}
ImGui::SameLine();
ImGui::InputText("##NewChatName", newChatName, 64);
for(size_t i=0;i<ui.chats.size();i++) {
if(ImGui::Selectable(ui.chats[i].name.c_str(), (int)i==ui.activeChat)) {
ui.activeChat = i;
}
}
ImGui::Separator();
if(!ui.chats.empty() && ui.activeChat < ui.chats.size()) {
auto& chat = ui.chats[ui.activeChat];
ImGui::Text("%s", chat.name.c_str());
ImGui::Separator();
ImGui::BeginChild("Messages", ImVec2(0, -80), true);
for(auto& msg : chat.messages) {
if(msg.role == "USER") {
ImGui::PushStyleColor(ImGuiCol_Text, ImVec4(0.5f,0.8f,1.0f,1));
ImGui::Text("[You]: %s", msg.content.c_str());
ImGui::PopStyleColor();
} else {
ImGui::PushStyleColor(ImGuiCol_Text, ImVec4(0.8f,1.0f,0.5f,1));
ImGui::TextWrapped("[AI]: %s", msg.content.c_str());
ImGui::PopStyleColor();
}
}
if(ui.scrollChat) { ImGui::SetScrollHereY(1.0f); ui.scrollChat = false; }
ImGui::EndChild();
if(ImGui::InputText("##Input", ui.inputBuf, 512, ImGuiInputTextFlags_EnterReturnsTrue)) {
std::string input = ui.inputBuf;
chat.messages.push_back({"USER", input});
auto ctx = tok.textToTokens(input);
std::string ans;
int generatedTokens = 0;
auto genStartTime = std::chrono::steady_clock::now();
for(int g=0; g<MAX_RESPONSE_TOKENS; g++) {
auto tokenStart = std::chrono::steady_clock::now();
auto out = nn->feedForward(PrepInput(ctx, emb));
int id = std::max_element(out.begin(),out.end())-out.begin();
if(id<=0||id>=UI_VOCAB) break;
std::string w = tok.getWord(id);
if(w.empty()) break;
ans += w + " ";
ctx.push_back(id);
if((int)ctx.size()>UI_CONTEXT) ctx.erase(ctx.begin());
generatedTokens++;
auto tokenEnd = std::chrono::steady_clock::now();
double tokenTime = std::chrono::duration<double>(tokenEnd - tokenStart).count();
if(tokenTime > 0) {
ui.generationTokensPerSecond = 1.0 / tokenTime;
}
}
auto genEndTime = std::chrono::steady_clock::now();
double genTime = std::chrono::duration<double>(genEndTime - genStartTime).count();
std::cout << "Generated " << generatedTokens << " tokens in "
<< genTime << "s (" << (generatedTokens/genTime) << " tok/s)\n";
chat.messages.push_back({"AI", ans});
ui.inputBuf[0]=0;
ui.scrollChat=true;
}
}
ImGui::EndChild();
ImGui::Columns(1);
ImGui::End();
ImGui::Render();
glViewport(0,0,(int)io.DisplaySize.x,(int)io.DisplaySize.y);
glClearColor(0.08f,0.09f,0.12f,1); glClear(GL_COLOR_BUFFER_BIT);
ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData());
glfwSwapBuffers(win);
}
delete nn;
ImGui_ImplOpenGL3_Shutdown(); ImGui_ImplGlfw_Shutdown(); ImGui::DestroyContext();
glfwTerminate();
return 0;
}